Coverage for transformer_lens/model_bridge/generalized_components/alibi_joint_qkv_attention.py: 83%
69 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""ALiBi joint QKV attention bridge component.
3Handles models that use ALiBi (Attention with Linear Biases) with fused QKV projections.
4Splits fused QKV, reimplements attention with ALiBi bias and hooks at each stage.
5"""
7from __future__ import annotations
9from typing import Any, Dict, Optional
11import torch
13from transformer_lens.model_bridge.generalized_components.alibi_utils import (
14 build_alibi_tensor,
15)
16from transformer_lens.model_bridge.generalized_components.base import (
17 GeneralizedComponent,
18)
19from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
20 JointQKVAttentionBridge,
21)
24class ALiBiJointQKVAttentionBridge(JointQKVAttentionBridge):
25 """Attention bridge for models using ALiBi position encoding with fused QKV.
27 Splits fused QKV, reimplements attention with ALiBi bias fused into scores,
28 and fires hooks at each stage (hook_q, hook_k, hook_v, hook_attn_scores,
29 hook_pattern). ALiBi bias is added to raw attention scores before scaling.
30 """
32 def __init__(
33 self,
34 name: str,
35 config: Any,
36 split_qkv_matrix: Any = None,
37 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
38 **kwargs: Any,
39 ):
40 super().__init__(
41 name=name,
42 config=config,
43 split_qkv_matrix=split_qkv_matrix,
44 submodules=submodules,
45 requires_position_embeddings=False,
46 requires_attention_mask=False,
47 **kwargs,
48 )
50 def forward(self, *args: Any, **kwargs: Any) -> Any:
51 """Forward pass: split QKV, apply ALiBi, fire hooks."""
52 if self.original_component is None: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true
53 raise RuntimeError(
54 f"Original component not set for {self.name}. "
55 "Call set_original_component() first."
56 )
58 if len(args) > 0 and isinstance(args[0], torch.Tensor): 58 ↛ 60line 58 didn't jump to line 60 because the condition on line 58 was always true
59 hidden_states = args[0]
60 elif "hidden_states" in kwargs:
61 hidden_states = kwargs.pop("hidden_states")
62 else:
63 raise ValueError("Could not find hidden_states in args or kwargs")
65 hooked_input = self.hook_in(hidden_states)
67 # Split fused QKV via parent's split mechanism
68 q_output = self.q(hooked_input)
69 k_output = self.k(hooked_input)
70 v_output = self.v(hooked_input)
72 attn_output, attn_weights = self._reconstruct_attention(
73 q_output, k_output, v_output, **kwargs
74 )
76 output = self.hook_out(attn_output)
77 return output, attn_weights
79 def _reconstruct_attention(
80 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs: Any
81 ) -> tuple[torch.Tensor, torch.Tensor]:
82 """Reconstruct attention with ALiBi bias fused into scores."""
83 num_heads = self.config.n_heads if self.config else 32
84 num_kv_heads = getattr(self.config, "n_key_value_heads", None) or num_heads
85 q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads(
86 q, k, v, num_heads, num_kv_heads
87 )
89 # GQA/MQA: expand K/V heads to match Q heads
90 if num_kv_heads != num_heads:
91 n_rep = num_heads // num_kv_heads
92 k = k.repeat_interleave(n_rep, dim=1)
93 v = v.repeat_interleave(n_rep, dim=1)
95 inv_norm_factor = head_dim**-0.5
97 # Raw attention scores
98 attn_scores = torch.matmul(q, k.transpose(-2, -1))
99 attn_scores = attn_scores.view(batch_size, num_heads, seq_len, -1)
101 # Upcast for numerical stability
102 input_dtype = attn_scores.dtype
103 if input_dtype in (torch.float16, torch.bfloat16): 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true
104 attn_scores = attn_scores.to(torch.float32)
106 # Add ALiBi bias
107 alibi = kwargs.get("alibi", None)
108 if alibi is not None:
109 kv_len = attn_scores.shape[-1]
110 alibi_view = alibi.view(batch_size, num_heads, 1, -1)
111 if alibi_view.shape[-1] > kv_len: 111 ↛ 112line 111 didn't jump to line 112 because the condition on line 111 was never true
112 alibi_view = alibi_view[..., :kv_len]
113 attn_scores = attn_scores + alibi_view
115 # Scale after ALiBi (matches HF Falcon)
116 attn_scores = attn_scores * inv_norm_factor
118 # Add attention mask
119 attention_mask = kwargs.get("attention_mask", None)
120 if attention_mask is not None: 120 ↛ 123line 120 didn't jump to line 123 because the condition on line 120 was always true
121 attn_scores = attn_scores + attention_mask[:, :, :, : attn_scores.shape[-1]]
123 attn_scores = self.hook_attn_scores(attn_scores)
125 attn_weights = self._softmax_dropout_pattern(
126 attn_scores, upcast_to_fp32=True, target_dtype=q.dtype
127 )
129 # Weighted sum
130 attn_output = torch.matmul(
131 attn_weights.view(batch_size * num_heads, seq_len, -1),
132 v.reshape(batch_size * num_heads, -1, head_dim),
133 )
134 attn_output = attn_output.view(batch_size, num_heads, seq_len, head_dim)
135 attn_output = self._reshape_attn_output(
136 attn_output, batch_size, seq_len, num_heads, head_dim
137 )
138 attn_output = self._apply_output_projection(attn_output)
140 return attn_output, attn_weights
142 def get_random_inputs(
143 self,
144 batch_size: int = 2,
145 seq_len: int = 8,
146 device: Optional[torch.device] = None,
147 dtype: Optional[torch.dtype] = None,
148 ) -> Dict[str, Any]:
149 """Generate test inputs including ALiBi tensor and attention mask."""
150 if device is None: 150 ↛ 152line 150 didn't jump to line 152 because the condition on line 150 was always true
151 device = torch.device("cpu")
152 if dtype is None: 152 ↛ 155line 152 didn't jump to line 155 because the condition on line 152 was always true
153 dtype = torch.float32
155 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 2048
156 num_heads = self.config.n_heads if self.config and hasattr(self.config, "n_heads") else 32
158 attention_mask = torch.ones(batch_size, seq_len, device=device, dtype=torch.long)
159 # HF Falcon passes alibi as [batch*heads, 1, seq] — reshape to match
160 alibi_4d = build_alibi_tensor(attention_mask, num_heads, dtype)
161 alibi = alibi_4d.reshape(batch_size * num_heads, 1, seq_len)
163 # Causal mask: [batch, 1, seq, seq]
164 causal_mask = torch.triu(
165 torch.full((seq_len, seq_len), torch.finfo(dtype).min, device=device, dtype=dtype),
166 diagonal=1,
167 )
168 causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1)
170 return {
171 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype),
172 "alibi": alibi,
173 "attention_mask": causal_mask,
174 }