Coverage for transformer_lens/components/transformer_block.py: 82%

91 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1"""Hooked Transformer Transformer Block Component. 

2 

3This module contains all the component :class:`TransformerBlock`. 

4""" 

5import logging 

6from typing import Dict, Optional, Union 

7 

8import torch 

9import torch.nn as nn 

10from jaxtyping import Float, Int 

11 

12from transformer_lens.components import ( 

13 MLP, 

14 Attention, 

15 GatedMLP, 

16 GroupedQueryAttention, 

17 LayerNorm, 

18 LayerNormPre, 

19 MoE, 

20 RMSNorm, 

21 RMSNormPre, 

22) 

23from transformer_lens.hook_points import HookPoint 

24from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

25from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry 

26from transformer_lens.utils import repeat_along_head_dimension 

27 

28 

29# Transformer Block 

30class TransformerBlock(nn.Module): 

31 ln1: nn.Module 

32 ln2: nn.Module 

33 mlp: nn.Module 

34 

35 def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index): 

36 super().__init__() 

37 self.cfg = HookedTransformerConfig.unwrap(cfg) 

38 if self.cfg.normalization_type == "LN": 

39 self.ln1 = LayerNorm(cfg) 

40 if not self.cfg.attn_only: 

41 self.ln2 = LayerNorm(cfg) 

42 elif self.cfg.normalization_type == "LNPre": 

43 # We've folded in LayerNorm weights, so just need the center + scale parts 

44 self.ln1 = LayerNormPre(cfg) 

45 if not self.cfg.attn_only: 

46 self.ln2 = LayerNormPre(cfg) 

47 elif self.cfg.normalization_type == "RMS": 47 ↛ 48line 47 didn't jump to line 48, because the condition on line 47 was never true

48 self.ln1 = RMSNorm(cfg) 

49 if not self.cfg.attn_only: 

50 self.ln2 = RMSNorm(cfg) 

51 elif self.cfg.normalization_type == "RMSPre": 51 ↛ 52line 51 didn't jump to line 52, because the condition on line 51 was never true

52 self.ln1 = RMSNormPre(cfg) 

53 if not self.cfg.attn_only: 

54 self.ln2 = RMSNormPre(cfg) 

55 elif self.cfg.normalization_type is None: 55 ↛ 60line 55 didn't jump to line 60, because the condition on line 55 was never false

56 self.ln1 = nn.Identity() 

57 if not self.cfg.attn_only: 57 ↛ 58line 57 didn't jump to line 58, because the condition on line 57 was never true

58 self.ln2 = nn.Identity() 

59 else: 

60 logging.warning(f"Invalid normalization_type passed in {self.cfg.normalization_type}") 

61 

62 attention = Attention if self.cfg.n_key_value_heads is None else GroupedQueryAttention 

63 if not self.cfg.use_local_attn: 

64 self.attn = attention(cfg, "global", block_index) 

65 else: 

66 if self.cfg.attn_types is None: 66 ↛ 67line 66 didn't jump to line 67, because the condition on line 66 was never true

67 raise ValueError("attn_types must be set when using local attention") 

68 attn_type = self.cfg.attn_types[block_index] 

69 self.attn = attention(cfg, attn_type, block_index) 

70 if not self.cfg.attn_only: 

71 if self.cfg.num_experts: 71 ↛ 72line 71 didn't jump to line 72, because the condition on line 71 was never true

72 self.mlp = MoE(cfg) 

73 elif self.cfg.gated_mlp: 73 ↛ 74line 73 didn't jump to line 74, because the condition on line 73 was never true

74 self.mlp = GatedMLP(cfg) 

75 else: 

76 self.mlp = MLP(cfg) 

77 

78 self.hook_attn_in = HookPoint() # [batch, pos, n_heads, d_model] 

79 self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model] 

80 self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model] 

81 self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model] 

82 self.hook_mlp_in = HookPoint() # [batch, pos, d_model] 

83 

