Coverage for transformer_lens/components/t5_block.py: 88%
64 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
1from typing import Optional
3import torch
4import torch.nn as nn
5from jaxtyping import Float
7from transformer_lens.components import RMSNorm, T5Attention
8from transformer_lens.factories.mlp_factory import MLPFactory
9from transformer_lens.hook_points import HookPoint
10from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
11from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry
12from transformer_lens.utils import repeat_along_head_dimension
15class T5Block(nn.Module):
16 """
17 T5 decoder Block. Uses T5Layernorm, and T5attention insted of usual ones.
18 Also uses cross attention if is_decoder is True.
19 """
21 def __init__(self, cfg: HookedTransformerConfig, block_index: int, is_decoder: bool):
22 super().__init__()
23 self.cfg = cfg
24 self.is_decoder = is_decoder
26 self.ln1 = RMSNorm(cfg)
27 self.attn = T5Attention(cfg, has_relative_attention_bias=block_index == 0)
28 self.ln2 = RMSNorm(cfg)
29 if self.is_decoder:
30 self.cross_attn = T5Attention(cfg)
31 self.ln3 = RMSNorm(cfg)
32 self.mlp = MLPFactory.create_mlp(self.cfg) # [batch, pos, n_heads]
34 self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model]
35 self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model]
36 self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model]
38 self.hook_attn_in = HookPoint() # [batch, pos, d_model]
39 self.hook_attn_out = HookPoint() # [batch, pos, d_model]
40 if self.is_decoder:
41 self.hook_cross_attn_in = HookPoint() # [batch, pos, d_model]
42 self.hook_cross_attn_out = HookPoint() # [batch, pos, d_model]
43 self.hook_resid_mid_cross = HookPoint() # [batch, pos, d_model]
45 self.hook_mlp_in = HookPoint() # [batch, pos, d_model]
46 self.hook_mlp_out = HookPoint() # [batch, pos, d_model]
47 self.hook_resid_pre = HookPoint() # [batch, pos, d_model]
48 self.hook_resid_mid = HookPoint() # [batch, pos, d_model]
49 self.hook_resid_post = HookPoint() # [batch, pos, d_model]
51 def forward(
52 self,
53 resid_pre: Float[torch.Tensor, "batch pos d_model"],
54 additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None,
55 encoder_additive_attention_mask: Optional[
56 Float[torch.Tensor, "batch 1 1 encoder_pos"]
57 ] = None,
58 position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None,
59 encoder_hidden_states: Optional[Float[torch.Tensor, "batch encoder_pos d_model"]] = None,
60 past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
61 ) -> Float[torch.Tensor, "batch pos d_model"]:
62 """A single Transformer block.
64 Args:
65 resid_pre (torch.Tensor): The residual stream - shape [batch, pos, d_model]
66 encoder_hidden_states (torch.Tensor): The hidden states of the encoder for cross attention - shape [batch, encoder_pos, d_model]
67 cache (HookedTransformerKeyValueCache): A cache of previous keys and values, used only when generating text. Defaults to None.
68 attention_mask (torch.Tensor, optional): The attention mask for padded tokens. Defaults to None.
70 Returns:
71 _type_: _description_
72 """
73 resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model]
75 attn_in = resid_pre
77 if self.cfg.use_attn_in: 77 ↛ 78line 77 didn't jump to line 78, because the condition on line 77 was never true
78 attn_in = self.hook_attn_in(
79 repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads)
80 )
82 if self.cfg.use_split_qkv_input: 82 ↛ 83line 82 didn't jump to line 83
83 n_kv_heads = (
84 self.cfg.n_key_value_heads
85 if self.cfg.n_key_value_heads is not None
86 else self.cfg.n_heads
87 )
88 query_input = self.hook_q_input(
89 repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads)
90 )
91 key_input = self.hook_k_input(
92 repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads)
93 )
94 value_input = self.hook_v_input(
95 repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads)
96 )
97 else:
98 query_input = attn_in
99 key_input = attn_in
100 value_input = attn_in
102 attn_out = self.hook_attn_out(
103 # hook the residual stream states that are used to calculate the
104 # queries, keys and values, independently.
105 # Then take the layer norm of these inputs, and pass these to the attention module.
106 self.attn(
107 query_input=self.ln1(query_input),
108 key_input=self.ln1(key_input),
109 value_input=self.ln1(value_input),
110 past_kv_cache_entry=past_kv_cache_entry,
111 additive_attention_mask=additive_attention_mask,
112 position_bias=position_bias,
113 )
114 )
116 # [batch, pos, d_model]
118 resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model]
120 if self.is_decoder:
121 cross_attn_in = (
122 resid_mid
123 if not self.cfg.use_attn_in
124 else self.hook_cross_attn_in(resid_mid.clone())
125 )
127 if encoder_hidden_states is None: 127 ↛ 128line 127 didn't jump to line 128, because the condition on line 127 was never true
128 raise ValueError("Encoder hidden states must be provided for cross attention!")
130 cross_attn_out = self.hook_cross_attn_out(
131 self.cross_attn(
132 query_input=self.ln2(cross_attn_in),
133 key_input=encoder_hidden_states,
134 value_input=encoder_hidden_states,
135 additive_attention_mask=encoder_additive_attention_mask,
136 )
137 )
138 resid_mid_cross = self.hook_resid_mid_cross(resid_mid + cross_attn_out)
140 mlp_in = (
141 resid_mid_cross
142 if not self.cfg.use_hook_mlp_in
143 else self.hook_mlp_in(resid_mid_cross.clone())
144 )
146 normalized_resid_mid = self.ln3(mlp_in)
147 else:
148 mlp_in = (
149 resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone())
150 )
151 normalized_resid_mid = self.ln2(mlp_in)
153 mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid)) # [batch, pos, d_model]
154 resid_post = self.hook_resid_post(mlp_in + mlp_out) # [batch, pos, d_model]
156 return resid_post