Coverage for transformer_lens/components/transformer_block.py: 71%

110 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Hooked Transformer Transformer Block Component. 

2 

3This module contains all the component :class:`TransformerBlock`. 

4""" 

5 

6from typing import Callable, Dict, Optional, Union 

7 

8import torch 

9import torch.nn as nn 

10from jaxtyping import Float, Int 

11 

12from transformer_lens.cache.key_value_cache_entry import ( 

13 TransformerLensKeyValueCacheEntry, 

14) 

15from transformer_lens.components import ( 

16 Attention, 

17 GroupedQueryAttention, 

18 LayerNorm, 

19 LayerNormPre, 

20 RMSNorm, 

21 RMSNormPre, 

22) 

23from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP 

24from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig 

25from transformer_lens.factories.mlp_factory import MLPFactory 

26from transformer_lens.hook_points import HookPoint 

27from transformer_lens.utilities import repeat_along_head_dimension 

28 

29 

30# Transformer Block 

31class TransformerBlock(nn.Module): 

32 ln1: nn.Module 

33 ln2: nn.Module 

34 mlp: CanBeUsedAsMLP 

35 

36 def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index): 

37 super().__init__() 

38 self.cfg = HookedTransformerConfig.unwrap(cfg) 

39 normalization_layer: Callable # type: ignore 

40 normalization_layer_after: Callable # type: ignore 

41 

42 self.normalization_type = self.cfg.normalization_type 

43 

44 if self.normalization_type == "LN": 

45 normalization_layer = LayerNorm 

46 elif self.normalization_type == "LNPre": 46 ↛ 49line 46 didn't jump to line 49 because the condition on line 46 was always true

47 # We've folded in LayerNorm weights, so just need the center + scale parts 

48 normalization_layer = LayerNormPre 

49 elif self.normalization_type == "RMS": 

50 normalization_layer = RMSNorm 

51 elif self.normalization_type == "RMSPre": 

52 normalization_layer = RMSNormPre 

53 elif self.normalization_type is None: 

54 # This should just be the identity. 

55 # We need to make this a lambda so we can call it on the config, just like the others 

56 normalization_layer = lambda cfg: nn.Identity() 

57 else: 

58 raise ValueError(f"Invalid normalization_type passed in: {self.normalization_type}") 

59 

60 if self.cfg.use_normalization_before_and_after: 60 ↛ 63line 60 didn't jump to line 63 because the condition on line 60 was never true

61 # If we use LN before and after, we do *not* fold in the weights to the LN 

62 # after, though we can fold for the one before. 

63 if self.normalization_type is None: 

64 normalization_layer_after = lambda cfg: nn.Identity() 

65 elif self.normalization_type.startswith("RMS"): 

66 normalization_layer_after = RMSNorm 

67 elif self.normalization_type.startswith("LayerNorm"): 

68 normalization_layer_after = LayerNorm 

69 

70 self.ln1 = normalization_layer(cfg) 

71 if self.cfg.use_normalization_before_and_after: 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true

72 self.ln1_post = normalization_layer_after(cfg) 

73 if not self.cfg.attn_only: 

74 self.ln2 = normalization_layer(cfg) 

75 if self.cfg.use_normalization_before_and_after: 75 ↛ 76line 75 didn't jump to line 76 because the condition on line 75 was never true

76 self.ln2_post = normalization_layer_after(cfg) 

77 

78 attention = Attention if self.cfg.n_key_value_heads is None else GroupedQueryAttention 

79 if not self.cfg.use_local_attn: 

80 self.attn = attention(self.cfg, "global", block_index) 

81 else: 

82 if self.cfg.attn_types is None: 82 ↛ 83line 82 didn't jump to line 83 because the condition on line 82 was never true

83 raise ValueError("attn_types must be set when using local attention") 

84 attn_type = self.cfg.attn_types[block_index] 

85 self.attn = attention(self.cfg, attn_type, block_index) 

86 if not self.cfg.attn_only: 

87 self.mlp = MLPFactory.create_mlp(self.cfg) 

88 

89 self.hook_attn_in = HookPoint() # [batch, pos, n_heads, d_model] 

90 self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model] 

91 self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model] 

92 self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model] 

93 self.hook_mlp_in = HookPoint() # [batch, pos, d_model] 

94 

95 self.hook_attn_out = HookPoint() # [batch, pos, d_model] 

96 self.hook_mlp_out = HookPoint() # [batch, pos, d_model] 

97 

98 self.hook_resid_pre = HookPoint() # [batch, pos, d_model] 

99 if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp: 

100 self.hook_resid_mid = HookPoint() # [batch, pos, d_model] 

101 self.hook_resid_post = HookPoint() # [batch, pos, d_model] 

102 

103 def forward( 

104 self, 

105 resid_pre: Float[torch.Tensor, "batch pos d_model"], 

106 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

107 past_kv_cache_entry: Optional[TransformerLensKeyValueCacheEntry] = None, 

108 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

109 ) -> Float[torch.Tensor, "batch pos d_model"]: 

110 """A single Transformer block. 

111 

112 Args: 

113 resid_pre (torch.Tensor): The residual stream - shape [batch, pos, d_model] 

114 cache (TransformerLensKeyValueCache): A cache of previous keys and values, used only when generating text. Defaults to None. 

115 shortformer_pos_embed (torch.Tensor, optional): Only used for positional_embeddings_type == "shortformer". The positional embeddings. See HookedTransformerConfig for details. Defaults to None. 

