Coverage for transformer_lens/components/transformer_block.py: 78%
101 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
1"""Hooked Transformer Transformer Block Component.
3This module contains all the component :class:`TransformerBlock`.
4"""
6from typing import Callable, Dict, Optional, Union
8import torch
9import torch.nn as nn
10from jaxtyping import Float, Int
12from transformer_lens.components import (
13 Attention,
14 GroupedQueryAttention,
15 LayerNorm,
16 LayerNormPre,
17 RMSNorm,
18 RMSNormPre,
19)
20from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
21from transformer_lens.factories.mlp_factory import MLPFactory
22from transformer_lens.hook_points import HookPoint
23from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
24from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry
25from transformer_lens.utils import repeat_along_head_dimension
28# Transformer Block
29class TransformerBlock(nn.Module):
30 ln1: nn.Module
31 ln2: nn.Module
32 mlp: CanBeUsedAsMLP
34 def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index):
35 super().__init__()
36 self.cfg = HookedTransformerConfig.unwrap(cfg)
37 normalization_layer: Callable # type: ignore
38 normalization_layer_after: Callable # type: ignore
40 self.normalization_type = self.cfg.normalization_type
42 if self.normalization_type == "LN":
43 normalization_layer = LayerNorm
44 elif self.normalization_type == "LNPre":
45 # We've folded in LayerNorm weights, so just need the center + scale parts
46 normalization_layer = LayerNormPre
47 elif self.normalization_type == "RMS": 47 ↛ 48line 47 didn't jump to line 48, because the condition on line 47 was never true
48 normalization_layer = RMSNorm
49 elif self.normalization_type == "RMSPre": 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true
50 normalization_layer = RMSNormPre
51 elif self.normalization_type is None: 51 ↛ 56line 51 didn't jump to line 56, because the condition on line 51 was never false
52 # This should just be the identity.
53 # We need to make this a lambda so we can call it on the config, just like the others
54 normalization_layer = lambda cfg: nn.Identity()
55 else:
56 raise ValueError(f"Invalid normalization_type passed in: {self.normalization_type}")
58 if self.cfg.use_normalization_before_and_after: 58 ↛ 61line 58 didn't jump to line 61, because the condition on line 58 was never true
59 # If we use LN before and after, we do *not* fold in the weights to the LN
60 # after, though we can fold for the one before.
61 if self.normalization_type is None:
62 normalization_layer_after = lambda cfg: nn.Identity()
63 elif self.normalization_type.startswith("RMS"):
64 normalization_layer_after = RMSNorm
65 elif self.normalization_type.startswith("LayerNorm"):
66 normalization_layer_after = LayerNorm
68 self.ln1 = normalization_layer(cfg)
69 if self.cfg.use_normalization_before_and_after: 69 ↛ 70line 69 didn't jump to line 70, because the condition on line 69 was never true
70 self.ln1_post = normalization_layer_after(cfg)
71 if not self.cfg.attn_only:
72 self.ln2 = normalization_layer(cfg)
73 if self.cfg.use_normalization_before_and_after: 73 ↛ 74line 73 didn't jump to line 74, because the condition on line 73 was never true
74 self.ln2_post = normalization_layer_after(cfg)
76 attention = Attention if self.cfg.n_key_value_heads is None else GroupedQueryAttention
77 if not self.cfg.use_local_attn:
78 self.attn = attention(self.cfg, "global", block_index)
79 else:
80 if self.cfg.attn_types is None: 80 ↛ 81line 80 didn't jump to line 81, because the condition on line 80 was never true
81 raise ValueError("attn_types must be set when using local attention")
82 attn_type = self.cfg.attn_types[block_index]
83 self.attn = attention(self.cfg, attn_type, block_index)
84 if not self.cfg.attn_only:
85 self.mlp = MLPFactory.create_mlp(self.cfg)
87 self.hook_attn_in = HookPoint() # [batch, pos, n_heads, d_model]
88 self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model]
89 self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model]
90 self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model]
91 self.hook_mlp_in = HookPoint() # [batch, pos, d_model]
93 self.hook_attn_out = HookPoint() # [batch, pos, d_model]
94 self.hook_mlp_out = HookPoint() # [batch, pos, d_model]
96 self.hook_resid_pre = HookPoint() # [batch, pos, d_model]
97 if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp:
98 self.hook_resid_mid = HookPoint() # [batch, pos, d_model]
99 self.hook_resid_post = HookPoint() # [batch, pos, d_model]
101 def forward(
102 self,
103 resid_pre: Float[torch.Tensor, "batch pos d_model"],
104 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
105 past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
106 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
107 ) -> Float[torch.Tensor, "batch pos d_model"]:
108 """A single Transformer block.
110 Args:
111 resid_pre (torch.Tensor): The residual stream - shape [batch, pos, d_model]
112 cache (HookedTransformerKeyValueCache): A cache of previous keys and values, used only when generating text. Defaults to None.
113 shortformer_pos_embed (torch.Tensor, optional): Only used for positional_embeddings_type == "shortformer". The positional embeddings. See HookedTransformerConfig for details. Defaults to None.
114 attention_mask (torch.Tensor, optional): The attention mask for padded tokens. Defaults to None.
116 Returns:
117 Float[torch.Tensor, "batch pos d_model"]: Our resulting tensor
118 """
119 resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model]
121 if self.cfg.use_attn_in or self.cfg.use_split_qkv_input:
122 # We're adding a head dimension
123 if shortformer_pos_embed is not None: 123 ↛ 124line 123 didn't jump to line 124, because the condition on line 123 was never true
124 shortformer_pos_embed = repeat_along_head_dimension(
125 shortformer_pos_embed, n_heads=self.cfg.n_heads
126 )
127 else:
128 attn_in = resid_pre
130 if self.cfg.use_attn_in:
131 attn_in = self.hook_attn_in(
132 repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads)
133 )
135 if self.cfg.use_split_qkv_input:
136 n_kv_heads = (
137 self.cfg.n_key_value_heads
138 if self.cfg.n_key_value_heads is not None
139 and not self.cfg.ungroup_grouped_query_attention
140 else self.cfg.n_heads
141 )
142 query_input = self.hook_q_input(
143 repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads)
144 )
145 key_input = self.hook_k_input(
146 repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads)
147 )
148 value_input = self.hook_v_input(
149 repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads)
150 )
151 else:
152 query_input = attn_in
153 key_input = attn_in
154 value_input = attn_in
156 attn_out = (
157 # hook the residual stream states that are used to calculate the
158 # queries, keys and values, independently.
159 # Then take the layer norm of these inputs, and pass these to the attention module.
160 self.attn(
161 query_input=self.ln1(query_input)
162 + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
163 key_input=self.ln1(key_input)
164 + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
165 value_input=self.ln1(value_input),
166 past_kv_cache_entry=past_kv_cache_entry,
167 attention_mask=attention_mask,
168 )
169 ) # [batch, pos, d_model]
170 if self.cfg.use_normalization_before_and_after: 170 ↛ 174line 170 didn't jump to line 174, because the condition on line 170 was never true
171 # If we use LayerNorm both before and after, then apply the second LN after the layer
172 # and before the hook. We do it before the hook so hook_attn_out captures "that which
173 # is added to the residual stream"
174 attn_out = self.ln1_post(attn_out)
175 attn_out = self.hook_attn_out(attn_out)
176 if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp:
177 resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model]
178 mlp_in = (
179 resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone())
180 )
181 normalized_resid_mid = self.ln2(mlp_in)
182 mlp_out = self.apply_mlp(normalized_resid_mid)
183 resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model]
184 elif self.cfg.parallel_attn_mlp:
185 # Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used.
186 # In GPT-J, LN1 and LN2 are tied, in GPT-NeoX they aren't.
187 normalized_resid_pre_2 = self.ln2(
188 resid_pre if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_pre.clone())
189 )
190 mlp_out = self.apply_mlp(normalized_resid_pre_2)
191 resid_post = self.hook_resid_post(
192 resid_pre + attn_out + mlp_out
193 ) # [batch, pos, d_model]
194 else:
195 resid_post = self.hook_resid_post(resid_pre + attn_out) # [batch, pos, d_model]
196 return resid_post
198 def apply_mlp(
199 self, normalized_resid: Float[torch.Tensor, "batch pos d_model"]
200 ) -> Float[torch.Tensor, "batch pos d_model"]:
201 """Centralized point where the MLP is applied to the forward pass
203 Returns:
204 Float[torch.Tensor, "batch pos d_model"]: Our resulting tensor
205 """
206 mlp_out = self.mlp(normalized_resid) # [batch, pos, d_model]
207 if self.cfg.use_normalization_before_and_after: 207 ↛ 208line 207 didn't jump to line 208, because the condition on line 207 was never true
208 mlp_out = self.ln2_post(mlp_out)
209 return self.hook_mlp_out(mlp_out)