Coverage for transformer_lens/components/t5_attention.py: 90%

50 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-11-19 14:42 +0000

1import math 

2from typing import Dict, Optional, Union 

3 

4import torch 

5import torch.nn as nn 

6from jaxtyping import Float, Int 

7 

8from transformer_lens.components.abstract_attention import AbstractAttention 

9from transformer_lens.hook_points import HookPoint 

10from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

11 

12 

13class T5Attention(AbstractAttention): 

14 r""" 

15 T5 attention - with relative attention bias and cross-attention support 

16 This realisation expects you to precompute relative positional bias, and then feed it to forward 

17 like 

18 ```python 

19 attn = T5Attention(cfg, has_relative_attention_bias=True) 

20 positional_bias = attn.compute_relative_attention_bias(query_len, key_len, device=device) 

21 result = attn(query, key, value, position_bias=positional_bias) 

22 ``` 

23 """ 

24 

25 def __init__( 

26 self, 

27 cfg: Union[Dict, HookedTransformerConfig], 

28 has_relative_attention_bias: bool = False, 

29 attn_type: str = "global", 

30 layer_id: Optional[int] = None, 

31 ): 

32 super().__init__(cfg, attn_type, layer_id) 

33 if isinstance(cfg, Dict): 33 ↛ 34line 33 didn't jump to line 34, because the condition on line 33 was never true

34 cfg = HookedTransformerConfig.from_dict(cfg) 

35 self.cfg = cfg 

36 self.has_relative_attention_bias: bool = has_relative_attention_bias 

37 

38 if self.has_relative_attention_bias: 

39 if ( 39 ↛ 43line 39 didn't jump to line 43

40 cfg.relative_attention_num_buckets is None 

41 or cfg.relative_attention_max_distance is None 

42 ): 

43 raise ValueError( 

44 "You need to specify relative_attention_num_buckets and relative_attention_max_distance in config to use relative attention bias" 

45 ) 

46 

47 self.relative_attention_num_buckets = cfg.relative_attention_num_buckets 

48 self.relative_attention_max_distance = cfg.relative_attention_max_distance 

49 self.rel_pos_bias = nn.Embedding(self.relative_attention_num_buckets, self.cfg.n_heads) 

50 self.rel_pos_hook = HookPoint() 

51 

52 self.W_K = nn.Parameter( 

53 torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype) 

54 ) 

55 self.W_V = nn.Parameter( 

56 torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype) 

57 ) 

58 self.b_K = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) 

59 self.b_V = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) 

60 

61 @staticmethod 

62 def _relative_position_bucket( 

63 relative_position: Int[torch.Tensor, "query_pos kv_pos"], 

64 bidirectional=True, 

65 num_buckets=32, 

66 max_distance=128, 

67 ) -> Int[torch.Tensor, "query_pos kv_pos"]: 

68 """ 

69 added from 

70 https://github.com/huggingface/transformers/blob/e0c3cee17085914bbe505c159beeb8ae39bc37dd/src/transformers/models/t5/modeling_t5.py#L382 

71 which is adapted from 

72 https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 

73 

74 

75 Translate relative position to a bucket number for relative attention. The relative position is defined as 

76 memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to 

77 position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for 

78 small absolute relative_position and larger buckets for larger absolute relative_positions. All relative 

79 positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. 

80 This should allow for more graceful generalization to longer sequences than the model has been trained on 

81 

82 Args: 

83 relative_position: an int32 Tensor 

84 bidirectional: a boolean - whether the attention is bidirectional 

85 num_buckets: an integer 

86 max_distance: an integer 

87 

88 Returns: 

89 a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) 

90 """ 

91 relative_buckets = torch.zeros_like(relative_position) 

92 

93 if bidirectional: 93 ↛ 98line 93 didn't jump to line 98, because the condition on line 93 was never false

94 num_buckets //= 2 

95 relative_buckets += (relative_position > 0).to(torch.long) * num_buckets 

96 relative_position = torch.abs(relative_position) 

97 else: 

98 relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) 

99 # now relative_position is in the range [0, inf) 

100 

101 # half of the buckets are for exact increments in positions 

102 max_exact = num_buckets // 2 

103 is_small = relative_position < max_exact 

104 

105 # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance 

106 relative_position_if_large = max_exact + ( 

107 torch.log(relative_position.float() / max_exact) 

108 / math.log(max_distance / max_exact) 

109 * (num_buckets - max_exact) 

110 ).to(torch.long) 

111 relative_position_if_large = torch.min( 

112 relative_position_if_large, 

113 torch.full_like(relative_position_if_large, num_buckets - 1), 

114 ) 

115 

116 relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) 

117 return relative_buckets 

118 

119 def compute_relative_attention_bias( 

120 self, query_length: int, key_length: int, device=None 

121 ) -> Float[torch.Tensor, "1 head_index pos kv_pos"]: 

122 """Compute binned relative position bias""" 

123 if device is None: 

124 device = self.rel_pos_bias.weight.device 

125 context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] 

126 memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] 

127 relative_position = memory_position - context_position # shape (query_length, key_length) 

128 relative_position_bucket = self._relative_position_bucket( 

129 relative_position, # shape (query_length, key_length) 

130 bidirectional=True, 

131 num_buckets=self.relative_attention_num_buckets, 

132 max_distance=self.relative_attention_max_distance, 

133 ) 

134 values = self.rel_pos_bias( 

135 relative_position_bucket 

136 ) # shape (query_length, key_length, num_heads) 

137 values = values.permute([2, 0, 1]).unsqueeze( 

138 0 

139 ) # shape (1, num_heads, query_length, key_length) 

140 return values