Coverage for transformer_lens/model_bridge/generalized_components/t5gemma_decoder_block.py: 36%

64 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""T5Gemma-specific decoder block bridge. 

2 

3T5GemmaDecoderLayer uses Gemma-style flat attribute access (not T5's .layer[] indexing). 

4It has: self-attention + cross-attention + MLP, each with pre/post norms. 

5This bridge monkey-patches the layer forward to insert intermediate hook points. 

6""" 

7from __future__ import annotations 

8 

9import types 

10from typing import Any, Callable, Dict, Optional 

11 

12import torch 

13 

14from transformer_lens.hook_points import HookPoint 

15from transformer_lens.model_bridge.generalized_components.base import ( 

16 GeneralizedComponent, 

17) 

18 

19 

20class T5GemmaDecoderBlockBridge(GeneralizedComponent): 

21 """Bridge for T5Gemma decoder layers. 

22 

23 Inserts hook points between the three sub-components of each decoder layer: 

24 - hook_in (hook_resid_pre): residual before self-attention pre-norm 

25 - hook_resid_mid: residual after self-attention + residual add, before cross-attn pre-norm 

26 - hook_resid_mid2: residual after cross-attention + residual add, before MLP pre-norm 

27 - hook_out (hook_resid_post): residual after MLP + residual add 

28 """ 

29 

30 is_list_item: bool = True 

31 hook_aliases = { 

32 "hook_resid_pre": "hook_in", 

33 "hook_resid_post": "hook_out", 

34 } 

35 

36 def __init__( 

37 self, 

38 name: str, 

39 config: Optional[Any] = None, 

40 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

41 ): 

42 super().__init__(name, config, submodules=submodules or {}) 

43 self.hook_resid_mid = HookPoint() 

44 self._register_hook("hook_resid_mid", self.hook_resid_mid) 

45 self.hook_resid_mid2 = HookPoint() 

46 self._register_hook("hook_resid_mid2", self.hook_resid_mid2) 

47 self._original_block_forward: Optional[Callable[..., Any]] = None 

48 

49 def set_original_component(self, component: torch.nn.Module) -> None: 

50 super().set_original_component(component) 

51 self._patch_decoder_layer_forward() 

52 

53 def _patch_decoder_layer_forward(self) -> None: 

54 """Monkey-patch T5GemmaDecoderLayer.forward to insert hook points. 

55 

56 The patched forward preserves the original residual-stream semantics but 

57 fires hook_in, hook_resid_mid, hook_resid_mid2, and hook_out at the 

58 canonical HookedTransformer positions. 

59 """ 

60 if self.original_component is None: 

61 return 

62 self._original_block_forward = self.original_component.forward 

63 

64 hook_in = self.hook_in # fires at hook_resid_pre 

65 hook_resid_mid = self.hook_resid_mid 

66 hook_resid_mid2 = self.hook_resid_mid2 

67 hook_out = self.hook_out # fires at hook_resid_post 

68 original_forward = self._original_block_forward 

69 

70 def patched_forward( 

71 layer_self, 

72 hidden_states: torch.Tensor, 

73 position_embeddings=None, 

74 attention_mask=None, 

75 position_ids=None, 

76 past_key_values=None, 

77 use_cache=None, 

78 encoder_hidden_states=None, 

79 encoder_attention_mask=None, 

80 **kwargs: Any, 

81 ) -> torch.Tensor: 

82 hidden_states = hook_in(hidden_states) 

83 

84 # --- self-attention sub-layer --- 

85 residual = hidden_states 

86 hidden_states = layer_self.pre_self_attn_layernorm(hidden_states) 

87 sa_out, _ = layer_self.self_attn( 

88 hidden_states=hidden_states, 

89 position_embeddings=position_embeddings, 

90 attention_mask=attention_mask, 

91 position_ids=position_ids, 

92 past_key_values=( 

93 past_key_values.self_attention_cache if past_key_values is not None else None 

94 ), 

95 use_cache=use_cache, 

96 **kwargs, 

97 ) 

98 hidden_states = layer_self.post_self_attn_layernorm(sa_out) 

99 hidden_states = residual + layer_self.dropout(hidden_states) 

100 hidden_states = hook_resid_mid(hidden_states) 

101 

102 # --- cross-attention sub-layer --- 

103 residual = hidden_states 

104 hidden_states = layer_self.pre_cross_attn_layernorm(hidden_states) 

105 ca_out, _ = layer_self.cross_attn( 

106 hidden_states=hidden_states, 

107 encoder_hidden_states=encoder_hidden_states, 

108 attention_mask=encoder_attention_mask, 

109 past_key_values=past_key_values, 

110 use_cache=use_cache, 

111 **kwargs, 

112 ) 

113 hidden_states = layer_self.post_cross_attn_layernorm(ca_out) 

114 hidden_states = residual + layer_self.dropout(hidden_states) 

115 hidden_states = hook_resid_mid2(hidden_states) 

116 

117 # --- MLP sub-layer --- 

118 residual = hidden_states 

119 hidden_states = layer_self.pre_feedforward_layernorm(hidden_states) 

120 hidden_states = layer_self.mlp(hidden_states) 

121 hidden_states = layer_self.post_feedforward_layernorm(hidden_states) 

122 hidden_states = residual + layer_self.dropout(hidden_states) 

123 hidden_states = hook_out(hidden_states) 

124 

125 return hidden_states 

126 

127 self.original_component.forward = types.MethodType(patched_forward, self.original_component) 

128 

129 def forward(self, *args: Any, **kwargs: Any) -> Any: 

130 if self.original_component is None: 

131 raise RuntimeError( 

132 f"Original component not set for {self.name}. " 

133 "Call set_original_component() first." 

134 ) 

135 return self.original_component(*args, **kwargs) 

136 

137 def get_expected_parameter_names(self, prefix: str = "") -> list[str]: 

138 param_names = [] 

139 for sub_name, sub_component in self.submodules.items(): 

140 sub_prefix = f"{prefix}.{sub_name}" if prefix else sub_name 

141 param_names.extend(sub_component.get_expected_parameter_names(sub_prefix)) 

142 return param_names 

143 

144 def get_list_size(self) -> int: 

145 if self.config is None: 

146 return 0 

147 return getattr(self.config, "n_layers", 0)