Coverage for transformer_lens/model_bridge/generalized_components/mpt_alibi_attention.py: 71%

69 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""MPT ALiBi attention bridge — MPT uses ``position_bias`` kwarg + bool causal mask.""" 

2 

3from __future__ import annotations 

4 

5import math 

6from typing import Any, Dict, Optional 

7 

8import torch 

9from packaging import version 

10 

11from transformer_lens.model_bridge.generalized_components.alibi_joint_qkv_attention import ( 

12 ALiBiJointQKVAttentionBridge, 

13) 

14 

15try: 

16 import transformers as _transformers 

17 

18 _TRANSFORMERS_V5 = version.parse(_transformers.__version__) >= version.parse("5.0.0") 

19except Exception: 

20 _TRANSFORMERS_V5 = False 

21 

22 

23def _build_mpt_alibi_tensor(num_heads: int, seq_len: int, alibi_bias_max: int = 8) -> torch.Tensor: 

24 """MPT ALiBi bias [num_heads, 1, seq_len] — mirrors HF's ``build_mpt_alibi_tensor``.""" 

25 alibi = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len) 

26 num_heads_power_of_2 = 2 ** math.ceil(math.log2(num_heads)) 

27 

28 base = torch.arange(1, num_heads_power_of_2 + 1, dtype=torch.int64).float() 

29 base = base * (alibi_bias_max / num_heads_power_of_2) 

30 slopes = 1.0 / torch.pow(2, base) 

31 slopes = slopes.view(1, num_heads_power_of_2, 1, 1) 

32 

33 if num_heads_power_of_2 != num_heads: 33 ↛ 34line 33 didn't jump to line 34 because the condition on line 33 was never true

34 slopes = torch.concat([slopes[:, 1::2, ...], slopes[:, ::2, ...]], dim=1)[ 

35 :, :num_heads, ... 

36 ] 

37 

38 alibi = alibi * slopes # [1, n_heads, 1, seq_len] 

39 return alibi.squeeze(0) # [n_heads, 1, seq_len] 

40 

41 

42class MPTALiBiAttentionBridge(ALiBiJointQKVAttentionBridge): 

43 """ALiBi bridge for MPT: overrides ALiBi kwarg name, bias shape, mask format, and clip_qkv.""" 

44 

45 _clip_qkv: Optional[float] = None 

46 

47 def forward( 

48 self, *args: Any, **kwargs: Any 

49 ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, None]: 

50 """2-tuple on transformers>=5, 3-tuple on <5 — MptBlock unpack arity changed in v5.""" 

51 output, attn_weights = super().forward(*args, **kwargs) 

52 if _TRANSFORMERS_V5: 52 ↛ 54line 52 didn't jump to line 54 because the condition on line 52 was always true

53 return output, attn_weights 

54 return output, attn_weights, None 

55 

56 def set_original_component(self, original_component: torch.nn.Module) -> None: 

57 super().set_original_component(original_component) 

58 if hasattr(self, "o") and hasattr(original_component, "out_proj"): 58 ↛ 60line 58 didn't jump to line 60 because the condition on line 58 was always true

59 self.o.set_original_component(original_component.out_proj) 

60 clip = getattr(original_component, "clip_qkv", None) 

61 self._clip_qkv = float(clip) if clip is not None else None 

62 

63 def _reconstruct_attention( 

64 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs: Any 

65 ) -> tuple[torch.Tensor, torch.Tensor]: 

66 # clip_qkv is post-projection, pre-head-split — must happen before reshape. 

67 if self._clip_qkv is not None: 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true

68 q = q.clamp(min=-self._clip_qkv, max=self._clip_qkv) 

69 k = k.clamp(min=-self._clip_qkv, max=self._clip_qkv) 

70 v = v.clamp(min=-self._clip_qkv, max=self._clip_qkv) 

71 

72 num_heads = self.config.n_heads if self.config else 32 

73 q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads( 

74 q, k, v, num_heads, num_heads 

75 ) 

76 

77 softmax_scale = head_dim**-0.5 

78 attn_scores = torch.matmul(q, k.transpose(-2, -1)) * softmax_scale 

79 

80 # position_bias is [n_heads, 1, max_seq_len]; slice trailing kv_len, broadcast over batch. 

81 position_bias = kwargs.get("position_bias", None) 

82 if position_bias is not None: 82 ↛ 88line 82 didn't jump to line 88 because the condition on line 82 was always true

83 kv_len = attn_scores.shape[-1] 

84 pb = position_bias[:, :, -kv_len:] 

85 attn_scores = attn_scores + pb.unsqueeze(0) 

86 

87 # MPT passes a bool 4D mask (True = masked), not an additive float mask. 

88 attention_mask = kwargs.get("attention_mask", None) 

89 if attention_mask is not None: 89 ↛ 94line 89 didn't jump to line 94 because the condition on line 89 was always true

90 attn_scores = attn_scores.masked_fill( 

91 attention_mask, torch.finfo(attn_scores.dtype).min 

92 ) 

93 

94 attn_scores = self.hook_attn_scores(attn_scores) 

95 

96 attn_weights = self._softmax_dropout_pattern( 

97 attn_scores, upcast_to_fp32=True, target_dtype=q.dtype 

98 ) 

99 

100 attn_output = torch.matmul(attn_weights, v) 

101 attn_output = self._reshape_attn_output( 

102 attn_output, batch_size, seq_len, num_heads, head_dim 

103 ) 

104 attn_output = self._apply_output_projection(attn_output) 

105 return attn_output, attn_weights 

106 

107 def get_random_inputs( 

108 self, 

109 batch_size: int = 2, 

110 seq_len: int = 8, 

111 device: Optional[torch.device] = None, 

112 dtype: Optional[torch.dtype] = None, 

113 ) -> Dict[str, Any]: 

114 """Test inputs using MPT's kwarg names: position_bias (no batch dim) + bool causal mask.""" 

115 if device is None: 

116 device = torch.device("cpu") 

117 if dtype is None: 

118 dtype = torch.float32 

119 

120 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 2048 

121 num_heads = self.config.n_heads if self.config and hasattr(self.config, "n_heads") else 32 

122 

123 position_bias = _build_mpt_alibi_tensor(num_heads, seq_len).to(device=device, dtype=dtype) 

124 

125 causal = torch.triu( 

126 torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=1 

127 ) 

128 causal_mask = causal.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1) 

129 

130 return { 

131 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype), 

132 "position_bias": position_bias, 

133 "attention_mask": causal_mask, 

134 }