Coverage for transformer_lens/components/t5_attention.py: 90%
50 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
1import math
2from typing import Dict, Optional, Union
4import torch
5import torch.nn as nn
6from jaxtyping import Float, Int
8from transformer_lens.components.abstract_attention import AbstractAttention
9from transformer_lens.hook_points import HookPoint
10from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
13class T5Attention(AbstractAttention):
14 r"""
15 T5 attention - with relative attention bias and cross-attention support
16 This realisation expects you to precompute relative positional bias, and then feed it to forward
17 like
18 ```python
19 attn = T5Attention(cfg, has_relative_attention_bias=True)
20 positional_bias = attn.compute_relative_attention_bias(query_len, key_len, device=device)
21 result = attn(query, key, value, position_bias=positional_bias)
22 ```
23 """
25 def __init__(
26 self,
27 cfg: Union[Dict, HookedTransformerConfig],
28 has_relative_attention_bias: bool = False,
29 attn_type: str = "global",
30 layer_id: Optional[int] = None,
31 ):
32 super().__init__(cfg, attn_type, layer_id)
33 if isinstance(cfg, Dict): 33 ↛ 34line 33 didn't jump to line 34, because the condition on line 33 was never true
34 cfg = HookedTransformerConfig.from_dict(cfg)
35 self.cfg = cfg
36 self.has_relative_attention_bias: bool = has_relative_attention_bias
38 if self.has_relative_attention_bias:
39 if ( 39 ↛ 43line 39 didn't jump to line 43
40 cfg.relative_attention_num_buckets is None
41 or cfg.relative_attention_max_distance is None
42 ):
43 raise ValueError(
44 "You need to specify relative_attention_num_buckets and relative_attention_max_distance in config to use relative attention bias"
45 )
47 self.relative_attention_num_buckets = cfg.relative_attention_num_buckets
48 self.relative_attention_max_distance = cfg.relative_attention_max_distance
49 self.rel_pos_bias = nn.Embedding(self.relative_attention_num_buckets, self.cfg.n_heads)
50 self.rel_pos_hook = HookPoint()
52 self.W_K = nn.Parameter(
53 torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype)
54 )
55 self.W_V = nn.Parameter(
56 torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype)
57 )
58 self.b_K = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype))
59 self.b_V = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype))
61 @staticmethod
62 def _relative_position_bucket(
63 relative_position: Int[torch.Tensor, "query_pos kv_pos"],
64 bidirectional=True,
65 num_buckets=32,
66 max_distance=128,
67 ) -> Int[torch.Tensor, "query_pos kv_pos"]:
68 """
69 added from
70 https://github.com/huggingface/transformers/blob/e0c3cee17085914bbe505c159beeb8ae39bc37dd/src/transformers/models/t5/modeling_t5.py#L382
71 which is adapted from
72 https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
75 Translate relative position to a bucket number for relative attention. The relative position is defined as
76 memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
77 position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
78 small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
79 positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
80 This should allow for more graceful generalization to longer sequences than the model has been trained on
82 Args:
83 relative_position: an int32 Tensor
84 bidirectional: a boolean - whether the attention is bidirectional
85 num_buckets: an integer
86 max_distance: an integer
88 Returns:
89 a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
90 """
91 relative_buckets = torch.zeros_like(relative_position)
93 if bidirectional: 93 ↛ 98line 93 didn't jump to line 98, because the condition on line 93 was never false
94 num_buckets //= 2
95 relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
96 relative_position = torch.abs(relative_position)
97 else:
98 relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
99 # now relative_position is in the range [0, inf)
101 # half of the buckets are for exact increments in positions
102 max_exact = num_buckets // 2
103 is_small = relative_position < max_exact
105 # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
106 relative_position_if_large = max_exact + (
107 torch.log(relative_position.float() / max_exact)
108 / math.log(max_distance / max_exact)
109 * (num_buckets - max_exact)
110 ).to(torch.long)
111 relative_position_if_large = torch.min(
112 relative_position_if_large,
113 torch.full_like(relative_position_if_large, num_buckets - 1),
114 )
116 relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
117 return relative_buckets
119 def compute_relative_attention_bias(
120 self, query_length: int, key_length: int, device=None
121 ) -> Float[torch.Tensor, "1 head_index pos kv_pos"]:
122 """Compute binned relative position bias"""
123 if device is None:
124 device = self.rel_pos_bias.weight.device
125 context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
126 memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
127 relative_position = memory_position - context_position # shape (query_length, key_length)
128 relative_position_bucket = self._relative_position_bucket(
129 relative_position, # shape (query_length, key_length)
130 bidirectional=True,
131 num_buckets=self.relative_attention_num_buckets,
132 max_distance=self.relative_attention_max_distance,
133 )
134 values = self.rel_pos_bias(
135 relative_position_bucket
136 ) # shape (query_length, key_length, num_heads)
137 values = values.permute([2, 0, 1]).unsqueeze(
138 0
139 ) # shape (1, num_heads, query_length, key_length)
140 return values