Coverage for transformer_lens/model_bridge/generalized_components/mpt_alibi_attention.py: 71%
69 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""MPT ALiBi attention bridge — MPT uses ``position_bias`` kwarg + bool causal mask."""
3from __future__ import annotations
5import math
6from typing import Any, Dict, Optional
8import torch
9from packaging import version
11from transformer_lens.model_bridge.generalized_components.alibi_joint_qkv_attention import (
12 ALiBiJointQKVAttentionBridge,
13)
15try:
16 import transformers as _transformers
18 _TRANSFORMERS_V5 = version.parse(_transformers.__version__) >= version.parse("5.0.0")
19except Exception:
20 _TRANSFORMERS_V5 = False
23def _build_mpt_alibi_tensor(num_heads: int, seq_len: int, alibi_bias_max: int = 8) -> torch.Tensor:
24 """MPT ALiBi bias [num_heads, 1, seq_len] — mirrors HF's ``build_mpt_alibi_tensor``."""
25 alibi = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
26 num_heads_power_of_2 = 2 ** math.ceil(math.log2(num_heads))
28 base = torch.arange(1, num_heads_power_of_2 + 1, dtype=torch.int64).float()
29 base = base * (alibi_bias_max / num_heads_power_of_2)
30 slopes = 1.0 / torch.pow(2, base)
31 slopes = slopes.view(1, num_heads_power_of_2, 1, 1)
33 if num_heads_power_of_2 != num_heads: 33 ↛ 34line 33 didn't jump to line 34 because the condition on line 33 was never true
34 slopes = torch.concat([slopes[:, 1::2, ...], slopes[:, ::2, ...]], dim=1)[
35 :, :num_heads, ...
36 ]
38 alibi = alibi * slopes # [1, n_heads, 1, seq_len]
39 return alibi.squeeze(0) # [n_heads, 1, seq_len]
42class MPTALiBiAttentionBridge(ALiBiJointQKVAttentionBridge):
43 """ALiBi bridge for MPT: overrides ALiBi kwarg name, bias shape, mask format, and clip_qkv."""
45 _clip_qkv: Optional[float] = None
47 def forward(
48 self, *args: Any, **kwargs: Any
49 ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, None]:
50 """2-tuple on transformers>=5, 3-tuple on <5 — MptBlock unpack arity changed in v5."""
51 output, attn_weights = super().forward(*args, **kwargs)
52 if _TRANSFORMERS_V5: 52 ↛ 54line 52 didn't jump to line 54 because the condition on line 52 was always true
53 return output, attn_weights
54 return output, attn_weights, None
56 def set_original_component(self, original_component: torch.nn.Module) -> None:
57 super().set_original_component(original_component)
58 if hasattr(self, "o") and hasattr(original_component, "out_proj"): 58 ↛ 60line 58 didn't jump to line 60 because the condition on line 58 was always true
59 self.o.set_original_component(original_component.out_proj)
60 clip = getattr(original_component, "clip_qkv", None)
61 self._clip_qkv = float(clip) if clip is not None else None
63 def _reconstruct_attention(
64 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs: Any
65 ) -> tuple[torch.Tensor, torch.Tensor]:
66 # clip_qkv is post-projection, pre-head-split — must happen before reshape.
67 if self._clip_qkv is not None: 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true
68 q = q.clamp(min=-self._clip_qkv, max=self._clip_qkv)
69 k = k.clamp(min=-self._clip_qkv, max=self._clip_qkv)
70 v = v.clamp(min=-self._clip_qkv, max=self._clip_qkv)
72 num_heads = self.config.n_heads if self.config else 32
73 q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads(
74 q, k, v, num_heads, num_heads
75 )
77 softmax_scale = head_dim**-0.5
78 attn_scores = torch.matmul(q, k.transpose(-2, -1)) * softmax_scale
80 # position_bias is [n_heads, 1, max_seq_len]; slice trailing kv_len, broadcast over batch.
81 position_bias = kwargs.get("position_bias", None)
82 if position_bias is not None: 82 ↛ 88line 82 didn't jump to line 88 because the condition on line 82 was always true
83 kv_len = attn_scores.shape[-1]
84 pb = position_bias[:, :, -kv_len:]
85 attn_scores = attn_scores + pb.unsqueeze(0)
87 # MPT passes a bool 4D mask (True = masked), not an additive float mask.
88 attention_mask = kwargs.get("attention_mask", None)
89 if attention_mask is not None: 89 ↛ 94line 89 didn't jump to line 94 because the condition on line 89 was always true
90 attn_scores = attn_scores.masked_fill(
91 attention_mask, torch.finfo(attn_scores.dtype).min
92 )
94 attn_scores = self.hook_attn_scores(attn_scores)
96 attn_weights = self._softmax_dropout_pattern(
97 attn_scores, upcast_to_fp32=True, target_dtype=q.dtype
98 )
100 attn_output = torch.matmul(attn_weights, v)
101 attn_output = self._reshape_attn_output(
102 attn_output, batch_size, seq_len, num_heads, head_dim
103 )
104 attn_output = self._apply_output_projection(attn_output)
105 return attn_output, attn_weights
107 def get_random_inputs(
108 self,
109 batch_size: int = 2,
110 seq_len: int = 8,
111 device: Optional[torch.device] = None,
112 dtype: Optional[torch.dtype] = None,
113 ) -> Dict[str, Any]:
114 """Test inputs using MPT's kwarg names: position_bias (no batch dim) + bool causal mask."""
115 if device is None:
116 device = torch.device("cpu")
117 if dtype is None:
118 dtype = torch.float32
120 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 2048
121 num_heads = self.config.n_heads if self.config and hasattr(self.config, "n_heads") else 32
123 position_bias = _build_mpt_alibi_tensor(num_heads, seq_len).to(device=device, dtype=dtype)
125 causal = torch.triu(
126 torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=1
127 )
128 causal_mask = causal.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1)
130 return {
131 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype),
132 "position_bias": position_bias,
133 "attention_mask": causal_mask,
134 }