Coverage for transformer_lens/model_bridge/sources/native/model.py: 93%
240 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""TL-native transformer for TransformerBridge — minimal, no HF/HT dependency.
3Cfg-driven features: ``normalization_type`` (LN / RMS / RMSPre), ``final_rms``,
4``gated_mlp``, ``attn_only``, ``n_key_value_heads`` (GQA), ``attn_scores_soft_cap``,
5``output_logits_soft_cap``, ``positional_embedding_type`` (standard / rotary),
6``rotary_dim`` / ``rotary_base`` / ``rope_scaling`` (linear PI, dynamic/NTK,
7llama3 by-parts).
8"""
9from __future__ import annotations
11import math
12from typing import Callable, Optional, cast
14import torch
15import torch.nn as nn
16import torch.nn.functional as F
18from transformer_lens.config import TransformerBridgeConfig
20# gelu_new = the tanh-approximation HF GPT-2 / HT use; F.gelu(approximate="tanh")
21# is the exact same formula.
22_Activation = Callable[[torch.Tensor], torch.Tensor]
23_ACTIVATIONS: dict[str, _Activation] = {
24 "gelu": F.gelu,
25 "gelu_new": lambda x: F.gelu(x, approximate="tanh"),
26 "relu": F.relu,
27 "silu": F.silu,
28 "swish": F.silu,
29}
32def _uses_rms_norm(cfg: TransformerBridgeConfig) -> bool:
33 return (cfg.normalization_type or "LN").upper() in ("RMS", "RMSPRE")
36def _positional_kind(cfg: TransformerBridgeConfig) -> str:
37 return (getattr(cfg, "positional_embedding_type", None) or "standard").lower()
40class NativeRMSNorm(nn.Module):
41 """Llama-style RMSNorm. Variance in fp32 regardless of input dtype, then
42 cast back before the per-channel scale (matches HF LlamaRMSNorm)."""
44 def __init__(self, d_model: int, eps: float = 1e-5):
45 super().__init__()
46 self.weight = nn.Parameter(torch.ones(d_model))
47 self.eps = eps
49 def forward(self, x: torch.Tensor) -> torch.Tensor:
50 input_dtype = x.dtype
51 x_fp32 = x.to(torch.float32)
52 rms_inv = torch.rsqrt(x_fp32.pow(2).mean(dim=-1, keepdim=True) + self.eps)
53 normalized = (x_fp32 * rms_inv).to(input_dtype)
54 return self.weight * normalized
57def _make_norm(cfg: TransformerBridgeConfig, *, force_rms: bool = False) -> nn.Module:
58 if force_rms or _uses_rms_norm(cfg):
59 return NativeRMSNorm(cfg.d_model, eps=cfg.eps)
60 return nn.LayerNorm(cfg.d_model, eps=cfg.eps)
63def _resolve_rope_scaling(
64 cfg: TransformerBridgeConfig, rotary_dim: int
65) -> tuple[float, float, torch.Tensor]:
66 """Returns (effective_base, position_scale, inv_freq) per cfg.rope_scaling."""
67 base = float(cfg.rotary_base)
68 rope_scaling = getattr(cfg, "rope_scaling", None)
69 inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim))
71 if not isinstance(rope_scaling, dict):
72 return base, 1.0, inv_freq
74 # Newer HF configs key on "rope_type"; older ones on "type".
75 scale_type = str(rope_scaling.get("rope_type") or rope_scaling.get("type") or "").lower()
76 factor = float(rope_scaling.get("factor", 1.0))
78 if scale_type in ("", "default") or factor <= 1.0:
79 return base, 1.0, inv_freq
81 if scale_type == "linear":
82 return base, factor, inv_freq
84 if scale_type in ("dynamic", "ntk"):
85 scaled_base = base * (factor ** (rotary_dim / (rotary_dim - 2)))
86 new_inv_freq = 1.0 / (scaled_base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim))
87 return scaled_base, 1.0, new_inv_freq
89 if scale_type == "llama3":
90 low_freq_factor = float(rope_scaling.get("low_freq_factor", 1.0))
91 high_freq_factor = float(rope_scaling.get("high_freq_factor", 4.0))
92 original_ctx = float(
93 rope_scaling.get("original_max_position_embeddings")
94 or rope_scaling.get("original_context_length")
95 or 8192
96 )
97 low_wavelen = original_ctx / low_freq_factor
98 high_wavelen = original_ctx / high_freq_factor
99 wavelens = 2 * math.pi / inv_freq
100 # Three regimes: low-freq → divide by factor; high-freq → unchanged;
101 # in-between → smooth linear interpolation between the two.
102 smooth = (original_ctx / wavelens - low_freq_factor) / (high_freq_factor - low_freq_factor)
103 new_inv_freq = torch.where(
104 wavelens > low_wavelen,
105 inv_freq / factor,
106 torch.where(
107 wavelens < high_wavelen,
108 inv_freq,
109 (1 - smooth) * inv_freq / factor + smooth * inv_freq,
110 ),
111 )
112 return base, 1.0, new_inv_freq
114 raise NotImplementedError(
115 f"rope_scaling type {scale_type!r} is not supported. "
116 f"Supported: 'linear', 'dynamic'/'ntk', 'llama3'."
117 )
120class NativeRotary(nn.Module):
121 """Shared cos/sin tables for RoPE. Honors ``cfg.rope_scaling``."""
123 # Declared so mypy sees the buffer dtype; register_buffer alone reports Module|Tensor.
124 cos_cached: torch.Tensor
125 sin_cached: torch.Tensor
127 def __init__(self, cfg: TransformerBridgeConfig):
128 super().__init__()
129 rotary_dim = cfg.rotary_dim if cfg.rotary_dim is not None else cfg.d_head
130 if rotary_dim <= 0 or rotary_dim % 2 != 0: 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true
131 raise ValueError(f"rotary_dim must be a positive even integer, got {rotary_dim!r}")
132 self.rotary_dim = rotary_dim
134 base, position_scale, inv_freq = _resolve_rope_scaling(cfg, rotary_dim)
136 positions = torch.arange(cfg.n_ctx).float() / position_scale
137 freqs = torch.outer(positions, inv_freq)
138 # Llama/HF adjacent-pair format: each (2i, 2i+1) pair rotates together.
139 cos = freqs.cos().repeat_interleave(2, dim=-1)
140 sin = freqs.sin().repeat_interleave(2, dim=-1)
141 self.register_buffer("cos_cached", cos, persistent=False)
142 self.register_buffer("sin_cached", sin, persistent=False)
143 self.effective_base = base
144 self.position_scale = position_scale
146 @staticmethod
147 def _rotate_half(x: torch.Tensor) -> torch.Tensor:
148 # Llama-style adjacent-pair rotation: (x0, x1) -> (-x1, x0).
149 x1 = x[..., 0::2]
150 x2 = x[..., 1::2]
151 rot = torch.stack((-x2, x1), dim=-1)
152 return rot.flatten(-2)
154 def apply_rope(
155 self,
156 q: torch.Tensor,
157 k: torch.Tensor,
158 *,
159 position_ids: Optional[torch.Tensor] = None,
160 ) -> tuple[torch.Tensor, torch.Tensor]:
161 """Apply RoPE to Q/K of shape [batch, heads, seq, d_head].
163 Named ``apply_rope`` rather than ``apply`` so ``nn.Module.apply(fn)``
164 — PyTorch's recursive function-application utility used by
165 ``bridge.apply(init_fn)`` — isn't shadowed.
166 """
167 seq = q.shape[-2]
168 rd = self.rotary_dim
169 if position_ids is None: 169 ↛ 170line 169 didn't jump to line 170 because the condition on line 169 was never true
170 cos = self.cos_cached[:seq].to(q.dtype)
171 sin = self.sin_cached[:seq].to(q.dtype)
172 else:
173 # [batch, seq] -> [batch, 1, seq, rd] (head dim for broadcast).
174 cos = self.cos_cached[position_ids].to(q.dtype).unsqueeze(1)
175 sin = self.sin_cached[position_ids].to(q.dtype).unsqueeze(1)
177 def _rope(x: torch.Tensor) -> torch.Tensor:
178 x_rot, x_pass = x[..., :rd], x[..., rd:]
179 x_rot = x_rot * cos + self._rotate_half(x_rot) * sin
180 return torch.cat([x_rot, x_pass], dim=-1) if x_pass.shape[-1] else x_rot
182 return _rope(q), _rope(k)
185class NativeAttention(nn.Module):
186 """Split-QKV causal self-attention. Returns (out, pattern); AttentionBridge
187 fires ``hook_pattern`` off the second element."""
189 causal_mask: torch.Tensor
191 def __init__(self, cfg: TransformerBridgeConfig, rotary: Optional[NativeRotary] = None):
192 super().__init__()
193 self.cfg = cfg
194 self.n_heads = cfg.n_heads
195 self.d_head = cfg.d_head
196 self.d_model = cfg.d_model
197 self.n_kv_heads = cfg.n_key_value_heads or cfg.n_heads
198 if self.n_heads % self.n_kv_heads != 0: 198 ↛ 199line 198 didn't jump to line 199 because the condition on line 198 was never true
199 raise ValueError(
200 f"n_heads ({self.n_heads}) must be divisible by n_key_value_heads "
201 f"({self.n_kv_heads}) for GQA."
202 )
203 self.kv_repeats = self.n_heads // self.n_kv_heads
205 q_dim = self.n_heads * self.d_head
206 kv_dim = self.n_kv_heads * self.d_head
207 self.q = nn.Linear(cfg.d_model, q_dim, bias=True)
208 self.k = nn.Linear(cfg.d_model, kv_dim, bias=True)
209 self.v = nn.Linear(cfg.d_model, kv_dim, bias=True)
210 self.o = nn.Linear(q_dim, cfg.d_model, bias=True)
212 mask = torch.triu(torch.ones(cfg.n_ctx, cfg.n_ctx, dtype=torch.bool), diagonal=1)
213 self.register_buffer("causal_mask", mask, persistent=False)
215 # attn_scale=1.0 reads like "standard scaling" but is "divide by 1" —
216 # i.e. unscaled scores, which saturate softmax for d_head>1.
217 if cfg.use_attn_scale and cfg.attn_scale > 0:
218 if self.d_head > 1 and math.isclose(cfg.attn_scale, 1.0, abs_tol=1e-9):
219 raise ValueError(
220 f"attn_scale=1.0 with d_head={self.d_head} (>1) is unscaled "
221 f"attention; softmax will saturate. For standard scaling "
222 f"leave attn_scale at -1 (sentinel for sqrt(d_head))."
223 )
224 scale = cfg.attn_scale
225 else:
226 scale = math.sqrt(cfg.d_head)
227 self.scale = scale
228 self.rotary = rotary
229 self.attn_scores_soft_cap = float(cfg.attn_scores_soft_cap)
231 def forward(
232 self,
233 hidden_states: torch.Tensor,
234 attention_mask: Optional[torch.Tensor] = None,
235 position_ids: Optional[torch.Tensor] = None,
236 **kwargs,
237 ) -> tuple[torch.Tensor, torch.Tensor]:
238 batch, seq, _ = hidden_states.shape
240 q = self.q(hidden_states).view(batch, seq, self.n_heads, self.d_head).transpose(1, 2)
241 k = self.k(hidden_states).view(batch, seq, self.n_kv_heads, self.d_head).transpose(1, 2)
242 v = self.v(hidden_states).view(batch, seq, self.n_kv_heads, self.d_head).transpose(1, 2)
244 if self.rotary is not None:
245 q, k = self.rotary.apply_rope(q, k, position_ids=position_ids)
247 # GQA: repeat_interleave matches HF Llama's repeat_kv group ordering.
248 if self.kv_repeats > 1:
249 k = k.repeat_interleave(self.kv_repeats, dim=1)
250 v = v.repeat_interleave(self.kv_repeats, dim=1)
252 scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
253 # Gemma2 soft-cap before the causal mask so masked positions stay -inf.
254 if self.attn_scores_soft_cap > 0:
255 c = self.attn_scores_soft_cap
256 scores = c * torch.tanh(scores / c)
258 block_mask = self.causal_mask[:seq, :seq]
259 if attention_mask is not None:
260 block_mask = self._combine_attention_mask(block_mask, attention_mask, batch=batch)
261 scores = scores.masked_fill(block_mask, float("-inf"))
263 pattern = F.softmax(scores, dim=-1)
265 attn = torch.matmul(pattern, v).transpose(1, 2).contiguous().view(batch, seq, -1)
266 out = self.o(attn)
267 return out, pattern
269 @staticmethod
270 def _combine_attention_mask(
271 block_mask: torch.Tensor, attention_mask: torch.Tensor, *, batch: int
272 ) -> torch.Tensor:
273 """Combine an external attention_mask with the causal mask.
275 Accepts 2D HF padding mask ``[batch, seq]`` (1=keep, 0=mask), 4D bool
276 mask (True=mask), or 4D additive float mask (HF generation style; values
277 below -1 treated as masked).
278 """
279 if attention_mask.dim() == 2:
280 pad_mask = ~attention_mask.bool()
281 return block_mask | pad_mask[:, None, None, :]
282 if attention_mask.dim() == 4: 282 ↛ 283line 282 didn't jump to line 283 because the condition on line 282 was never true
283 if attention_mask.dtype is torch.bool:
284 return block_mask | attention_mask
285 # HF additive masks use -inf or large negatives; benign biases bounded.
286 extra = attention_mask < -1.0
287 return block_mask | extra
288 raise ValueError(
289 f"attention_mask must be 2D [batch, seq] or 4D [batch, *, seq, seq], "
290 f"got shape {tuple(attention_mask.shape)}."
291 )
294class NativeMLP(nn.Module):
295 """Two-layer MLP with configurable activation."""
297 act: Callable[[torch.Tensor], torch.Tensor]
299 def __init__(self, cfg: TransformerBridgeConfig):
300 super().__init__()
301 assert cfg.d_mlp is not None, "NativeModel resolves d_mlp before instantiating MLPs"
302 d_mlp: int = cfg.d_mlp
303 self.fc_in = nn.Linear(cfg.d_model, d_mlp, bias=True)
304 self.fc_out = nn.Linear(d_mlp, cfg.d_model, bias=True)
305 act_name = (cfg.act_fn or "gelu").lower()
306 if act_name not in _ACTIVATIONS: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true
307 raise ValueError(f"Unsupported act_fn={act_name!r}. Supported: {sorted(_ACTIVATIONS)}")
308 self.act = _ACTIVATIONS[act_name]
310 def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
311 return self.fc_out(self.act(self.fc_in(hidden_states)))
314class NativeGatedMLP(nn.Module):
315 """SwiGLU / ReGLU / GeGLU gated MLP (variant picked by ``cfg.act_fn``).
317 Submodules ``gate`` / ``in`` / ``out`` match GatedMLPBridge's expected slots.
318 """
320 act: Callable[[torch.Tensor], torch.Tensor]
322 def __init__(self, cfg: TransformerBridgeConfig):
323 super().__init__()
324 assert cfg.d_mlp is not None, "NativeModel resolves d_mlp before instantiating MLPs"
325 d_mlp: int = cfg.d_mlp
326 # Llama convention: no biases on gated MLP projections.
327 self.gate = nn.Linear(cfg.d_model, d_mlp, bias=False)
328 # ``in`` is a Python keyword; add_module + getattr(self, "in") works
329 # because the bridge resolves LinearBridge(name="in") the same way.
330 self.add_module("in", nn.Linear(cfg.d_model, d_mlp, bias=False))
331 self.out = nn.Linear(d_mlp, cfg.d_model, bias=False)
332 # Default to SwiGLU; mirror NativeMLP's dispatch so a typo'd act_fn
333 # raises instead of silently changing the model.
334 act_name = (cfg.act_fn or "silu").lower()
335 if act_name not in _ACTIVATIONS:
336 raise ValueError(f"Unsupported act_fn={act_name!r}. Supported: {sorted(_ACTIVATIONS)}")
337 self.act = _ACTIVATIONS[act_name]
339 def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
340 gate_out = self.act(self.gate(hidden_states))
341 in_proj = cast(nn.Linear, getattr(self, "in"))
342 up_out = in_proj(hidden_states)
343 return self.out(gate_out * up_out)
346class NativeBlock(nn.Module):
347 """Pre-LN transformer block. Layout adapts to ``cfg.attn_only`` and
348 ``cfg.gated_mlp``."""
350 def __init__(self, cfg: TransformerBridgeConfig, rotary: Optional[NativeRotary] = None):
351 super().__init__()
352 self.cfg = cfg
353 self.ln1 = _make_norm(cfg)
354 self.attn = NativeAttention(cfg, rotary=rotary)
355 if not cfg.attn_only:
356 self.ln2 = _make_norm(cfg)
357 self.mlp = NativeGatedMLP(cfg) if cfg.gated_mlp else NativeMLP(cfg)
359 def forward(
360 self,
361 hidden_states: torch.Tensor,
362 attention_mask: Optional[torch.Tensor] = None,
363 position_ids: Optional[torch.Tensor] = None,
364 **kwargs,
365 ) -> tuple[torch.Tensor]:
366 attn_out, _pattern = self.attn(
367 self.ln1(hidden_states),
368 attention_mask=attention_mask,
369 position_ids=position_ids,
370 )
371 hidden_states = hidden_states + attn_out
372 if not self.cfg.attn_only:
373 hidden_states = hidden_states + self.mlp(self.ln2(hidden_states))
374 # Tuple return matches HF block convention; BlockBridge's parser expects it.
375 return (hidden_states,)
378class NativeModel(nn.Module):
379 """TL-native transformer. See module docstring for the supported feature set."""
381 pos: Optional[nn.Embedding]
382 rotary: Optional[NativeRotary]
384 def __init__(self, cfg: TransformerBridgeConfig):
385 super().__init__()
386 # Write the resolved d_mlp back so downstream consumers see the real
387 # value, not None. Mutates cfg; isolating callers should deep-copy first.
388 if not getattr(cfg, "d_mlp", None): 388 ↛ 389line 388 didn't jump to line 389 because the condition on line 388 was never true
389 cfg.d_mlp = 4 * cfg.d_model
390 self.cfg = cfg
392 self.tok_embed = nn.Embedding(cfg.d_vocab, cfg.d_model)
394 kind = _positional_kind(cfg)
395 if kind == "standard":
396 self.pos = nn.Embedding(cfg.n_ctx, cfg.d_model)
397 self.rotary = None
398 elif kind == "rotary": 398 ↛ 402line 398 didn't jump to line 402 because the condition on line 398 was always true
399 self.pos = None
400 self.rotary = NativeRotary(cfg)
401 else:
402 raise ValueError(
403 f"Unsupported positional_embedding_type={kind!r}. "
404 f"NativeModel supports 'standard' and 'rotary'."
405 )
407 self.layers = nn.ModuleList(
408 [NativeBlock(cfg, rotary=self.rotary) for _ in range(cfg.n_layers)]
409 )
410 # final_rms forces RMS on the final norm regardless of block-norm choice
411 # — matches the TL config semantic Llama uses.
412 self.ln_out = _make_norm(cfg, force_rms=cfg.final_rms)
413 d_vocab_out = cfg.d_vocab_out if cfg.d_vocab_out > 0 else cfg.d_vocab
414 self.head = nn.Linear(cfg.d_model, d_vocab_out, bias=False)
415 self.output_logits_soft_cap = float(cfg.output_logits_soft_cap)
417 def forward(
418 self,
419 input_ids: torch.Tensor,
420 attention_mask: Optional[torch.Tensor] = None,
421 position_ids: Optional[torch.Tensor] = None,
422 **kwargs,
423 ) -> torch.Tensor:
424 """Returns logits directly."""
425 # Bounds check up front so both absolute and rotary paths produce a
426 # self-explanatory error rather than IndexError / shape mismatch.
427 seq_len = input_ids.shape[-1]
428 if seq_len > self.cfg.n_ctx:
429 raise ValueError(
430 f"input length {seq_len} exceeds n_ctx={self.cfg.n_ctx}; "
431 f"position embeddings and rotary tables are pre-baked at n_ctx."
432 )
434 # Resolve position_ids before the block loop so rotary sees the caller's
435 # positions, not the dense default.
436 batch, seq = input_ids.shape
437 if position_ids is None:
438 position_ids = torch.arange(seq, device=input_ids.device).unsqueeze(0).expand(batch, -1)
440 hidden_states = self.tok_embed(input_ids)
441 if self.pos is not None:
442 hidden_states = hidden_states + self.pos(position_ids)
444 for block in self.layers:
445 (hidden_states,) = block(
446 hidden_states, attention_mask=attention_mask, position_ids=position_ids
447 )
448 hidden_states = self.ln_out(hidden_states)
449 logits = self.head(hidden_states)
450 if self.output_logits_soft_cap > 0:
451 c = self.output_logits_soft_cap
452 logits = c * torch.tanh(logits / c)
453 return logits