Coverage for transformer_lens/components/transformer_block.py: 82%
91 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
1"""Hooked Transformer Transformer Block Component.
3This module contains all the component :class:`TransformerBlock`.
4"""
5import logging
6from typing import Dict, Optional, Union
8import torch
9import torch.nn as nn
10from jaxtyping import Float, Int
12from transformer_lens.components import (
13 MLP,
14 Attention,
15 GatedMLP,
16 GroupedQueryAttention,
17 LayerNorm,
18 LayerNormPre,
19 MoE,
20 RMSNorm,
21 RMSNormPre,
22)
23from transformer_lens.hook_points import HookPoint
24from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
25from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry
26from transformer_lens.utils import repeat_along_head_dimension
29# Transformer Block
30class TransformerBlock(nn.Module):
31 ln1: nn.Module
32 ln2: nn.Module
33 mlp: nn.Module
35 def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index):
36 super().__init__()
37 self.cfg = HookedTransformerConfig.unwrap(cfg)
38 if self.cfg.normalization_type == "LN":
39 self.ln1 = LayerNorm(cfg)
40 if not self.cfg.attn_only:
41 self.ln2 = LayerNorm(cfg)
42 elif self.cfg.normalization_type == "LNPre":
43 # We've folded in LayerNorm weights, so just need the center + scale parts
44 self.ln1 = LayerNormPre(cfg)
45 if not self.cfg.attn_only:
46 self.ln2 = LayerNormPre(cfg)
47 elif self.cfg.normalization_type == "RMS": 47 ↛ 48line 47 didn't jump to line 48, because the condition on line 47 was never true
48 self.ln1 = RMSNorm(cfg)
49 if not self.cfg.attn_only:
50 self.ln2 = RMSNorm(cfg)
51 elif self.cfg.normalization_type == "RMSPre": 51 ↛ 52line 51 didn't jump to line 52, because the condition on line 51 was never true
52 self.ln1 = RMSNormPre(cfg)
53 if not self.cfg.attn_only:
54 self.ln2 = RMSNormPre(cfg)
55 elif self.cfg.normalization_type is None: 55 ↛ 60line 55 didn't jump to line 60, because the condition on line 55 was never false
56 self.ln1 = nn.Identity()
57 if not self.cfg.attn_only: 57 ↛ 58line 57 didn't jump to line 58, because the condition on line 57 was never true
58 self.ln2 = nn.Identity()
59 else:
60 logging.warning(f"Invalid normalization_type passed in {self.cfg.normalization_type}")
62 attention = Attention if self.cfg.n_key_value_heads is None else GroupedQueryAttention
63 if not self.cfg.use_local_attn:
64 self.attn = attention(cfg, "global", block_index)
65 else:
66 if self.cfg.attn_types is None: 66 ↛ 67line 66 didn't jump to line 67, because the condition on line 66 was never true
67 raise ValueError("attn_types must be set when using local attention")
68 attn_type = self.cfg.attn_types[block_index]
69 self.attn = attention(cfg, attn_type, block_index)
70 if not self.cfg.attn_only:
71 if self.cfg.num_experts: 71 ↛ 72line 71 didn't jump to line 72, because the condition on line 71 was never true
72 self.mlp = MoE(cfg)
73 elif self.cfg.gated_mlp: 73 ↛ 74line 73 didn't jump to line 74, because the condition on line 73 was never true
74 self.mlp = GatedMLP(cfg)
75 else:
76 self.mlp = MLP(cfg)
78 self.hook_attn_in = HookPoint() # [batch, pos, n_heads, d_model]
79 self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model]
80 self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model]
81 self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model]
82 self.hook_mlp_in = HookPoint() # [batch, pos, d_model]
84 self.hook_attn_out = HookPoint() # [batch, pos, d_model]
85 self.hook_mlp_out = HookPoint() # [batch, pos, d_model]
87 self.hook_resid_pre = HookPoint() # [batch, pos, d_model]
88 if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp:
89 self.hook_resid_mid = HookPoint() # [batch, pos, d_model]
90 self.hook_resid_post = HookPoint() # [batch, pos, d_model]
92 def forward(
93 self,
94 resid_pre: Float[torch.Tensor, "batch pos d_model"],
95 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
96 past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
97 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
98 ) -> Float[torch.Tensor, "batch pos d_model"]:
99 """A single Transformer block.
101 Args:
102 resid_pre (torch.Tensor): The residual stream - shape [batch, pos, d_model]
103 cache (HookedTransformerKeyValueCache): A cache of previous keys and values, used only when generating text. Defaults to None.
104 shortformer_pos_embed (torch.Tensor, optional): Only used for positional_embeddings_type == "shortformer". The positional embeddings. See HookedTransformerConfig for details. Defaults to None.
105 attention_mask (torch.Tensor, optional): The attention mask for padded tokens. Defaults to None.
107 Returns:
108 _type_: _description_
109 """
110 resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model]
112 if self.cfg.use_attn_in or self.cfg.use_split_qkv_input:
113 # We're adding a head dimension
114 if shortformer_pos_embed is not None: 114 ↛ 115line 114 didn't jump to line 115, because the condition on line 114 was never true
115 shortformer_pos_embed = repeat_along_head_dimension(
116 shortformer_pos_embed, n_heads=self.cfg.n_heads
117 )
118 else:
119 attn_in = resid_pre
121 if self.cfg.use_attn_in:
122 attn_in = self.hook_attn_in(
123 repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads)
124 )
126 if self.cfg.use_split_qkv_input:
127 n_kv_heads = (
128 self.cfg.n_key_value_heads
129 if self.cfg.n_key_value_heads is not None
130 else self.cfg.n_heads
131 )
132 query_input = self.hook_q_input(
133 repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads)
134 )
135 key_input = self.hook_k_input(
136 repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads)
137 )
138 value_input = self.hook_v_input(
139 repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads)
140 )
141 else:
142 query_input = attn_in
143 key_input = attn_in
144 value_input = attn_in
146 attn_out = self.hook_attn_out(
147 # hook the residual stream states that are used to calculate the
148 # queries, keys and values, independently.
149 # Then take the layer norm of these inputs, and pass these to the attention module.
150 self.attn(
151 query_input=self.ln1(query_input)
152 + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
153 key_input=self.ln1(key_input)
154 + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
155 value_input=self.ln1(value_input),
156 past_kv_cache_entry=past_kv_cache_entry,
157 attention_mask=attention_mask,
158 )
159 ) # [batch, pos, d_model]
160 if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp:
161 resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model]
162 mlp_in = (
163 resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone())
164 )
165 normalized_resid_mid = self.ln2(mlp_in)
166 mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid)) # [batch, pos, d_model]
167 resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model]
168 elif self.cfg.parallel_attn_mlp:
169 # Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used.
170 # In GPT-J, LN1 and LN2 are tied, in GPT-NeoX they aren't.
171 normalized_resid_pre_2 = self.ln2(
172 resid_pre if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_pre.clone())
173 )
174 mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_pre_2)) # [batch, pos, d_model]
175 resid_post = self.hook_resid_post(
176 resid_pre + attn_out + mlp_out
177 ) # [batch, pos, d_model]
178 else:
179 resid_post = self.hook_resid_post(resid_pre + attn_out) # [batch, pos, d_model]
180 return resid_post