84 self.hook_attn_out = HookPoint() # [batch, pos, d_model] 

85 self.hook_mlp_out = HookPoint() # [batch, pos, d_model] 

86 

87 self.hook_resid_pre = HookPoint() # [batch, pos, d_model] 

88 if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp: 

89 self.hook_resid_mid = HookPoint() # [batch, pos, d_model] 

90 self.hook_resid_post = HookPoint() # [batch, pos, d_model] 

91 

92 def forward( 

93 self, 

94 resid_pre: Float[torch.Tensor, "batch pos d_model"], 

95 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

96 past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, 

97 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

98 ) -> Float[torch.Tensor, "batch pos d_model"]: 

99 """A single Transformer block. 

100 

101 Args: 

102 resid_pre (torch.Tensor): The residual stream - shape [batch, pos, d_model] 

103 cache (HookedTransformerKeyValueCache): A cache of previous keys and values, used only when generating text. Defaults to None. 

104 shortformer_pos_embed (torch.Tensor, optional): Only used for positional_embeddings_type == "shortformer". The positional embeddings. See HookedTransformerConfig for details. Defaults to None. 

105 attention_mask (torch.Tensor, optional): The attention mask for padded tokens. Defaults to None. 

106 

107 Returns: 

108 _type_: _description_ 

109 """ 

110 resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model] 

111 

112 if self.cfg.use_attn_in or self.cfg.use_split_qkv_input: 

113 # We're adding a head dimension 

114 if shortformer_pos_embed is not None: 114 ↛ 115line 114 didn't jump to line 115, because the condition on line 114 was never true

115 shortformer_pos_embed = repeat_along_head_dimension( 

116 shortformer_pos_embed, n_heads=self.cfg.n_heads 

117 ) 

118 else: 

119 attn_in = resid_pre 

120 

121 if self.cfg.use_attn_in: 

122 attn_in = self.hook_attn_in( 

123 repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads) 

124 ) 

125 

126 if self.cfg.use_split_qkv_input: 

127 n_kv_heads = ( 

128 self.cfg.n_key_value_heads 

129 if self.cfg.n_key_value_heads is not None 

130 else self.cfg.n_heads 

131 ) 

132 query_input = self.hook_q_input( 

133 repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads) 

134 ) 

135 key_input = self.hook_k_input( 

136 repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads) 

137 ) 

138 value_input = self.hook_v_input( 

139 repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads) 

140 ) 

141 else: 

142 query_input = attn_in 

143 key_input = attn_in 

144 value_input = attn_in 

145 

146 attn_out = self.hook_attn_out( 

147 # hook the residual stream states that are used to calculate the 

148 # queries, keys and values, independently. 

149 # Then take the layer norm of these inputs, and pass these to the attention module. 

150 self.attn( 

151 query_input=self.ln1(query_input) 

152 + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), 

153 key_input=self.ln1(key_input) 

154 + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), 

155 value_input=self.ln1(value_input), 

156 past_kv_cache_entry=past_kv_cache_entry, 

157 attention_mask=attention_mask, 

158 ) 

159 ) # [batch, pos, d_model] 

160 if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp: 

161 resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model] 

162 mlp_in = ( 

163 resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone()) 

164 ) 

165 normalized_resid_mid = self.ln2(mlp_in) 

166 mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid)) # [batch, pos, d_model] 

167 resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model] 

168 elif self.cfg.parallel_attn_mlp: 

169 # Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used. 

170 # In GPT-J, LN1 and LN2 are tied, in GPT-NeoX they aren't. 

171 normalized_resid_pre_2 = self.ln2( 

172 resid_pre if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_pre.clone()) 

173 ) 

174 mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_pre_2)) # [batch, pos, d_model] 

175 resid_post = self.hook_resid_post( 

176 resid_pre + attn_out + mlp_out 

177 ) # [batch, pos, d_model] 

178 else: 

179 resid_post = self.hook_resid_post(resid_pre + attn_out) # [batch, pos, d_model] 

180 return resid_post