Coverage for transformer_lens/model_bridge/supported_architectures/falcon.py: 64%

109 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Falcon architecture adapter. 

2 

3Supports original Falcon models (7B, 40B, 180B) with: 

4- Parallel attention+MLP (both read same residual input) 

5- Multi-query or grouped-query attention (fused QKV) 

6- RoPE or ALiBi position embeddings 

7""" 

8 

9from typing import Any 

10 

11import torch 

12 

13from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

14from transformer_lens.conversion_utils.param_processing_conversion import ( 

15 ParamProcessingConversion, 

16) 

17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

18from transformer_lens.model_bridge.generalized_components import ( 

19 ALiBiJointQKVAttentionBridge, 

20 BlockBridge, 

21 EmbeddingBridge, 

22 JointQKVPositionEmbeddingsAttentionBridge, 

23 LinearBridge, 

24 MLPBridge, 

25 NormalizationBridge, 

26 ParallelBlockBridge, 

27 RotaryEmbeddingBridge, 

28 UnembeddingBridge, 

29) 

30 

31 

32def _patch_decoder_inplace_add(layer: Any) -> None: 

33 """Patch FalconDecoderLayer.forward to use non-inplace addition. 

34 

35 The original does `mlp_output += attention_output` which modifies 

36 mlp_output inplace, conflicting with backward hooks on mlp.hook_out. 

37 We monkey-patch the forward to use `mlp_output = mlp_output + attention_output`. 

