Coverage for transformer_lens/components/transformer_block.py: 78%

101 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-14 00:54 +0000

1"""Hooked Transformer Transformer Block Component. 

2 

3This module contains all the component :class:`TransformerBlock`. 

4""" 

5 

6from typing import Callable, Dict, Optional, Union 

7 

8import torch 

9import torch.nn as nn 

10from jaxtyping import Float, Int 

11 

12from transformer_lens.components import ( 

13 Attention, 

14 GroupedQueryAttention, 

15 LayerNorm, 

16 LayerNormPre, 

17 RMSNorm, 

18 RMSNormPre, 

19) 

20from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP 

21from transformer_lens.factories.mlp_factory import MLPFactory 

22from transformer_lens.hook_points import HookPoint 

23from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

24from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry 

25from transformer_lens.utils import repeat_along_head_dimension 

26 

27 

28# Transformer Block 

29class TransformerBlock(nn.Module): 

30 ln1: nn.Module 

31 ln2: nn.Module 

32 mlp: CanBeUsedAsMLP 

33 

34 def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index): 

35 super().__init__() 

36 self.cfg = HookedTransformerConfig.unwrap(cfg) 

37 normalization_layer: Callable # type: ignore 

38 normalization_layer_after: Callable # type: ignore 

39 

40 self.normalization_type = self.cfg.normalization_type 

41 

42 if self.normalization_type == "LN": 

43 normalization_layer = LayerNorm 

44 elif self.normalization_type == "LNPre": 

45 # We've folded in LayerNorm weights, so just need the center + scale parts 

46 normalization_layer = LayerNormPre 

47 elif self.normalization_type == "RMS": 47 ↛ 48line 47 didn't jump to line 48, because the condition on line 47 was never true

48 normalization_layer = RMSNorm 

49 elif self.normalization_type == "RMSPre": 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true

50 normalization_layer = RMSNormPre 

51 elif self.normalization_type is None: 51 ↛ 56line 51 didn't jump to line 56, because the condition on line 51 was never false

52 # This should just be the identity. 

53 # We need to make this a lambda so we can call it on the config, just like the others 

54 normalization_layer = lambda cfg: nn.Identity() 

55 else: 

56 raise ValueError(f"Invalid normalization_type passed in: {self.normalization_type}") 

57 

58 if self.cfg.use_normalization_before_and_after: 58 ↛ 61line 58 didn't jump to line 61, because the condition on line 58 was never true

59 # If we use LN before and after, we do *not* fold in the weights to the LN 

60 # after, though we can fold for the one before. 

61 if self.normalization_type is None: 

62 normalization_layer_after = lambda cfg: nn.Identity() 

63 elif self.normalization_type.startswith("RMS"): 

64 normalization_layer_after = RMSNorm 

65 elif self.normalization_type.startswith("LayerNorm"): 

66 normalization_layer_after = LayerNorm 

67 

68 self.ln1 = normalization_layer(cfg) 

69 if self.cfg.use_normalization_before_and_after: 69 ↛ 70line 69 didn't jump to line 70, because the condition on line 69 was never true

70 self.ln1_post = normalization_layer_after(cfg) 

71 if not self.cfg.attn_only: 

72 self.ln2 = normalization_layer(cfg) 

73 if self.cfg.use_normalization_before_and_after: 73 ↛ 74line 73 didn't jump to line 74, because the condition on line 73 was never true

74 self.ln2_post = normalization_layer_after(cfg) 

75 

76 attention = Attention if self.cfg.n_key_value_heads is None else GroupedQueryAttention 

77 if not self.cfg.use_local_attn: 

78 self.attn = attention(self.cfg, "global", block_index) 

79 else: 

80 if self.cfg.attn_types is None: 80 ↛ 81line 80 didn't jump to line 81, because the condition on line 80 was never true

81 raise ValueError("attn_types must be set when using local attention") 

82 attn_type = self.cfg.attn_types[block_index] 

83 self.attn = attention(self.cfg, attn_type, block_index) 

84 if not self.cfg.attn_only: 

85 self.mlp = MLPFactory.create_mlp(self.cfg) 

86 

87 self.hook_attn_in = HookPoint() # [batch, pos, n_heads, d_model] 

88 self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model] 

89 self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model] 

90 self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model] 

91 self.hook_mlp_in = HookPoint() # [batch, pos, d_model] 

92 

93 self.hook_attn_out = HookPoint() # [batch, pos, d_model] 

94 self.hook_mlp_out = HookPoint() # [batch, pos, d_model] 

95 

96 self.hook_resid_pre = HookPoint() # [batch, pos, d_model] 

97 if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp: 

98 self.hook_resid_mid = HookPoint() # [batch, pos, d_model] 

99 self.hook_resid_post = HookPoint() # [batch, pos, d_model] 

100 

