Coverage for transformer_lens/model_bridge/supported_architectures/falcon.py: 64%
109 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Falcon architecture adapter.
3Supports original Falcon models (7B, 40B, 180B) with:
4- Parallel attention+MLP (both read same residual input)
5- Multi-query or grouped-query attention (fused QKV)
6- RoPE or ALiBi position embeddings
7"""
9from typing import Any
11import torch
13from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
14from transformer_lens.conversion_utils.param_processing_conversion import (
15 ParamProcessingConversion,
16)
17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
18from transformer_lens.model_bridge.generalized_components import (
19 ALiBiJointQKVAttentionBridge,
20 BlockBridge,
21 EmbeddingBridge,
22 JointQKVPositionEmbeddingsAttentionBridge,
23 LinearBridge,
24 MLPBridge,
25 NormalizationBridge,
26 ParallelBlockBridge,
27 RotaryEmbeddingBridge,
28 UnembeddingBridge,
29)
32def _patch_decoder_inplace_add(layer: Any) -> None:
33 """Patch FalconDecoderLayer.forward to use non-inplace addition.
35 The original does `mlp_output += attention_output` which modifies
36 mlp_output inplace, conflicting with backward hooks on mlp.hook_out.
37 We monkey-patch the forward to use `mlp_output = mlp_output + attention_output`.
38 """
39 import inspect
41 src = inspect.getsource(type(layer).forward)
43 # Only patch if the inplace pattern exists
44 if "mlp_output += attention_output" not in src: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true
45 return
47 # Get the original forward and wrap it
48 orig_forward = type(layer).forward
50 def patched_forward(self: Any, *args: Any, **kwargs: Any) -> Any:
51 # Call original but intercept mlp_output before inplace add.
52 # Since we can't modify the source, we use a different approach:
53 # register a temporary hook on self.mlp that clones output.
54 clone_handle = self.mlp.register_forward_hook(
55 lambda _m, _i, o: o.clone() if isinstance(o, torch.Tensor) else o
56 )
57 try:
58 result = orig_forward(self, *args, **kwargs)
59 finally:
60 clone_handle.remove()
61 return result
63 layer.forward = patched_forward.__get__(layer, type(layer)) # type: ignore[method-assign]
66class FalconArchitectureAdapter(ArchitectureAdapter):
67 """Architecture adapter for Falcon models (FalconForCausalLM)."""
69 def __init__(self, cfg: Any) -> None:
70 super().__init__(cfg)
72 self._is_alibi = getattr(cfg, "alibi", False)
73 self._is_new_arch = getattr(cfg, "new_decoder_architecture", False)
74 self._is_multi_query = getattr(cfg, "multi_query", False)
75 is_parallel = getattr(cfg, "parallel_attn", True)
77 self.cfg.normalization_type = "LN"
78 self.cfg.positional_embedding_type = "alibi" if self._is_alibi else "rotary"
79 self.cfg.parallel_attn_mlp = is_parallel
80 self.cfg.gated_mlp = False
82 if self._is_multi_query:
83 self.cfg.n_key_value_heads = 1
85 n_kv_heads = self.cfg.n_key_value_heads or self.cfg.n_heads
86 self.weight_processing_conversions = {
87 "blocks.{i}.attn.q": ParamProcessingConversion(
88 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
89 ),
90 "blocks.{i}.attn.k": ParamProcessingConversion(
91 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
92 ),
93 "blocks.{i}.attn.v": ParamProcessingConversion(
94 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
95 ),
96 "blocks.{i}.attn.o": ParamProcessingConversion(
97 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
98 ),
99 }
101 ln1_name = "ln_attn" if self._is_new_arch else "input_layernorm"
103 if self._is_alibi: 103 ↛ 106line 103 didn't jump to line 106 because the condition on line 103 was never true
104 # ALiBi: reimplement attention with ALiBi bias fused into scores.
105 # Splits fused QKV and fires hooks at each stage for mech interp.
106 attn_bridge: Any = ALiBiJointQKVAttentionBridge(
107 name="self_attention",
108 config=self.cfg,
109 split_qkv_matrix=self._split_falcon_qkv,
110 submodules={
111 "qkv": LinearBridge(name="query_key_value"),
112 "o": LinearBridge(name="dense"),
113 },
114 )
115 else:
116 # RoPE: reimplement with position embeddings for hook access
117 attn_bridge = JointQKVPositionEmbeddingsAttentionBridge(
118 name="self_attention",
119 config=self.cfg,
120 split_qkv_matrix=self._split_falcon_qkv,
121 submodules={
122 "qkv": LinearBridge(name="query_key_value"),
123 "o": LinearBridge(name="dense"),
124 },
125 )
127 block_submodules: dict[str, Any] = {
128 "ln1": NormalizationBridge(name=ln1_name, config=self.cfg),
129 "attn": attn_bridge,
130 "mlp": MLPBridge(
131 name="mlp",
132 config=self.cfg,
133 submodules={
134 "in": LinearBridge(name="dense_h_to_4h"),
135 "out": LinearBridge(name="dense_4h_to_h"),
136 },
137 ),
138 }
140 if not is_parallel: 140 ↛ 141line 140 didn't jump to line 141 because the condition on line 140 was never true
141 block_submodules["ln2"] = NormalizationBridge(
142 name="post_attention_layernorm", config=self.cfg
143 )
144 elif self._is_new_arch and getattr(cfg, "num_ln_in_parallel_attn", None) == 2: 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true
145 block_submodules["ln2"] = NormalizationBridge(name="ln_mlp", config=self.cfg)
147 # Falcon has both parallel (most checkpoints) and sequential variants.
148 block_cls = ParallelBlockBridge if is_parallel else BlockBridge
149 self.component_mapping: dict[str, Any] = {
150 "embed": EmbeddingBridge(name="transformer.word_embeddings"),
151 "blocks": block_cls(name="transformer.h", submodules=block_submodules),
152 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg),
153 "unembed": UnembeddingBridge(name="lm_head"),
154 }
156 if not self._is_alibi: 156 ↛ exitline 156 didn't return from function '__init__' because the condition on line 156 was always true
157 self.component_mapping["rotary_emb"] = RotaryEmbeddingBridge(
158 name="transformer.rotary_emb", config=self.cfg
159 )
161 def prepare_model(self, hf_model: Any) -> None:
162 """Patch Falcon modules to avoid backward hook conflicts.
164 Two issues:
165 1. FalconLinear does `input @ self.weight.T` where .T is a view —
166 clone the transpose to break the view chain.
167 2. FalconDecoderLayer does `mlp_output += attention_output` (inplace) —
168 this modifies a tensor captured by mlp.hook_out's backward hook.
169 Patch to use non-inplace addition.
170 """
172 def _make_patched_linear(mod: Any) -> Any:
173 def patched_forward(input: torch.Tensor) -> torch.Tensor:
174 hidden_states = input @ mod.weight.T.contiguous()
175 if mod.bias is not None: 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true
176 hidden_states = hidden_states + mod.bias
177 return hidden_states
179 return patched_forward
181 for module in hf_model.modules():
182 if type(module).__name__ == "FalconLinear":
183 module.forward = _make_patched_linear(module) # type: ignore[method-assign]
185 # Patch decoder layers to avoid `mlp_output += attention_output` (inplace).
186 # The patched forward registers a temporary clone hook on self.mlp
187 # around each forward call, so the inplace += gets a clone, not the
188 # original tensor captured by backward hooks.
189 for module in hf_model.modules():
190 if type(module).__name__ == "FalconDecoderLayer":
191 _patch_decoder_inplace_add(module)
193 def _split_falcon_qkv(
194 self, original_attention_component: Any
195 ) -> tuple[torch.nn.Linear, torch.nn.Linear, torch.nn.Linear]:
196 """Split Falcon's fused query_key_value into separate Q, K, V projections."""
197 qkv = original_attention_component.query_key_value
198 weight = qkv.weight.detach().clone()
199 d_model = self.cfg.d_model
200 head_dim = d_model // self.cfg.n_heads
201 has_bias = qkv.bias is not None
203 if self._is_new_arch: 203 ↛ 204line 203 didn't jump to line 204 because the condition on line 203 was never true
204 n_kv = self.cfg.n_key_value_heads or self.cfg.n_heads
205 sizes = [self.cfg.n_heads * head_dim, n_kv * head_dim, n_kv * head_dim]
206 W_Q, W_K, W_V = torch.split(weight, sizes, dim=0)
207 b_Q: torch.Tensor | None
208 b_K: torch.Tensor | None
209 b_V: torch.Tensor | None
210 if has_bias:
211 b_Q, b_K, b_V = torch.split(qkv.bias.detach().clone(), sizes, dim=0)
212 else:
213 b_Q = b_K = b_V = None
214 elif self._is_multi_query: 214 ↛ 225line 214 didn't jump to line 225 because the condition on line 214 was always true
215 sizes = [d_model, head_dim, head_dim]
216 W_Q, W_K, W_V = torch.split(weight, sizes, dim=0)
217 if has_bias: 217 ↛ 218line 217 didn't jump to line 218 because the condition on line 217 was never true
218 b_Q, b_K, b_V = torch.split(qkv.bias.detach().clone(), sizes, dim=0)
219 else:
220 b_Q = b_K = b_V = None
221 else:
222 # Non-multi-query, non-new-arch: QKV is interleaved per head.
223 # Weight layout: [Q_h0, K_h0, V_h0, Q_h1, K_h1, V_h1, ...]
224 # Each chunk is head_dim rows. Deinterleave to [Q_all, K_all, V_all].
225 n_heads = self.cfg.n_heads
226 weight_heads = weight.view(n_heads, 3, head_dim, d_model)
227 W_Q = weight_heads[:, 0, :, :].reshape(d_model, d_model)
228 W_K = weight_heads[:, 1, :, :].reshape(d_model, d_model)
229 W_V = weight_heads[:, 2, :, :].reshape(d_model, d_model)
230 if has_bias:
231 bias = qkv.bias.detach().clone()
232 bias_heads = bias.view(n_heads, 3, head_dim)
233 b_Q = bias_heads[:, 0, :].reshape(d_model)
234 b_K = bias_heads[:, 1, :].reshape(d_model)
235 b_V = bias_heads[:, 2, :].reshape(d_model)
236 else:
237 b_Q = b_K = b_V = None
239 def build_linear(
240 w: torch.Tensor, b: torch.Tensor | None, out_features: int
241 ) -> torch.nn.Linear:
242 linear = torch.nn.Linear(
243 d_model, out_features, bias=b is not None, device=w.device, dtype=w.dtype
244 )
245 linear.weight = torch.nn.Parameter(w.contiguous())
246 if b is not None: 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true
247 linear.bias = torch.nn.Parameter(b.contiguous())
248 return linear
250 return (
251 build_linear(W_Q, b_Q, W_Q.shape[0]),
252 build_linear(W_K, b_K, W_K.shape[0]),
253 build_linear(W_V, b_V, W_V.shape[0]),
254 )
256 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
257 """Set up rotary embedding references for component testing."""
258 if self._is_alibi:
259 return # ALiBi handled by HF natively
261 rotary_emb = hf_model.transformer.rotary_emb
263 if bridge_model is not None and hasattr(bridge_model, "blocks"):
264 for block in bridge_model.blocks:
265 if hasattr(block, "attn"):
266 block.attn.set_rotary_emb(rotary_emb)
268 attn_bridge = self.get_generalized_component("blocks.0.attn")
269 attn_bridge.set_rotary_emb(rotary_emb)