38 """ 

39 import inspect 

40 

41 src = inspect.getsource(type(layer).forward) 

42 

43 # Only patch if the inplace pattern exists 

44 if "mlp_output += attention_output" not in src: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true

45 return 

46 

47 # Get the original forward and wrap it 

48 orig_forward = type(layer).forward 

49 

50 def patched_forward(self: Any, *args: Any, **kwargs: Any) -> Any: 

51 # Call original but intercept mlp_output before inplace add. 

52 # Since we can't modify the source, we use a different approach: 

53 # register a temporary hook on self.mlp that clones output. 

54 clone_handle = self.mlp.register_forward_hook( 

55 lambda _m, _i, o: o.clone() if isinstance(o, torch.Tensor) else o 

56 ) 

57 try: 

58 result = orig_forward(self, *args, **kwargs) 

59 finally: 

60 clone_handle.remove() 

61 return result 

62 

63 layer.forward = patched_forward.__get__(layer, type(layer)) # type: ignore[method-assign] 

64 

65 

66class FalconArchitectureAdapter(ArchitectureAdapter): 

67 """Architecture adapter for Falcon models (FalconForCausalLM).""" 

68 

69 def __init__(self, cfg: Any) -> None: 

70 super().__init__(cfg) 

71 

72 self._is_alibi = getattr(cfg, "alibi", False) 

73 self._is_new_arch = getattr(cfg, "new_decoder_architecture", False) 

74 self._is_multi_query = getattr(cfg, "multi_query", False) 

75 is_parallel = getattr(cfg, "parallel_attn", True) 

76 

77 self.cfg.normalization_type = "LN" 

78 self.cfg.positional_embedding_type = "alibi" if self._is_alibi else "rotary" 

79 self.cfg.parallel_attn_mlp = is_parallel 

80 self.cfg.gated_mlp = False 

81 

82 if self._is_multi_query: 

83 self.cfg.n_key_value_heads = 1 

84 

85 n_kv_heads = self.cfg.n_key_value_heads or self.cfg.n_heads 

86 self.weight_processing_conversions = { 

87 "blocks.{i}.attn.q": ParamProcessingConversion( 

88 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

89 ), 

90 "blocks.{i}.attn.k": ParamProcessingConversion( 

91 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

92 ), 

93 "blocks.{i}.attn.v": ParamProcessingConversion( 

94 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

95 ), 

96 "blocks.{i}.attn.o": ParamProcessingConversion( 

97 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

98 ), 

99 } 

100 

101 ln1_name = "ln_attn" if self._is_new_arch else "input_layernorm" 

102 

103 if self._is_alibi: 103 ↛ 106line 103 didn't jump to line 106 because the condition on line 103 was never true

104 # ALiBi: reimplement attention with ALiBi bias fused into scores. 

105 # Splits fused QKV and fires hooks at each stage for mech interp. 

106 attn_bridge: Any = ALiBiJointQKVAttentionBridge( 

107 name="self_attention", 

108 config=self.cfg, 

109 split_qkv_matrix=self._split_falcon_qkv, 

110 submodules={ 

111 "qkv": LinearBridge(name="query_key_value"), 

112 "o": LinearBridge(name="dense"), 

113 }, 

114 ) 

115 else: 

116 # RoPE: reimplement with position embeddings for hook access 

117 attn_bridge = JointQKVPositionEmbeddingsAttentionBridge( 

118 name="self_attention", 

119 config=self.cfg, 

120 split_qkv_matrix=self._split_falcon_qkv, 

121 submodules={ 

122 "qkv": LinearBridge(name="query_key_value"), 

123 "o": LinearBridge(name="dense"), 

124 }, 

125 ) 

126 

127 block_submodules: dict[str, Any] = { 

128 "ln1": NormalizationBridge(name=ln1_name, config=self.cfg), 

129 "attn": attn_bridge, 

130 "mlp": MLPBridge( 

131 name="mlp", 

132 config=self.cfg, 

133 submodules={ 

134 "in": LinearBridge(name="dense_h_to_4h"), 

135 "out": LinearBridge(name="dense_4h_to_h"), 

136 }, 

137 ), 

138 } 

139 

140 if not is_parallel: 140 ↛ 141line 140 didn't jump to line 141 because the condition on line 140 was never true

141 block_submodules["ln2"] = NormalizationBridge( 

142 name="post_attention_layernorm", config=self.cfg 

143 ) 

144 elif self._is_new_arch and getattr(cfg, "num_ln_in_parallel_attn", None) == 2: 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true

145 block_submodules["ln2"] = NormalizationBridge(name="ln_mlp", config=self.cfg) 

146 

147 # Falcon has both parallel (most checkpoints) and sequential variants. 

148 block_cls = ParallelBlockBridge if is_parallel else BlockBridge 

149 self.component_mapping: dict[str, Any] = { 

150 "embed": EmbeddingBridge(name="transformer.word_embeddings"), 

151 "blocks": block_cls(name="transformer.h", submodules=block_submodules), 

152 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg), 

153 "unembed": UnembeddingBridge(name="lm_head"), 

154 } 

155 

156 if not self._is_alibi: 156 ↛ exitline 156 didn't return from function '__init__' because the condition on line 156 was always true

157 self.component_mapping["rotary_emb"] = RotaryEmbeddingBridge( 

158 name="transformer.rotary_emb", config=self.cfg 

159 ) 

160 

161 def prepare_model(self, hf_model: Any) -> None: 

162 """Patch Falcon modules to avoid backward hook conflicts. 

163 

164 Two issues: 

165 1. FalconLinear does `input @ self.weight.T` where .T is a view — 

166 clone the transpose to break the view chain. 

167 2. FalconDecoderLayer does `mlp_output += attention_output` (inplace) — 

168 this modifies a tensor captured by mlp.hook_out's backward hook. 

169 Patch to use non-inplace addition. 

170 """ 

171 

172 def _make_patched_linear(mod: Any) -> Any: 

173 def patched_forward(input: torch.Tensor) -> torch.Tensor: 

174 hidden_states = input @ mod.weight.T.contiguous() 

175 if mod.bias is not None: 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true

176 hidden_states = hidden_states + mod.bias 

177 return hidden_states 

178 

179 return patched_forward 

180 

181 for module in hf_model.modules(): 

182 if type(module).__name__ == "FalconLinear": 

183 module.forward = _make_patched_linear(module) # type: ignore[method-assign] 

184 

185 # Patch decoder layers to avoid `mlp_output += attention_output` (inplace). 

186 # The patched forward registers a temporary clone hook on self.mlp 

187 # around each forward call, so the inplace += gets a clone, not the 

188 # original tensor captured by backward hooks. 

189 for module in hf_model.modules(): 

190 if type(module).__name__ == "FalconDecoderLayer": 

191 _patch_decoder_inplace_add(module) 

192 

193 def _split_falcon_qkv( 

194 self, original_attention_component: Any 

195 ) -> tuple[torch.nn.Linear, torch.nn.Linear, torch.nn.Linear]: 

196 """Split Falcon's fused query_key_value into separate Q, K, V projections.""" 

