Coverage for transformer_lens/components/t5_block.py: 88%

64 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-14 00:54 +0000

1from typing import Optional 

2 

3import torch 

4import torch.nn as nn 

5from jaxtyping import Float 

6 

7from transformer_lens.components import RMSNorm, T5Attention 

8from transformer_lens.factories.mlp_factory import MLPFactory 

9from transformer_lens.hook_points import HookPoint 

10from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

11from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry 

12from transformer_lens.utils import repeat_along_head_dimension 

13 

14 

15class T5Block(nn.Module): 

16 """ 

17 T5 decoder Block. Uses T5Layernorm, and T5attention insted of usual ones. 

18 Also uses cross attention if is_decoder is True. 

19 """ 

20 

21 def __init__(self, cfg: HookedTransformerConfig, block_index: int, is_decoder: bool): 

22 super().__init__() 

23 self.cfg = cfg 

24 self.is_decoder = is_decoder 

25 

26 self.ln1 = RMSNorm(cfg) 

27 self.attn = T5Attention(cfg, has_relative_attention_bias=block_index == 0) 

28 self.ln2 = RMSNorm(cfg) 

29 if self.is_decoder: 

30 self.cross_attn = T5Attention(cfg) 

31 self.ln3 = RMSNorm(cfg) 

32 self.mlp = MLPFactory.create_mlp(self.cfg) # [batch, pos, n_heads] 

33 

34 self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model] 

35 self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model] 

36 self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model] 

37 

38 self.hook_attn_in = HookPoint() # [batch, pos, d_model] 

39 self.hook_attn_out = HookPoint() # [batch, pos, d_model] 

40 if self.is_decoder: 

41 self.hook_cross_attn_in = HookPoint() # [batch, pos, d_model] 

42 self.hook_cross_attn_out = HookPoint() # [batch, pos, d_model] 

43 self.hook_resid_mid_cross = HookPoint() # [batch, pos, d_model] 

44 

45 self.hook_mlp_in = HookPoint() # [batch, pos, d_model] 

46 self.hook_mlp_out = HookPoint() # [batch, pos, d_model] 

47 self.hook_resid_pre = HookPoint() # [batch, pos, d_model] 

48 self.hook_resid_mid = HookPoint() # [batch, pos, d_model] 

49 self.hook_resid_post = HookPoint() # [batch, pos, d_model] 

50 

51 def forward( 

52 self, 

53 resid_pre: Float[torch.Tensor, "batch pos d_model"], 

54 additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None, 

55 encoder_additive_attention_mask: Optional[ 

56 Float[torch.Tensor, "batch 1 1 encoder_pos"] 

57 ] = None, 

58 position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None, 

59 encoder_hidden_states: Optional[Float[torch.Tensor, "batch encoder_pos d_model"]] = None, 

60 past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, 

61 ) -> Float[torch.Tensor, "batch pos d_model"]: 

62 """A single Transformer block. 

63 

64 Args: 

65 resid_pre (torch.Tensor): The residual stream - shape [batch, pos, d_model] 

66 encoder_hidden_states (torch.Tensor): The hidden states of the encoder for cross attention - shape [batch, encoder_pos, d_model] 

67 cache (HookedTransformerKeyValueCache): A cache of previous keys and values, used only when generating text. Defaults to None. 

68 attention_mask (torch.Tensor, optional): The attention mask for padded tokens. Defaults to None. 

69 

70 Returns: 

71 _type_: _description_ 

72 """ 

73 resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model] 

74 

75 attn_in = resid_pre 

76 

77 if self.cfg.use_attn_in: 77 ↛ 78line 77 didn't jump to line 78, because the condition on line 77 was never true

78 attn_in = self.hook_attn_in( 

79 repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads) 

80 ) 

81 

82 if self.cfg.use_split_qkv_input: 82 ↛ 83line 82 didn't jump to line 83

83 n_kv_heads = ( 

84 self.cfg.n_key_value_heads 

85 if self.cfg.n_key_value_heads is not None 

86 else self.cfg.n_heads 

87 ) 

88 query_input = self.hook_q_input( 

89 repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads) 

90 ) 

91 key_input = self.hook_k_input( 

92 repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads) 

93 ) 

94 value_input = self.hook_v_input( 

95 repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads) 

96 ) 

97 else: 

98 query_input = attn_in 

99 key_input = attn_in 

100 value_input = attn_in 

101 

102 attn_out = self.hook_attn_out( 

103 # hook the residual stream states that are used to calculate the 

104 # queries, keys and values, independently. 

105 # Then take the layer norm of these inputs, and pass these to the attention module. 

106 self.attn( 

107 query_input=self.ln1(query_input), 

108 key_input=self.ln1(key_input), 

109 value_input=self.ln1(value_input), 

110 past_kv_cache_entry=past_kv_cache_entry, 

111 additive_attention_mask=additive_attention_mask, 

112 position_bias=position_bias, 

113 ) 

114 ) 

115 

116 # [batch, pos, d_model] 

117 

118 resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model] 

119 

120 if self.is_decoder: 

121 cross_attn_in = ( 

122 resid_mid 

123 if not self.cfg.use_attn_in 

124 else self.hook_cross_attn_in(resid_mid.clone()) 

125 ) 

126 

127 if encoder_hidden_states is None: 127 ↛ 128line 127 didn't jump to line 128, because the condition on line 127 was never true

128 raise ValueError("Encoder hidden states must be provided for cross attention!") 

129 

130 cross_attn_out = self.hook_cross_attn_out( 

131 self.cross_attn( 

132 query_input=self.ln2(cross_attn_in), 

133 key_input=encoder_hidden_states, 

134 value_input=encoder_hidden_states, 

135 additive_attention_mask=encoder_additive_attention_mask, 

136 ) 

137 ) 

138 resid_mid_cross = self.hook_resid_mid_cross(resid_mid + cross_attn_out) 

139 

140 mlp_in = ( 

141 resid_mid_cross 

142 if not self.cfg.use_hook_mlp_in 

143 else self.hook_mlp_in(resid_mid_cross.clone()) 

144 ) 

145 

146 normalized_resid_mid = self.ln3(mlp_in) 

147 else: 

148 mlp_in = ( 

149 resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone()) 

150 ) 

151 normalized_resid_mid = self.ln2(mlp_in) 

152 

153 mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid)) # [batch, pos, d_model] 

154 resid_post = self.hook_resid_post(mlp_in + mlp_out) # [batch, pos, d_model] 

155 

156 return resid_post