Coverage for transformer_lens/model_bridge/generalized_components/t5gemma_decoder_block.py: 36%
64 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""T5Gemma-specific decoder block bridge.
3T5GemmaDecoderLayer uses Gemma-style flat attribute access (not T5's .layer[] indexing).
4It has: self-attention + cross-attention + MLP, each with pre/post norms.
5This bridge monkey-patches the layer forward to insert intermediate hook points.
6"""
7from __future__ import annotations
9import types
10from typing import Any, Callable, Dict, Optional
12import torch
14from transformer_lens.hook_points import HookPoint
15from transformer_lens.model_bridge.generalized_components.base import (
16 GeneralizedComponent,
17)
20class T5GemmaDecoderBlockBridge(GeneralizedComponent):
21 """Bridge for T5Gemma decoder layers.
23 Inserts hook points between the three sub-components of each decoder layer:
24 - hook_in (hook_resid_pre): residual before self-attention pre-norm
25 - hook_resid_mid: residual after self-attention + residual add, before cross-attn pre-norm
26 - hook_resid_mid2: residual after cross-attention + residual add, before MLP pre-norm
27 - hook_out (hook_resid_post): residual after MLP + residual add
28 """
30 is_list_item: bool = True
31 hook_aliases = {
32 "hook_resid_pre": "hook_in",
33 "hook_resid_post": "hook_out",
34 }
36 def __init__(
37 self,
38 name: str,
39 config: Optional[Any] = None,
40 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
41 ):
42 super().__init__(name, config, submodules=submodules or {})
43 self.hook_resid_mid = HookPoint()
44 self._register_hook("hook_resid_mid", self.hook_resid_mid)
45 self.hook_resid_mid2 = HookPoint()
46 self._register_hook("hook_resid_mid2", self.hook_resid_mid2)
47 self._original_block_forward: Optional[Callable[..., Any]] = None
49 def set_original_component(self, component: torch.nn.Module) -> None:
50 super().set_original_component(component)
51 self._patch_decoder_layer_forward()
53 def _patch_decoder_layer_forward(self) -> None:
54 """Monkey-patch T5GemmaDecoderLayer.forward to insert hook points.
56 The patched forward preserves the original residual-stream semantics but
57 fires hook_in, hook_resid_mid, hook_resid_mid2, and hook_out at the
58 canonical HookedTransformer positions.
59 """
60 if self.original_component is None:
61 return
62 self._original_block_forward = self.original_component.forward
64 hook_in = self.hook_in # fires at hook_resid_pre
65 hook_resid_mid = self.hook_resid_mid
66 hook_resid_mid2 = self.hook_resid_mid2
67 hook_out = self.hook_out # fires at hook_resid_post
68 original_forward = self._original_block_forward
70 def patched_forward(
71 layer_self,
72 hidden_states: torch.Tensor,
73 position_embeddings=None,
74 attention_mask=None,
75 position_ids=None,
76 past_key_values=None,
77 use_cache=None,
78 encoder_hidden_states=None,
79 encoder_attention_mask=None,
80 **kwargs: Any,
81 ) -> torch.Tensor:
82 hidden_states = hook_in(hidden_states)
84 # --- self-attention sub-layer ---
85 residual = hidden_states
86 hidden_states = layer_self.pre_self_attn_layernorm(hidden_states)
87 sa_out, _ = layer_self.self_attn(
88 hidden_states=hidden_states,
89 position_embeddings=position_embeddings,
90 attention_mask=attention_mask,
91 position_ids=position_ids,
92 past_key_values=(
93 past_key_values.self_attention_cache if past_key_values is not None else None
94 ),
95 use_cache=use_cache,
96 **kwargs,
97 )
98 hidden_states = layer_self.post_self_attn_layernorm(sa_out)
99 hidden_states = residual + layer_self.dropout(hidden_states)
100 hidden_states = hook_resid_mid(hidden_states)
102 # --- cross-attention sub-layer ---
103 residual = hidden_states
104 hidden_states = layer_self.pre_cross_attn_layernorm(hidden_states)
105 ca_out, _ = layer_self.cross_attn(
106 hidden_states=hidden_states,
107 encoder_hidden_states=encoder_hidden_states,
108 attention_mask=encoder_attention_mask,
109 past_key_values=past_key_values,
110 use_cache=use_cache,
111 **kwargs,
112 )
113 hidden_states = layer_self.post_cross_attn_layernorm(ca_out)
114 hidden_states = residual + layer_self.dropout(hidden_states)
115 hidden_states = hook_resid_mid2(hidden_states)
117 # --- MLP sub-layer ---
118 residual = hidden_states
119 hidden_states = layer_self.pre_feedforward_layernorm(hidden_states)
120 hidden_states = layer_self.mlp(hidden_states)
121 hidden_states = layer_self.post_feedforward_layernorm(hidden_states)
122 hidden_states = residual + layer_self.dropout(hidden_states)
123 hidden_states = hook_out(hidden_states)
125 return hidden_states
127 self.original_component.forward = types.MethodType(patched_forward, self.original_component)
129 def forward(self, *args: Any, **kwargs: Any) -> Any:
130 if self.original_component is None:
131 raise RuntimeError(
132 f"Original component not set for {self.name}. "
133 "Call set_original_component() first."
134 )
135 return self.original_component(*args, **kwargs)
137 def get_expected_parameter_names(self, prefix: str = "") -> list[str]:
138 param_names = []
139 for sub_name, sub_component in self.submodules.items():
140 sub_prefix = f"{prefix}.{sub_name}" if prefix else sub_name
141 param_names.extend(sub_component.get_expected_parameter_names(sub_prefix))
142 return param_names
144 def get_list_size(self) -> int:
145 if self.config is None:
146 return 0
147 return getattr(self.config, "n_layers", 0)