Coverage for transformer_lens/model_bridge/generalized_components/alibi_joint_qkv_attention.py: 83%

69 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""ALiBi joint QKV attention bridge component. 

2 

3Handles models that use ALiBi (Attention with Linear Biases) with fused QKV projections. 

4Splits fused QKV, reimplements attention with ALiBi bias and hooks at each stage. 

5""" 

6 

7from __future__ import annotations 

8 

9from typing import Any, Dict, Optional 

10 

11import torch 

12 

13from transformer_lens.model_bridge.generalized_components.alibi_utils import ( 

14 build_alibi_tensor, 

15) 

16from transformer_lens.model_bridge.generalized_components.base import ( 

17 GeneralizedComponent, 

18) 

19from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import ( 

20 JointQKVAttentionBridge, 

21) 

22 

23 

24class ALiBiJointQKVAttentionBridge(JointQKVAttentionBridge): 

25 """Attention bridge for models using ALiBi position encoding with fused QKV. 

26 

27 Splits fused QKV, reimplements attention with ALiBi bias fused into scores, 

28 and fires hooks at each stage (hook_q, hook_k, hook_v, hook_attn_scores, 

29 hook_pattern). ALiBi bias is added to raw attention scores before scaling. 

30 """ 

31 

32 def __init__( 

33 self, 

34 name: str, 

35 config: Any, 

36 split_qkv_matrix: Any = None, 

37 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

38 **kwargs: Any, 

39 ): 

40 super().__init__( 

41 name=name, 

42 config=config, 

43 split_qkv_matrix=split_qkv_matrix, 

44 submodules=submodules, 

45 requires_position_embeddings=False, 

46 requires_attention_mask=False, 

47 **kwargs, 

48 ) 

49 

50 def forward(self, *args: Any, **kwargs: Any) -> Any: 

51 """Forward pass: split QKV, apply ALiBi, fire hooks.""" 

52 if self.original_component is None: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true

53 raise RuntimeError( 

54 f"Original component not set for {self.name}. " 

55 "Call set_original_component() first." 

56 ) 

57 

58 if len(args) > 0 and isinstance(args[0], torch.Tensor): 58 ↛ 60line 58 didn't jump to line 60 because the condition on line 58 was always true

59 hidden_states = args[0] 

60 elif "hidden_states" in kwargs: 

61 hidden_states = kwargs.pop("hidden_states") 

62 else: 

63 raise ValueError("Could not find hidden_states in args or kwargs") 

64 

65 hooked_input = self.hook_in(hidden_states) 

66 

67 # Split fused QKV via parent's split mechanism 

68 q_output = self.q(hooked_input) 

69 k_output = self.k(hooked_input) 

70 v_output = self.v(hooked_input) 

71 

72 attn_output, attn_weights = self._reconstruct_attention( 

73 q_output, k_output, v_output, **kwargs 

74 ) 

75 

76 output = self.hook_out(attn_output) 

77 return output, attn_weights 

78 

79 def _reconstruct_attention( 

80 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs: Any 

81 ) -> tuple[torch.Tensor, torch.Tensor]: 

82 """Reconstruct attention with ALiBi bias fused into scores.""" 

83 num_heads = self.config.n_heads if self.config else 32 

84 num_kv_heads = getattr(self.config, "n_key_value_heads", None) or num_heads 

85 q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads( 

86 q, k, v, num_heads, num_kv_heads 

87 ) 

88 

89 # GQA/MQA: expand K/V heads to match Q heads 

90 if num_kv_heads != num_heads: 

91 n_rep = num_heads // num_kv_heads 

92 k = k.repeat_interleave(n_rep, dim=1) 

93 v = v.repeat_interleave(n_rep, dim=1) 

94 

95 inv_norm_factor = head_dim**-0.5 

96 

97 # Raw attention scores 

98 attn_scores = torch.matmul(q, k.transpose(-2, -1)) 

99 attn_scores = attn_scores.view(batch_size, num_heads, seq_len, -1) 

100 

101 # Upcast for numerical stability 

102 input_dtype = attn_scores.dtype 

103 if input_dtype in (torch.float16, torch.bfloat16): 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true

104 attn_scores = attn_scores.to(torch.float32) 

105 

106 # Add ALiBi bias 

107 alibi = kwargs.get("alibi", None) 

108 if alibi is not None: 

109 kv_len = attn_scores.shape[-1] 

110 alibi_view = alibi.view(batch_size, num_heads, 1, -1) 

111 if alibi_view.shape[-1] > kv_len: 111 ↛ 112line 111 didn't jump to line 112 because the condition on line 111 was never true

112 alibi_view = alibi_view[..., :kv_len] 

113 attn_scores = attn_scores + alibi_view 

114 

115 # Scale after ALiBi (matches HF Falcon) 

116 attn_scores = attn_scores * inv_norm_factor 

117 

118 # Add attention mask 

119 attention_mask = kwargs.get("attention_mask", None) 

120 if attention_mask is not None: 120 ↛ 123line 120 didn't jump to line 123 because the condition on line 120 was always true

121 attn_scores = attn_scores + attention_mask[:, :, :, : attn_scores.shape[-1]] 

122 

123 attn_scores = self.hook_attn_scores(attn_scores) 

124 

125 attn_weights = self._softmax_dropout_pattern( 

126 attn_scores, upcast_to_fp32=True, target_dtype=q.dtype 

127 ) 

128 

129 # Weighted sum 

130 attn_output = torch.matmul( 

131 attn_weights.view(batch_size * num_heads, seq_len, -1), 

132 v.reshape(batch_size * num_heads, -1, head_dim), 

133 ) 

134 attn_output = attn_output.view(batch_size, num_heads, seq_len, head_dim) 

135 attn_output = self._reshape_attn_output( 

136 attn_output, batch_size, seq_len, num_heads, head_dim 

137 ) 

138 attn_output = self._apply_output_projection(attn_output) 

139 

140 return attn_output, attn_weights 

141 

142 def get_random_inputs( 

143 self, 

144 batch_size: int = 2, 

145 seq_len: int = 8, 

146 device: Optional[torch.device] = None, 

147 dtype: Optional[torch.dtype] = None, 

148 ) -> Dict[str, Any]: 

149 """Generate test inputs including ALiBi tensor and attention mask.""" 

150 if device is None: 150 ↛ 152line 150 didn't jump to line 152 because the condition on line 150 was always true

151 device = torch.device("cpu") 

152 if dtype is None: 152 ↛ 155line 152 didn't jump to line 155 because the condition on line 152 was always true

153 dtype = torch.float32 

154 

155 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 2048 

156 num_heads = self.config.n_heads if self.config and hasattr(self.config, "n_heads") else 32 

157 

158 attention_mask = torch.ones(batch_size, seq_len, device=device, dtype=torch.long) 

159 # HF Falcon passes alibi as [batch*heads, 1, seq] — reshape to match 

160 alibi_4d = build_alibi_tensor(attention_mask, num_heads, dtype) 

161 alibi = alibi_4d.reshape(batch_size * num_heads, 1, seq_len) 

162 

163 # Causal mask: [batch, 1, seq, seq] 

164 causal_mask = torch.triu( 

165 torch.full((seq_len, seq_len), torch.finfo(dtype).min, device=device, dtype=dtype), 

166 diagonal=1, 

167 ) 

168 causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1) 

169 

170 return { 

171 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype), 

172 "alibi": alibi, 

173 "attention_mask": causal_mask, 

174 }