197 qkv = original_attention_component.query_key_value 

198 weight = qkv.weight.detach().clone() 

199 d_model = self.cfg.d_model 

200 head_dim = d_model // self.cfg.n_heads 

201 has_bias = qkv.bias is not None 

202 

203 if self._is_new_arch: 203 ↛ 204line 203 didn't jump to line 204 because the condition on line 203 was never true

204 n_kv = self.cfg.n_key_value_heads or self.cfg.n_heads 

205 sizes = [self.cfg.n_heads * head_dim, n_kv * head_dim, n_kv * head_dim] 

206 W_Q, W_K, W_V = torch.split(weight, sizes, dim=0) 

207 b_Q: torch.Tensor | None 

208 b_K: torch.Tensor | None 

209 b_V: torch.Tensor | None 

210 if has_bias: 

211 b_Q, b_K, b_V = torch.split(qkv.bias.detach().clone(), sizes, dim=0) 

212 else: 

213 b_Q = b_K = b_V = None 

214 elif self._is_multi_query: 214 ↛ 225line 214 didn't jump to line 225 because the condition on line 214 was always true

215 sizes = [d_model, head_dim, head_dim] 

216 W_Q, W_K, W_V = torch.split(weight, sizes, dim=0) 

217 if has_bias: 217 ↛ 218line 217 didn't jump to line 218 because the condition on line 217 was never true

218 b_Q, b_K, b_V = torch.split(qkv.bias.detach().clone(), sizes, dim=0) 

219 else: 

220 b_Q = b_K = b_V = None 

221 else: 

222 # Non-multi-query, non-new-arch: QKV is interleaved per head. 

223 # Weight layout: [Q_h0, K_h0, V_h0, Q_h1, K_h1, V_h1, ...] 

224 # Each chunk is head_dim rows. Deinterleave to [Q_all, K_all, V_all]. 

225 n_heads = self.cfg.n_heads 

226 weight_heads = weight.view(n_heads, 3, head_dim, d_model) 

227 W_Q = weight_heads[:, 0, :, :].reshape(d_model, d_model) 

228 W_K = weight_heads[:, 1, :, :].reshape(d_model, d_model) 

229 W_V = weight_heads[:, 2, :, :].reshape(d_model, d_model) 

230 if has_bias: 

231 bias = qkv.bias.detach().clone() 

232 bias_heads = bias.view(n_heads, 3, head_dim) 

233 b_Q = bias_heads[:, 0, :].reshape(d_model) 

234 b_K = bias_heads[:, 1, :].reshape(d_model) 

235 b_V = bias_heads[:, 2, :].reshape(d_model) 

236 else: 

237 b_Q = b_K = b_V = None 

238 

239 def build_linear( 

240 w: torch.Tensor, b: torch.Tensor | None, out_features: int 

241 ) -> torch.nn.Linear: 

242 linear = torch.nn.Linear( 

243 d_model, out_features, bias=b is not None, device=w.device, dtype=w.dtype 

244 ) 

245 linear.weight = torch.nn.Parameter(w.contiguous()) 

246 if b is not None: 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true

247 linear.bias = torch.nn.Parameter(b.contiguous()) 

248 return linear 

249 

250 return ( 

251 build_linear(W_Q, b_Q, W_Q.shape[0]), 

252 build_linear(W_K, b_K, W_K.shape[0]), 

253 build_linear(W_V, b_V, W_V.shape[0]), 

254 ) 

255 

256 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

257 """Set up rotary embedding references for component testing.""" 

258 if self._is_alibi: 

259 return # ALiBi handled by HF natively 

260 

261 rotary_emb = hf_model.transformer.rotary_emb 

262 

263 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

264 for block in bridge_model.blocks: 

265 if hasattr(block, "attn"): 

266 block.attn.set_rotary_emb(rotary_emb) 

267 

268 attn_bridge = self.get_generalized_component("blocks.0.attn") 

269 attn_bridge.set_rotary_emb(rotary_emb)