Coverage for transformer_lens/model_bridge/generalized_components/t5_block.py: 68%

92 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""T5-specific block bridge component. 

2 

3This module contains the bridge component for T5 blocks, which have a different 

4structure than standard transformer blocks (3 layers in decoder vs 2 layers). 

5""" 

6from __future__ import annotations 

7 

8import types 

9from typing import Any, Callable, Dict, Optional 

10 

11import torch 

12 

13from transformer_lens.hook_points import HookPoint 

14from transformer_lens.model_bridge.generalized_components.base import ( 

15 GeneralizedComponent, 

16) 

17 

18 

19class T5BlockBridge(GeneralizedComponent): 

20 """Bridge component for T5 transformer blocks. 

21 

22 T5 has two types of blocks: 

23 - Encoder blocks: 2 layers (self-attention, feed-forward) 

24 - Decoder blocks: 3 layers (self-attention, cross-attention, feed-forward) 

25 

26 This bridge handles both types based on the presence of cross-attention. 

27 """ 

28 

29 is_list_item: bool = True 

30 hook_aliases = {"hook_resid_pre": "hook_in", "hook_resid_post": "hook_out"} 

31 

32 def __init__( 

33 self, 

34 name: str, 

35 config: Optional[Any] = None, 

36 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

37 is_decoder: bool = False, 

38 ): 

39 """Initialize the T5 block bridge. 

40 

41 Args: 

42 name: The name of the component in the model 

43 config: Optional configuration 

44 submodules: Dictionary of submodules to register 

45 is_decoder: Whether this is a decoder block (has cross-attention) 

46 """ 

47 super().__init__(name, config, submodules=submodules or {}) 

48 self.is_decoder = is_decoder 

49 self.hook_resid_mid = HookPoint() 

50 self._register_hook("hook_resid_mid", self.hook_resid_mid) 

51 if is_decoder: 

52 self.hook_resid_mid2 = HookPoint() 

53 self._register_hook("hook_resid_mid2", self.hook_resid_mid2) 

54 self._original_block_forward: Optional[Callable[..., Any]] = None 

55 

56 def set_original_component(self, component: torch.nn.Module): 

57 """Set the original component and monkey-patch its forward method. 

58 

59 Args: 

60 component: The original PyTorch module to wrap 

61 """ 

62 super().set_original_component(component) 

63 self._patch_t5_block_forward() 

64 

65 def _patch_t5_block_forward(self): 

66 """Monkey-patch the T5 block's forward method to insert hooks.""" 

67 if self.original_component is None: 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true

68 return 

69 self._original_block_forward = self.original_component.forward 

70 

71 def patched_forward( 

72 block_self, 

73 hidden_states, 

74 attention_mask=None, 

75 position_bias=None, 

76 encoder_hidden_states=None, 

77 encoder_attention_mask=None, 

78 encoder_decoder_position_bias=None, 

79 layer_head_mask=None, 

80 cross_attn_layer_head_mask=None, 

81 past_key_value=None, 

82 use_cache=False, 

83 output_attentions=False, 

84 return_dict=True, 

85 cache_position=None, 

86 **kwargs, 

87 ): 

88 """Patched T5 block forward with hooks.""" 

89 import inspect 

90 

91 hidden_states = self.hook_in(hidden_states) 

92 if not hasattr(block_self, "layer"): 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true

93 raise RuntimeError(f"T5 block {block_self} does not have 'layer' attribute") 

94 layers = block_self.layer 

95 is_decoder_block = len(layers) == 3 

96 

97 # Check which parameters are accepted by the layer forward methods 

98 # (Transformers v5 removed past_key_value, use_cache, layer_head_mask) 

99 self_attn_params = set(inspect.signature(layers[0].forward).parameters.keys()) 

100 

101 if "past_key_value" in self_attn_params and past_key_value is not None: 101 ↛ 102line 101 didn't jump to line 102 because the condition on line 101 was never true

102 if not is_decoder_block: 

103 expected_num_past_key_values = 0 

104 else: 

105 expected_num_past_key_values = 2 

106 if len(past_key_value) != expected_num_past_key_values: 

107 raise ValueError( 

108 f"There should be {expected_num_past_key_values} past states. Got {len(past_key_value)}." 

109 ) 

110 self_attn_past_key_value = past_key_value[:2] if is_decoder_block else None 

111 cross_attn_past_key_value = past_key_value[2:4] if is_decoder_block else None 

112 else: 

113 self_attn_past_key_value = None 

114 cross_attn_past_key_value = None 

115 self_attn_kwargs = dict( 

116 hidden_states=hidden_states, 

117 attention_mask=attention_mask, 

118 position_bias=position_bias, 

119 output_attentions=output_attentions, 

120 cache_position=cache_position, 

121 ) 

122 # Conditionally pass parameters removed in Transformers v5 