116 attention_mask (torch.Tensor, optional): The attention mask for padded tokens. Defaults to None. 

117 

118 Returns: 

119 Float[torch.Tensor, "batch pos d_model"]: Our resulting tensor 

120 """ 

121 resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model] 

122 

123 if self.cfg.use_attn_in or self.cfg.use_split_qkv_input: 

124 # We're adding a head dimension 

125 if shortformer_pos_embed is not None: 125 ↛ 126line 125 didn't jump to line 126 because the condition on line 125 was never true

126 shortformer_pos_embed = repeat_along_head_dimension( 

127 shortformer_pos_embed, n_heads=self.cfg.n_heads 

128 ) 

129 else: 

130 attn_in = resid_pre 

131 

132 if self.cfg.use_attn_in: 

133 attn_in = self.hook_attn_in( 

134 repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads) 

135 ) 

136 

137 if self.cfg.use_split_qkv_input: 

138 n_kv_heads = ( 

139 self.cfg.n_key_value_heads 

140 if self.cfg.n_key_value_heads is not None 

141 and not self.cfg.ungroup_grouped_query_attention 

142 else self.cfg.n_heads 

143 ) 

144 query_input = self.hook_q_input( 

145 repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads) 

146 ) 

147 key_input = self.hook_k_input( 

148 repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads) 

149 ) 

150 value_input = self.hook_v_input( 

151 repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads) 

152 ) 

153 else: 

154 query_input = attn_in 

155 key_input = attn_in 

156 value_input = attn_in 

157 

158 if self.cfg.original_architecture in ("Olmo2ForCausalLM", "Olmo3ForCausalLM"): 158 ↛ 159line 158 didn't jump to line 159 because the condition on line 158 was never true

159 attn_out = self.attn( 

160 query_input=query_input, 

161 key_input=key_input, 

162 value_input=value_input, 

163 past_kv_cache_entry=past_kv_cache_entry, 

164 attention_mask=attention_mask, 

165 ) 

166 else: 

167 attn_out = ( 

168 # hook the residual stream states that are used to calculate the 

169 # queries, keys and values, independently. 

170 # Then take the layer norm of these inputs, and pass these to the attention module. 

171 self.attn( 

172 query_input=self.ln1(query_input) 

173 + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), 

174 key_input=self.ln1(key_input) 

175 + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), 

176 value_input=self.ln1(value_input), 

177 past_kv_cache_entry=past_kv_cache_entry, 

178 attention_mask=attention_mask, 

179 ) 

180 ) # [batch, pos, d_model] 

181 if self.cfg.use_normalization_before_and_after: 181 ↛ 185line 181 didn't jump to line 185 because the condition on line 181 was never true

182 # If we use LayerNorm both before and after, then apply the second LN after the layer 

183 # and before the hook. We do it before the hook so hook_attn_out captures "that which 

184 # is added to the residual stream" 

185 attn_out = self.ln1_post(attn_out) 

186 attn_out = self.hook_attn_out(attn_out) 

187 

188 if self.cfg.original_architecture in ("Olmo2ForCausalLM", "Olmo3ForCausalLM"): 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true

189 attn_out = self.ln1(attn_out) 

190 

191 if resid_pre.device != attn_out.device: 191 ↛ 192line 191 didn't jump to line 192 because the condition on line 191 was never true

192 resid_pre = resid_pre.to(attn_out.device) 

193 

194 if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp: 

195 resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model] 

196 mlp_in = ( 

197 resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone()) 

198 ) 

199 if self.cfg.original_architecture in ("Olmo2ForCausalLM", "Olmo3ForCausalLM"): 199 ↛ 200line 199 didn't jump to line 200 because the condition on line 199 was never true

200 mlp_out = self.apply_mlp(mlp_in) 

201 mlp_out = self.ln2(mlp_out) 

202 else: 

203 normalized_resid_mid = self.ln2(mlp_in) 

204 mlp_out = self.apply_mlp(normalized_resid_mid) 

205 resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model] 

206 elif self.cfg.parallel_attn_mlp: 

207 # Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used. 

208 # In GPT-J, LN1 and LN2 are tied, in GPT-NeoX they aren't. 

209 normalized_resid_pre_2 = self.ln2( 

210 resid_pre if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_pre.clone()) 

211 ) 

212 mlp_out = self.apply_mlp(normalized_resid_pre_2) 

213 resid_post = self.hook_resid_post( 

214 resid_pre + attn_out + mlp_out 

215 ) # [batch, pos, d_model] 

216 else: 

217 resid_post = self.hook_resid_post(resid_pre + attn_out) # [batch, pos, d_model] 

218 return resid_post 

219 

220 def apply_mlp( 

221 self, normalized_resid: Float[torch.Tensor, "batch pos d_model"] 

222 ) -> Float[torch.Tensor, "batch pos d_model"]: 

223 """Centralized point where the MLP is applied to the forward pass 

224 

225 Returns: 

226 Float[torch.Tensor, "batch pos d_model"]: Our resulting tensor 

227 """ 

228 mlp_out = self.mlp(normalized_resid) # [batch, pos, d_model] 

229 if self.cfg.use_normalization_before_and_after: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true

230 mlp_out = self.ln2_post(mlp_out) 

231 return self.hook_mlp_out(mlp_out)