101 def forward( 

102 self, 

103 resid_pre: Float[torch.Tensor, "batch pos d_model"], 

104 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

105 past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, 

106 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

107 ) -> Float[torch.Tensor, "batch pos d_model"]: 

108 """A single Transformer block. 

109 

110 Args: 

111 resid_pre (torch.Tensor): The residual stream - shape [batch, pos, d_model] 

112 cache (HookedTransformerKeyValueCache): A cache of previous keys and values, used only when generating text. Defaults to None. 

113 shortformer_pos_embed (torch.Tensor, optional): Only used for positional_embeddings_type == "shortformer". The positional embeddings. See HookedTransformerConfig for details. Defaults to None. 

114 attention_mask (torch.Tensor, optional): The attention mask for padded tokens. Defaults to None. 

115 

116 Returns: 

117 Float[torch.Tensor, "batch pos d_model"]: Our resulting tensor 

118 """ 

119 resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model] 

120 

121 if self.cfg.use_attn_in or self.cfg.use_split_qkv_input: 

122 # We're adding a head dimension 

123 if shortformer_pos_embed is not None: 123 ↛ 124line 123 didn't jump to line 124, because the condition on line 123 was never true

124 shortformer_pos_embed = repeat_along_head_dimension( 

125 shortformer_pos_embed, n_heads=self.cfg.n_heads 

126 ) 

127 else: 

128 attn_in = resid_pre 

129 

130 if self.cfg.use_attn_in: 

131 attn_in = self.hook_attn_in( 

132 repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads) 

133 ) 

134 

135 if self.cfg.use_split_qkv_input: 

136 n_kv_heads = ( 

137 self.cfg.n_key_value_heads 

138 if self.cfg.n_key_value_heads is not None 

139 and not self.cfg.ungroup_grouped_query_attention 

140 else self.cfg.n_heads 

141 ) 

142 query_input = self.hook_q_input( 

143 repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads) 

144 ) 

145 key_input = self.hook_k_input( 

146 repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads) 

147 ) 

148 value_input = self.hook_v_input( 

149 repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads) 

150 ) 

151 else: 

152 query_input = attn_in 

153 key_input = attn_in 

154 value_input = attn_in 

155 

156 attn_out = ( 

157 # hook the residual stream states that are used to calculate the 

158 # queries, keys and values, independently. 

159 # Then take the layer norm of these inputs, and pass these to the attention module. 

160 self.attn( 

161 query_input=self.ln1(query_input) 

162 + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), 

163 key_input=self.ln1(key_input) 

164 + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), 

165 value_input=self.ln1(value_input), 

166 past_kv_cache_entry=past_kv_cache_entry, 

167 attention_mask=attention_mask, 

168 ) 

169 ) # [batch, pos, d_model] 

170 if self.cfg.use_normalization_before_and_after: 170 ↛ 174line 170 didn't jump to line 174, because the condition on line 170 was never true

171 # If we use LayerNorm both before and after, then apply the second LN after the layer 

172 # and before the hook. We do it before the hook so hook_attn_out captures "that which 

173 # is added to the residual stream" 

174 attn_out = self.ln1_post(attn_out) 

175 attn_out = self.hook_attn_out(attn_out) 

176 if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp: 

177 resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model] 

178 mlp_in = ( 

179 resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone()) 

180 ) 

181 normalized_resid_mid = self.ln2(mlp_in) 

182 mlp_out = self.apply_mlp(normalized_resid_mid) 

183 resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model] 

184 elif self.cfg.parallel_attn_mlp: 

185 # Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used. 

186 # In GPT-J, LN1 and LN2 are tied, in GPT-NeoX they aren't. 

187 normalized_resid_pre_2 = self.ln2( 

188 resid_pre if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_pre.clone()) 

189 ) 

190 mlp_out = self.apply_mlp(normalized_resid_pre_2) 

191 resid_post = self.hook_resid_post( 

192 resid_pre + attn_out + mlp_out 

193 ) # [batch, pos, d_model] 

194 else: 

195 resid_post = self.hook_resid_post(resid_pre + attn_out) # [batch, pos, d_model] 

196 return resid_post 

197 

198 def apply_mlp( 

199 self, normalized_resid: Float[torch.Tensor, "batch pos d_model"] 

200 ) -> Float[torch.Tensor, "batch pos d_model"]: 

201 """Centralized point where the MLP is applied to the forward pass 

202 

203 Returns: 

204 Float[torch.Tensor, "batch pos d_model"]: Our resulting tensor 

205 """ 

206 mlp_out = self.mlp(normalized_resid) # [batch, pos, d_model] 

207 if self.cfg.use_normalization_before_and_after: 207 ↛ 208line 207 didn't jump to line 208, because the condition on line 207 was never true

208 mlp_out = self.ln2_post(mlp_out) 

209 return self.hook_mlp_out(mlp_out)