123 if "past_key_value" in self_attn_params: 123 ↛ 124line 123 didn't jump to line 124 because the condition on line 123 was never true

124 self_attn_kwargs["past_key_value"] = self_attn_past_key_value 

125 if "use_cache" in self_attn_params: 125 ↛ 127line 125 didn't jump to line 127 because the condition on line 125 was always true

126 self_attn_kwargs["use_cache"] = use_cache 

127 if "layer_head_mask" in self_attn_params: 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true

128 self_attn_kwargs["layer_head_mask"] = layer_head_mask 

129 self_attention_outputs = layers[0](**self_attn_kwargs) 

130 hidden_states = self_attention_outputs[0] 

131 # Keep self-attention outputs and relative position weights 

132 # attention_outputs contains: (position_bias,) or (position_bias, attn_weights) 

133 attention_outputs = self_attention_outputs[1:] 

134 hidden_states = self.hook_resid_mid(hidden_states) 

135 if is_decoder_block and encoder_hidden_states is not None: 

136 cross_attn_params = set(inspect.signature(layers[1].forward).parameters.keys()) 

137 cross_attn_kwargs = dict( 

138 hidden_states=hidden_states, 

139 key_value_states=encoder_hidden_states, 

140 attention_mask=encoder_attention_mask, 

141 position_bias=encoder_decoder_position_bias, 

142 output_attentions=output_attentions, 

143 cache_position=cache_position, 

144 ) 

145 if "past_key_value" in cross_attn_params: 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true

146 cross_attn_kwargs["past_key_value"] = cross_attn_past_key_value 

147 if "use_cache" in cross_attn_params: 147 ↛ 149line 147 didn't jump to line 149 because the condition on line 147 was always true

148 cross_attn_kwargs["use_cache"] = use_cache 

149 if "layer_head_mask" in cross_attn_params: 149 ↛ 150line 149 didn't jump to line 150 because the condition on line 149 was never true

150 cross_attn_kwargs["layer_head_mask"] = cross_attn_layer_head_mask 

151 cross_attention_outputs = layers[1](**cross_attn_kwargs) 

152 hidden_states = cross_attention_outputs[0] 

153 if hasattr(self, "hook_resid_mid2"): 153 ↛ 156line 153 didn't jump to line 156 because the condition on line 153 was always true

154 hidden_states = self.hook_resid_mid2(hidden_states) 

155 # Keep cross-attention outputs and relative position weights 

156 attention_outputs = attention_outputs + cross_attention_outputs[1:] 

157 ff_layer_idx = 2 if is_decoder_block else 1 

158 feed_forward_outputs = layers[ff_layer_idx](hidden_states) 

159 # T5LayerFF returns a tensor, not a tuple 

160 if isinstance(feed_forward_outputs, tuple): 160 ↛ 161line 160 didn't jump to line 161 because the condition on line 160 was never true

161 hidden_states = feed_forward_outputs[0] 

162 else: 

163 hidden_states = feed_forward_outputs 

164 hidden_states = self.hook_out(hidden_states) 

165 outputs: tuple[Any, ...] = (hidden_states,) 

166 # Return: hidden-states, (self-attention position bias), (self-attention weights), 

167 # (cross-attention position bias), (cross-attention weights) 

168 return outputs + attention_outputs 

169 

170 self.original_component.forward = types.MethodType(patched_forward, self.original_component) 

171 

172 def forward(self, *args: Any, **kwargs: Any) -> Any: 

173 """Forward pass through the block bridge. 

174 

175 Args: 

176 *args: Input arguments 

177 **kwargs: Input keyword arguments 

178 

179 Returns: 

180 The output from the original component 

181 """ 

182 if self.original_component is None: 182 ↛ 183line 182 didn't jump to line 183 because the condition on line 182 was never true

183 raise RuntimeError( 

184 f"Original component not set for {self.name}. Call set_original_component() first." 

185 ) 

186 output = self.original_component(*args, **kwargs) 

187 return output 

188 

189 def get_expected_parameter_names(self, prefix: str = "") -> list[str]: 

190 """Get the expected TransformerLens parameter names for this block. 

191 

192 Args: 

193 prefix: Prefix to add to parameter names (e.g., "blocks.0") 

194 

195 Returns: 

196 List of expected parameter names in TransformerLens format 

197 """ 

198 param_names = [] 

199 for sub_name, sub_component in self.submodules.items(): 

200 sub_prefix = f"{prefix}.{sub_name}" if prefix else sub_name 

201 param_names.extend(sub_component.get_expected_parameter_names(sub_prefix)) 

202 return param_names 

203 

204 def get_list_size(self) -> int: 

205 """Get the number of transformer blocks. 

206 

207 Returns: 

208 Number of layers in the model 

209 """ 

210 if self.config is None: 

211 return 0 

212 return getattr(self.config, "n_layers", 0)