Coverage for transformer_lens/model_bridge/generalized_components/t5_block.py: 68%
92 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""T5-specific block bridge component.
3This module contains the bridge component for T5 blocks, which have a different
4structure than standard transformer blocks (3 layers in decoder vs 2 layers).
5"""
6from __future__ import annotations
8import types
9from typing import Any, Callable, Dict, Optional
11import torch
13from transformer_lens.hook_points import HookPoint
14from transformer_lens.model_bridge.generalized_components.base import (
15 GeneralizedComponent,
16)
19class T5BlockBridge(GeneralizedComponent):
20 """Bridge component for T5 transformer blocks.
22 T5 has two types of blocks:
23 - Encoder blocks: 2 layers (self-attention, feed-forward)
24 - Decoder blocks: 3 layers (self-attention, cross-attention, feed-forward)
26 This bridge handles both types based on the presence of cross-attention.
27 """
29 is_list_item: bool = True
30 hook_aliases = {"hook_resid_pre": "hook_in", "hook_resid_post": "hook_out"}
32 def __init__(
33 self,
34 name: str,
35 config: Optional[Any] = None,
36 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
37 is_decoder: bool = False,
38 ):
39 """Initialize the T5 block bridge.
41 Args:
42 name: The name of the component in the model
43 config: Optional configuration
44 submodules: Dictionary of submodules to register
45 is_decoder: Whether this is a decoder block (has cross-attention)
46 """
47 super().__init__(name, config, submodules=submodules or {})
48 self.is_decoder = is_decoder
49 self.hook_resid_mid = HookPoint()
50 self._register_hook("hook_resid_mid", self.hook_resid_mid)
51 if is_decoder:
52 self.hook_resid_mid2 = HookPoint()
53 self._register_hook("hook_resid_mid2", self.hook_resid_mid2)
54 self._original_block_forward: Optional[Callable[..., Any]] = None
56 def set_original_component(self, component: torch.nn.Module):
57 """Set the original component and monkey-patch its forward method.
59 Args:
60 component: The original PyTorch module to wrap
61 """
62 super().set_original_component(component)
63 self._patch_t5_block_forward()
65 def _patch_t5_block_forward(self):
66 """Monkey-patch the T5 block's forward method to insert hooks."""
67 if self.original_component is None: 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true
68 return
69 self._original_block_forward = self.original_component.forward
71 def patched_forward(
72 block_self,
73 hidden_states,
74 attention_mask=None,
75 position_bias=None,
76 encoder_hidden_states=None,
77 encoder_attention_mask=None,
78 encoder_decoder_position_bias=None,
79 layer_head_mask=None,
80 cross_attn_layer_head_mask=None,
81 past_key_value=None,
82 use_cache=False,
83 output_attentions=False,
84 return_dict=True,
85 cache_position=None,
86 **kwargs,
87 ):
88 """Patched T5 block forward with hooks."""
89 import inspect
91 hidden_states = self.hook_in(hidden_states)
92 if not hasattr(block_self, "layer"): 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true
93 raise RuntimeError(f"T5 block {block_self} does not have 'layer' attribute")
94 layers = block_self.layer
95 is_decoder_block = len(layers) == 3
97 # Check which parameters are accepted by the layer forward methods
98 # (Transformers v5 removed past_key_value, use_cache, layer_head_mask)
99 self_attn_params = set(inspect.signature(layers[0].forward).parameters.keys())
101 if "past_key_value" in self_attn_params and past_key_value is not None: 101 ↛ 102line 101 didn't jump to line 102 because the condition on line 101 was never true
102 if not is_decoder_block:
103 expected_num_past_key_values = 0
104 else:
105 expected_num_past_key_values = 2
106 if len(past_key_value) != expected_num_past_key_values:
107 raise ValueError(
108 f"There should be {expected_num_past_key_values} past states. Got {len(past_key_value)}."
109 )
110 self_attn_past_key_value = past_key_value[:2] if is_decoder_block else None
111 cross_attn_past_key_value = past_key_value[2:4] if is_decoder_block else None
112 else:
113 self_attn_past_key_value = None
114 cross_attn_past_key_value = None
115 self_attn_kwargs = dict(
116 hidden_states=hidden_states,
117 attention_mask=attention_mask,
118 position_bias=position_bias,
119 output_attentions=output_attentions,
120 cache_position=cache_position,
121 )
122 # Conditionally pass parameters removed in Transformers v5
123 if "past_key_value" in self_attn_params: 123 ↛ 124line 123 didn't jump to line 124 because the condition on line 123 was never true
124 self_attn_kwargs["past_key_value"] = self_attn_past_key_value
125 if "use_cache" in self_attn_params: 125 ↛ 127line 125 didn't jump to line 127 because the condition on line 125 was always true
126 self_attn_kwargs["use_cache"] = use_cache
127 if "layer_head_mask" in self_attn_params: 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true
128 self_attn_kwargs["layer_head_mask"] = layer_head_mask
129 self_attention_outputs = layers[0](**self_attn_kwargs)
130 hidden_states = self_attention_outputs[0]
131 # Keep self-attention outputs and relative position weights
132 # attention_outputs contains: (position_bias,) or (position_bias, attn_weights)
133 attention_outputs = self_attention_outputs[1:]
134 hidden_states = self.hook_resid_mid(hidden_states)
135 if is_decoder_block and encoder_hidden_states is not None:
136 cross_attn_params = set(inspect.signature(layers[1].forward).parameters.keys())
137 cross_attn_kwargs = dict(
138 hidden_states=hidden_states,
139 key_value_states=encoder_hidden_states,
140 attention_mask=encoder_attention_mask,
141 position_bias=encoder_decoder_position_bias,
142 output_attentions=output_attentions,
143 cache_position=cache_position,
144 )
145 if "past_key_value" in cross_attn_params: 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true
146 cross_attn_kwargs["past_key_value"] = cross_attn_past_key_value
147 if "use_cache" in cross_attn_params: 147 ↛ 149line 147 didn't jump to line 149 because the condition on line 147 was always true
148 cross_attn_kwargs["use_cache"] = use_cache
149 if "layer_head_mask" in cross_attn_params: 149 ↛ 150line 149 didn't jump to line 150 because the condition on line 149 was never true
150 cross_attn_kwargs["layer_head_mask"] = cross_attn_layer_head_mask
151 cross_attention_outputs = layers[1](**cross_attn_kwargs)
152 hidden_states = cross_attention_outputs[0]
153 if hasattr(self, "hook_resid_mid2"): 153 ↛ 156line 153 didn't jump to line 156 because the condition on line 153 was always true
154 hidden_states = self.hook_resid_mid2(hidden_states)
155 # Keep cross-attention outputs and relative position weights
156 attention_outputs = attention_outputs + cross_attention_outputs[1:]
157 ff_layer_idx = 2 if is_decoder_block else 1
158 feed_forward_outputs = layers[ff_layer_idx](hidden_states)
159 # T5LayerFF returns a tensor, not a tuple
160 if isinstance(feed_forward_outputs, tuple): 160 ↛ 161line 160 didn't jump to line 161 because the condition on line 160 was never true
161 hidden_states = feed_forward_outputs[0]
162 else:
163 hidden_states = feed_forward_outputs
164 hidden_states = self.hook_out(hidden_states)
165 outputs: tuple[Any, ...] = (hidden_states,)
166 # Return: hidden-states, (self-attention position bias), (self-attention weights),
167 # (cross-attention position bias), (cross-attention weights)
168 return outputs + attention_outputs
170 self.original_component.forward = types.MethodType(patched_forward, self.original_component)
172 def forward(self, *args: Any, **kwargs: Any) -> Any:
173 """Forward pass through the block bridge.
175 Args:
176 *args: Input arguments
177 **kwargs: Input keyword arguments
179 Returns:
180 The output from the original component
181 """
182 if self.original_component is None: 182 ↛ 183line 182 didn't jump to line 183 because the condition on line 182 was never true
183 raise RuntimeError(
184 f"Original component not set for {self.name}. Call set_original_component() first."
185 )
186 output = self.original_component(*args, **kwargs)
187 return output
189 def get_expected_parameter_names(self, prefix: str = "") -> list[str]:
190 """Get the expected TransformerLens parameter names for this block.
192 Args:
193 prefix: Prefix to add to parameter names (e.g., "blocks.0")
195 Returns:
196 List of expected parameter names in TransformerLens format
197 """
198 param_names = []
199 for sub_name, sub_component in self.submodules.items():
200 sub_prefix = f"{prefix}.{sub_name}" if prefix else sub_name
201 param_names.extend(sub_component.get_expected_parameter_names(sub_prefix))
202 return param_names
204 def get_list_size(self) -> int:
205 """Get the number of transformer blocks.
207 Returns:
208 Number of layers in the model
209 """
210 if self.config is None:
211 return 0
212 return getattr(self.config, "n_layers", 0)