Coverage for transformer_lens/components/transformer_block.py: 71%
110 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Hooked Transformer Transformer Block Component.
3This module contains all the component :class:`TransformerBlock`.
4"""
6from typing import Callable, Dict, Optional, Union
8import torch
9import torch.nn as nn
10from jaxtyping import Float, Int
12from transformer_lens.cache.key_value_cache_entry import (
13 TransformerLensKeyValueCacheEntry,
14)
15from transformer_lens.components import (
16 Attention,
17 GroupedQueryAttention,
18 LayerNorm,
19 LayerNormPre,
20 RMSNorm,
21 RMSNormPre,
22)
23from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
24from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
25from transformer_lens.factories.mlp_factory import MLPFactory
26from transformer_lens.hook_points import HookPoint
27from transformer_lens.utilities import repeat_along_head_dimension
30# Transformer Block
31class TransformerBlock(nn.Module):
32 ln1: nn.Module
33 ln2: nn.Module
34 mlp: CanBeUsedAsMLP
36 def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index):
37 super().__init__()
38 self.cfg = HookedTransformerConfig.unwrap(cfg)
39 normalization_layer: Callable # type: ignore
40 normalization_layer_after: Callable # type: ignore
42 self.normalization_type = self.cfg.normalization_type
44 if self.normalization_type == "LN":
45 normalization_layer = LayerNorm
46 elif self.normalization_type == "LNPre": 46 ↛ 49line 46 didn't jump to line 49 because the condition on line 46 was always true
47 # We've folded in LayerNorm weights, so just need the center + scale parts
48 normalization_layer = LayerNormPre
49 elif self.normalization_type == "RMS":
50 normalization_layer = RMSNorm
51 elif self.normalization_type == "RMSPre":
52 normalization_layer = RMSNormPre
53 elif self.normalization_type is None:
54 # This should just be the identity.
55 # We need to make this a lambda so we can call it on the config, just like the others
56 normalization_layer = lambda cfg: nn.Identity()
57 else:
58 raise ValueError(f"Invalid normalization_type passed in: {self.normalization_type}")
60 if self.cfg.use_normalization_before_and_after: 60 ↛ 63line 60 didn't jump to line 63 because the condition on line 60 was never true
61 # If we use LN before and after, we do *not* fold in the weights to the LN
62 # after, though we can fold for the one before.
63 if self.normalization_type is None:
64 normalization_layer_after = lambda cfg: nn.Identity()
65 elif self.normalization_type.startswith("RMS"):
66 normalization_layer_after = RMSNorm
67 elif self.normalization_type.startswith("LayerNorm"):
68 normalization_layer_after = LayerNorm
70 self.ln1 = normalization_layer(cfg)
71 if self.cfg.use_normalization_before_and_after: 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true
72 self.ln1_post = normalization_layer_after(cfg)
73 if not self.cfg.attn_only:
74 self.ln2 = normalization_layer(cfg)
75 if self.cfg.use_normalization_before_and_after: 75 ↛ 76line 75 didn't jump to line 76 because the condition on line 75 was never true
76 self.ln2_post = normalization_layer_after(cfg)
78 attention = Attention if self.cfg.n_key_value_heads is None else GroupedQueryAttention
79 if not self.cfg.use_local_attn:
80 self.attn = attention(self.cfg, "global", block_index)
81 else:
82 if self.cfg.attn_types is None: 82 ↛ 83line 82 didn't jump to line 83 because the condition on line 82 was never true
83 raise ValueError("attn_types must be set when using local attention")
84 attn_type = self.cfg.attn_types[block_index]
85 self.attn = attention(self.cfg, attn_type, block_index)
86 if not self.cfg.attn_only:
87 self.mlp = MLPFactory.create_mlp(self.cfg)
89 self.hook_attn_in = HookPoint() # [batch, pos, n_heads, d_model]
90 self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model]
91 self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model]
92 self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model]
93 self.hook_mlp_in = HookPoint() # [batch, pos, d_model]
95 self.hook_attn_out = HookPoint() # [batch, pos, d_model]
96 self.hook_mlp_out = HookPoint() # [batch, pos, d_model]
98 self.hook_resid_pre = HookPoint() # [batch, pos, d_model]
99 if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp:
100 self.hook_resid_mid = HookPoint() # [batch, pos, d_model]
101 self.hook_resid_post = HookPoint() # [batch, pos, d_model]
103 def forward(
104 self,
105 resid_pre: Float[torch.Tensor, "batch pos d_model"],
106 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
107 past_kv_cache_entry: Optional[TransformerLensKeyValueCacheEntry] = None,
108 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
109 ) -> Float[torch.Tensor, "batch pos d_model"]:
110 """A single Transformer block.
112 Args:
113 resid_pre (torch.Tensor): The residual stream - shape [batch, pos, d_model]
114 cache (TransformerLensKeyValueCache): A cache of previous keys and values, used only when generating text. Defaults to None.
115 shortformer_pos_embed (torch.Tensor, optional): Only used for positional_embeddings_type == "shortformer". The positional embeddings. See HookedTransformerConfig for details. Defaults to None.
116 attention_mask (torch.Tensor, optional): The attention mask for padded tokens. Defaults to None.
118 Returns:
119 Float[torch.Tensor, "batch pos d_model"]: Our resulting tensor
120 """
121 resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model]
123 if self.cfg.use_attn_in or self.cfg.use_split_qkv_input:
124 # We're adding a head dimension
125 if shortformer_pos_embed is not None: 125 ↛ 126line 125 didn't jump to line 126 because the condition on line 125 was never true
126 shortformer_pos_embed = repeat_along_head_dimension(
127 shortformer_pos_embed, n_heads=self.cfg.n_heads
128 )
129 else:
130 attn_in = resid_pre
132 if self.cfg.use_attn_in:
133 attn_in = self.hook_attn_in(
134 repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads)
135 )
137 if self.cfg.use_split_qkv_input:
138 n_kv_heads = (
139 self.cfg.n_key_value_heads
140 if self.cfg.n_key_value_heads is not None
141 and not self.cfg.ungroup_grouped_query_attention
142 else self.cfg.n_heads
143 )
144 query_input = self.hook_q_input(
145 repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads)
146 )
147 key_input = self.hook_k_input(
148 repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads)
149 )
150 value_input = self.hook_v_input(
151 repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads)
152 )
153 else:
154 query_input = attn_in
155 key_input = attn_in
156 value_input = attn_in
158 if self.cfg.original_architecture in ("Olmo2ForCausalLM", "Olmo3ForCausalLM"): 158 ↛ 159line 158 didn't jump to line 159 because the condition on line 158 was never true
159 attn_out = self.attn(
160 query_input=query_input,
161 key_input=key_input,
162 value_input=value_input,
163 past_kv_cache_entry=past_kv_cache_entry,
164 attention_mask=attention_mask,
165 )
166 else:
167 attn_out = (
168 # hook the residual stream states that are used to calculate the
169 # queries, keys and values, independently.
170 # Then take the layer norm of these inputs, and pass these to the attention module.
171 self.attn(
172 query_input=self.ln1(query_input)
173 + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
174 key_input=self.ln1(key_input)
175 + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
176 value_input=self.ln1(value_input),
177 past_kv_cache_entry=past_kv_cache_entry,
178 attention_mask=attention_mask,
179 )
180 ) # [batch, pos, d_model]
181 if self.cfg.use_normalization_before_and_after: 181 ↛ 185line 181 didn't jump to line 185 because the condition on line 181 was never true
182 # If we use LayerNorm both before and after, then apply the second LN after the layer
183 # and before the hook. We do it before the hook so hook_attn_out captures "that which
184 # is added to the residual stream"
185 attn_out = self.ln1_post(attn_out)
186 attn_out = self.hook_attn_out(attn_out)
188 if self.cfg.original_architecture in ("Olmo2ForCausalLM", "Olmo3ForCausalLM"): 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true
189 attn_out = self.ln1(attn_out)
191 if resid_pre.device != attn_out.device: 191 ↛ 192line 191 didn't jump to line 192 because the condition on line 191 was never true
192 resid_pre = resid_pre.to(attn_out.device)
194 if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp:
195 resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model]
196 mlp_in = (
197 resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone())
198 )
199 if self.cfg.original_architecture in ("Olmo2ForCausalLM", "Olmo3ForCausalLM"): 199 ↛ 200line 199 didn't jump to line 200 because the condition on line 199 was never true
200 mlp_out = self.apply_mlp(mlp_in)
201 mlp_out = self.ln2(mlp_out)
202 else:
203 normalized_resid_mid = self.ln2(mlp_in)
204 mlp_out = self.apply_mlp(normalized_resid_mid)
205 resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model]
206 elif self.cfg.parallel_attn_mlp:
207 # Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used.
208 # In GPT-J, LN1 and LN2 are tied, in GPT-NeoX they aren't.
209 normalized_resid_pre_2 = self.ln2(
210 resid_pre if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_pre.clone())
211 )
212 mlp_out = self.apply_mlp(normalized_resid_pre_2)
213 resid_post = self.hook_resid_post(
214 resid_pre + attn_out + mlp_out
215 ) # [batch, pos, d_model]
216 else:
217 resid_post = self.hook_resid_post(resid_pre + attn_out) # [batch, pos, d_model]
218 return resid_post
220 def apply_mlp(
221 self, normalized_resid: Float[torch.Tensor, "batch pos d_model"]
222 ) -> Float[torch.Tensor, "batch pos d_model"]:
223 """Centralized point where the MLP is applied to the forward pass
225 Returns:
226 Float[torch.Tensor, "batch pos d_model"]: Our resulting tensor
227 """
228 mlp_out = self.mlp(normalized_resid) # [batch, pos, d_model]
229 if self.cfg.use_normalization_before_and_after: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true
230 mlp_out = self.ln2_post(mlp_out)
231 return self.hook_mlp_out(mlp_out)