Coverage for transformer_lens/model_bridge/sources/native/model.py: 94%
253 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""TL-native transformer for TransformerBridge — minimal, no HF/HT dependency.
3Cfg-driven features: ``normalization_type`` (LN / RMS / RMSPre), ``final_rms``,
4``gated_mlp``, ``attn_only``, ``n_key_value_heads`` (GQA), ``attn_scores_soft_cap``,
5``output_logits_soft_cap``, ``positional_embedding_type`` (standard / rotary),
6``rotary_dim`` / ``rotary_base`` / ``rope_scaling`` (linear PI, dynamic/NTK,
7llama3 by-parts).
8"""
10from __future__ import annotations
12import math
13from typing import Callable, Optional, cast
15import torch
16import torch.nn as nn
17import torch.nn.functional as F
19from transformer_lens.config import TransformerBridgeConfig
20from transformer_lens.utilities import TypedModuleList
22# gelu_new = the tanh-approximation HF GPT-2 / HT use; F.gelu(approximate="tanh")
23# is the exact same formula.
24_Activation = Callable[[torch.Tensor], torch.Tensor]
25_ACTIVATIONS: dict[str, _Activation] = {
26 "gelu": F.gelu,
27 "gelu_new": lambda x: F.gelu(x, approximate="tanh"),
28 "relu": F.relu,
29 "silu": F.silu,
30 "swish": F.silu,
31}
34def _normalization_type(cfg: TransformerBridgeConfig) -> str | None:
35 normalization_type = cfg.normalization_type
36 return None if normalization_type is None else normalization_type.upper()
39def _uses_rms_norm(cfg: TransformerBridgeConfig) -> bool:
40 return _normalization_type(cfg) in ("RMS", "RMSPRE")
43def _uses_no_norm(cfg: TransformerBridgeConfig) -> bool:
44 return _normalization_type(cfg) is None
47def _positional_kind(cfg: TransformerBridgeConfig) -> str:
48 return (getattr(cfg, "positional_embedding_type", None) or "standard").lower()
51class NativeRMSNorm(nn.Module):
52 """Llama-style RMSNorm. Variance in fp32 regardless of input dtype, then
53 cast back before the per-channel scale (matches HF LlamaRMSNorm)."""
55 def __init__(self, d_model: int, eps: float = 1e-5):
56 super().__init__()
57 self.weight = nn.Parameter(torch.ones(d_model))
58 self.eps = eps
60 def forward(self, x: torch.Tensor) -> torch.Tensor:
61 input_dtype = x.dtype
62 x_fp32 = x.to(torch.float32)
63 rms_inv = torch.rsqrt(x_fp32.pow(2).mean(dim=-1, keepdim=True) + self.eps)
64 normalized = (x_fp32 * rms_inv).to(input_dtype)
65 return self.weight * normalized
68def _make_norm(cfg: TransformerBridgeConfig, *, force_rms: bool = False) -> nn.Module:
69 if force_rms or _uses_rms_norm(cfg):
70 return NativeRMSNorm(cfg.d_model, eps=cfg.eps)
71 if _uses_no_norm(cfg):
72 return nn.Identity()
73 return nn.LayerNorm(cfg.d_model, eps=cfg.eps)
76def _uses_causal_attention(cfg: TransformerBridgeConfig) -> bool:
77 return cfg.attention_dir == "causal"
80def _resolve_rope_scaling(
81 cfg: TransformerBridgeConfig, rotary_dim: int
82) -> tuple[float, float, torch.Tensor]:
83 """Returns (effective_base, position_scale, inv_freq) per cfg.rope_scaling."""
84 base = float(cfg.rotary_base)
85 rope_scaling = getattr(cfg, "rope_scaling", None)
86 inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim))
88 if not isinstance(rope_scaling, dict):
89 return base, 1.0, inv_freq
91 # Newer HF configs key on "rope_type"; older ones on "type".
92 scale_type = str(rope_scaling.get("rope_type") or rope_scaling.get("type") or "").lower()
93 factor = float(rope_scaling.get("factor", 1.0))
95 if scale_type in ("", "default") or factor <= 1.0:
96 return base, 1.0, inv_freq
98 if scale_type == "linear":
99 return base, factor, inv_freq
101 if scale_type in ("dynamic", "ntk"):
102 scaled_base = base * (factor ** (rotary_dim / (rotary_dim - 2)))
103 new_inv_freq = 1.0 / (scaled_base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim))
104 return scaled_base, 1.0, new_inv_freq
106 if scale_type == "llama3":
107 low_freq_factor = float(rope_scaling.get("low_freq_factor", 1.0))
108 high_freq_factor = float(rope_scaling.get("high_freq_factor", 4.0))
109 original_ctx = float(
110 rope_scaling.get("original_max_position_embeddings")
111 or rope_scaling.get("original_context_length")
112 or 8192
113 )
114 low_wavelen = original_ctx / low_freq_factor
115 high_wavelen = original_ctx / high_freq_factor
116 wavelens = 2 * math.pi / inv_freq
117 # Three regimes: low-freq → divide by factor; high-freq → unchanged;
118 # in-between → smooth linear interpolation between the two.
119 smooth = (original_ctx / wavelens - low_freq_factor) / (high_freq_factor - low_freq_factor)
120 new_inv_freq = torch.where(
121 wavelens > low_wavelen,
122 inv_freq / factor,
123 torch.where(
124 wavelens < high_wavelen,
125 inv_freq,
126 (1 - smooth) * inv_freq / factor + smooth * inv_freq,
127 ),
128 )
129 return base, 1.0, new_inv_freq
131 raise NotImplementedError(
132 f"rope_scaling type {scale_type!r} is not supported. "
133 f"Supported: 'linear', 'dynamic'/'ntk', 'llama3'."
134 )
137class NativeRotary(nn.Module):
138 """Shared cos/sin tables for RoPE. Honors ``cfg.rope_scaling``."""
140 # Declared so mypy sees the buffer dtype; register_buffer alone reports Module|Tensor.
141 cos_cached: torch.Tensor
142 sin_cached: torch.Tensor
144 def __init__(self, cfg: TransformerBridgeConfig):
145 super().__init__()
146 rotary_dim = cfg.rotary_dim if cfg.rotary_dim is not None else cfg.d_head
147 if rotary_dim <= 0 or rotary_dim % 2 != 0: 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true
148 raise ValueError(f"rotary_dim must be a positive even integer, got {rotary_dim!r}")
149 self.rotary_dim = rotary_dim
151 base, position_scale, inv_freq = _resolve_rope_scaling(cfg, rotary_dim)
153 positions = torch.arange(cfg.n_ctx).float() / position_scale
154 freqs = torch.outer(positions, inv_freq)
155 # Llama/HF adjacent-pair format: each (2i, 2i+1) pair rotates together.
156 cos = freqs.cos().repeat_interleave(2, dim=-1)
157 sin = freqs.sin().repeat_interleave(2, dim=-1)
158 self.register_buffer("cos_cached", cos, persistent=False)
159 self.register_buffer("sin_cached", sin, persistent=False)
160 self.effective_base = base
161 self.position_scale = position_scale
163 @staticmethod
164 def _rotate_half(x: torch.Tensor) -> torch.Tensor:
165 # Llama-style adjacent-pair rotation: (x0, x1) -> (-x1, x0).
166 x1 = x[..., 0::2]
167 x2 = x[..., 1::2]
168 rot = torch.stack((-x2, x1), dim=-1)
169 return rot.flatten(-2)
171 def apply_rope(
172 self,
173 q: torch.Tensor,
174 k: torch.Tensor,
175 *,
176 position_ids: Optional[torch.Tensor] = None,
177 ) -> tuple[torch.Tensor, torch.Tensor]:
178 """Apply RoPE to Q/K of shape [batch, heads, seq, d_head].
180 Named ``apply_rope`` rather than ``apply`` so ``nn.Module.apply(fn)``
181 — PyTorch's recursive function-application utility used by
182 ``bridge.apply(init_fn)`` — isn't shadowed.
183 """
184 seq = q.shape[-2]
185 rd = self.rotary_dim
186 if position_ids is None: 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true
187 cos = self.cos_cached[:seq].to(q.dtype)
188 sin = self.sin_cached[:seq].to(q.dtype)
189 else:
190 # [batch, seq] -> [batch, 1, seq, rd] (head dim for broadcast).
191 cos = self.cos_cached[position_ids].to(q.dtype).unsqueeze(1)
192 sin = self.sin_cached[position_ids].to(q.dtype).unsqueeze(1)
194 def _rope(x: torch.Tensor) -> torch.Tensor:
195 x_rot, x_pass = x[..., :rd], x[..., rd:]
196 x_rot = x_rot * cos + self._rotate_half(x_rot) * sin
197 return torch.cat([x_rot, x_pass], dim=-1) if x_pass.shape[-1] else x_rot
199 return _rope(q), _rope(k)
202class NativeAttention(nn.Module):
203 """Split-QKV causal self-attention. Returns (out, pattern); AttentionBridge
204 fires ``hook_pattern`` off the second element."""
206 causal_mask: torch.Tensor
208 def __init__(self, cfg: TransformerBridgeConfig, rotary: Optional[NativeRotary] = None):
209 super().__init__()
210 self.cfg = cfg
211 self.n_heads = cfg.n_heads
212 self.d_head = cfg.d_head
213 self.d_model = cfg.d_model
214 self.n_kv_heads = cfg.n_key_value_heads or cfg.n_heads
215 if self.n_heads % self.n_kv_heads != 0: 215 ↛ 216line 215 didn't jump to line 216 because the condition on line 215 was never true
216 raise ValueError(
217 f"n_heads ({self.n_heads}) must be divisible by n_key_value_heads "
218 f"({self.n_kv_heads}) for GQA."
219 )
220 self.kv_repeats = self.n_heads // self.n_kv_heads
222 q_dim = self.n_heads * self.d_head
223 kv_dim = self.n_kv_heads * self.d_head
224 self.q = nn.Linear(cfg.d_model, q_dim, bias=True)
225 self.k = nn.Linear(cfg.d_model, kv_dim, bias=True)
226 self.v = nn.Linear(cfg.d_model, kv_dim, bias=True)
227 self.o = nn.Linear(q_dim, cfg.d_model, bias=True)
229 mask = torch.triu(torch.ones(cfg.n_ctx, cfg.n_ctx, dtype=torch.bool), diagonal=1)
230 self.register_buffer("causal_mask", mask, persistent=False)
232 # attn_scale=1.0 reads like "standard scaling" but is "divide by 1" —
233 # i.e. unscaled scores, which saturate softmax for d_head>1.
234 if cfg.use_attn_scale and cfg.attn_scale > 0:
235 if self.d_head > 1 and math.isclose(cfg.attn_scale, 1.0, abs_tol=1e-9):
236 raise ValueError(
237 f"attn_scale=1.0 with d_head={self.d_head} (>1) is unscaled "
238 f"attention; softmax will saturate. For standard scaling "
239 f"leave attn_scale at -1 (sentinel for sqrt(d_head))."
240 )
241 scale = cfg.attn_scale
242 else:
243 scale = math.sqrt(cfg.d_head)
244 self.scale = scale
245 self.rotary = rotary
246 self.attn_scores_soft_cap = float(cfg.attn_scores_soft_cap)
247 self.causal = _uses_causal_attention(cfg)
249 def forward(
250 self,
251 hidden_states: torch.Tensor,
252 attention_mask: Optional[torch.Tensor] = None,
253 position_ids: Optional[torch.Tensor] = None,
254 **kwargs,
255 ) -> tuple[torch.Tensor, torch.Tensor]:
256 batch, seq, _ = hidden_states.shape
258 q = self.q(hidden_states).view(batch, seq, self.n_heads, self.d_head).transpose(1, 2)
259 k = self.k(hidden_states).view(batch, seq, self.n_kv_heads, self.d_head).transpose(1, 2)
260 v = self.v(hidden_states).view(batch, seq, self.n_kv_heads, self.d_head).transpose(1, 2)
262 if self.rotary is not None:
263 q, k = self.rotary.apply_rope(q, k, position_ids=position_ids)
265 # GQA: repeat_interleave matches HF Llama's repeat_kv group ordering.
266 if self.kv_repeats > 1:
267 k = k.repeat_interleave(self.kv_repeats, dim=1)
268 v = v.repeat_interleave(self.kv_repeats, dim=1)
270 scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
271 # Gemma2 soft-cap before the causal mask so masked positions stay -inf.
272 if self.attn_scores_soft_cap > 0:
273 c = self.attn_scores_soft_cap
274 scores = c * torch.tanh(scores / c)
276 if self.causal:
277 block_mask = self.causal_mask[:seq, :seq]
278 else:
279 block_mask = torch.zeros(seq, seq, dtype=torch.bool, device=scores.device)
280 if attention_mask is not None:
281 block_mask = self._combine_attention_mask(block_mask, attention_mask, batch=batch)
282 scores = scores.masked_fill(block_mask, float("-inf"))
284 pattern = F.softmax(scores, dim=-1)
286 attn = torch.matmul(pattern, v).transpose(1, 2).contiguous().view(batch, seq, -1)
287 out = self.o(attn)
288 return out, pattern
290 @staticmethod
291 def _combine_attention_mask(
292 block_mask: torch.Tensor, attention_mask: torch.Tensor, *, batch: int
293 ) -> torch.Tensor:
294 """Combine an external attention_mask with the causal mask.
296 Accepts 2D HF padding mask ``[batch, seq]`` (1=keep, 0=mask), 4D bool
297 mask (True=mask), or 4D additive float mask (HF generation style; values
298 below -1 treated as masked).
299 """
300 if attention_mask.dim() == 2:
301 pad_mask = ~attention_mask.bool()
302 return block_mask | pad_mask[:, None, None, :]
303 if attention_mask.dim() == 4: 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true
304 if attention_mask.dtype is torch.bool:
305 return block_mask | attention_mask
306 # HF additive masks use -inf or large negatives; benign biases bounded.
307 extra = attention_mask < -1.0
308 return block_mask | extra
309 raise ValueError(
310 f"attention_mask must be 2D [batch, seq] or 4D [batch, *, seq, seq], "
311 f"got shape {tuple(attention_mask.shape)}."
312 )
315class NativeMLP(nn.Module):
316 """Two-layer MLP with configurable activation."""
318 act: Callable[[torch.Tensor], torch.Tensor]
320 def __init__(self, cfg: TransformerBridgeConfig):
321 super().__init__()
322 assert cfg.d_mlp is not None, "NativeModel resolves d_mlp before instantiating MLPs"
323 d_mlp: int = cfg.d_mlp
324 self.fc_in = nn.Linear(cfg.d_model, d_mlp, bias=True)
325 self.fc_out = nn.Linear(d_mlp, cfg.d_model, bias=True)
326 act_name = (cfg.act_fn or "gelu").lower()
327 if act_name not in _ACTIVATIONS: 327 ↛ 328line 327 didn't jump to line 328 because the condition on line 327 was never true
328 raise ValueError(f"Unsupported act_fn={act_name!r}. Supported: {sorted(_ACTIVATIONS)}")
329 self.act = _ACTIVATIONS[act_name]
331 def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
332 return self.fc_out(self.act(self.fc_in(hidden_states)))
335class NativeGatedMLP(nn.Module):
336 """SwiGLU / ReGLU / GeGLU gated MLP (variant picked by ``cfg.act_fn``).
338 Submodules ``gate`` / ``in`` / ``out`` match GatedMLPBridge's expected slots.
339 """
341 act: Callable[[torch.Tensor], torch.Tensor]
343 def __init__(self, cfg: TransformerBridgeConfig):
344 super().__init__()
345 assert cfg.d_mlp is not None, "NativeModel resolves d_mlp before instantiating MLPs"
346 d_mlp: int = cfg.d_mlp
347 # Llama convention: no biases on gated MLP projections.
348 self.gate = nn.Linear(cfg.d_model, d_mlp, bias=False)
349 # ``in`` is a Python keyword; add_module + getattr(self, "in") works
350 # because the bridge resolves LinearBridge(name="in") the same way.
351 self.add_module("in", nn.Linear(cfg.d_model, d_mlp, bias=False))
352 self.out = nn.Linear(d_mlp, cfg.d_model, bias=False)
353 # Default to SwiGLU; mirror NativeMLP's dispatch so a typo'd act_fn
354 # raises instead of silently changing the model.
355 act_name = (cfg.act_fn or "silu").lower()
356 if act_name not in _ACTIVATIONS:
357 raise ValueError(f"Unsupported act_fn={act_name!r}. Supported: {sorted(_ACTIVATIONS)}")
358 self.act = _ACTIVATIONS[act_name]
360 def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
361 gate_out = self.act(self.gate(hidden_states))
362 in_proj = cast(nn.Linear, getattr(self, "in"))
363 up_out = in_proj(hidden_states)
364 return self.out(gate_out * up_out)
367class NativeBlock(nn.Module):
368 """Pre-LN transformer block. Layout adapts to ``cfg.attn_only`` and
369 ``cfg.gated_mlp``."""
371 def __init__(self, cfg: TransformerBridgeConfig, rotary: Optional[NativeRotary] = None):
372 super().__init__()
373 self.cfg = cfg
374 self.ln1 = _make_norm(cfg)
375 self.attn = NativeAttention(cfg, rotary=rotary)
376 if not cfg.attn_only:
377 self.ln2 = _make_norm(cfg)
378 self.mlp = NativeGatedMLP(cfg) if cfg.gated_mlp else NativeMLP(cfg)
380 def forward(
381 self,
382 hidden_states: torch.Tensor,
383 attention_mask: Optional[torch.Tensor] = None,
384 position_ids: Optional[torch.Tensor] = None,
385 **kwargs,
386 ) -> tuple[torch.Tensor]:
387 attn_out, _pattern = self.attn(
388 self.ln1(hidden_states),
389 attention_mask=attention_mask,
390 position_ids=position_ids,
391 )
392 hidden_states = hidden_states + attn_out
393 if not self.cfg.attn_only:
394 hidden_states = hidden_states + self.mlp(self.ln2(hidden_states))
395 # Tuple return matches HF block convention; BlockBridge's parser expects it.
396 return (hidden_states,)
399class NativeModel(nn.Module):
400 """TL-native transformer. See module docstring for the supported feature set."""
402 pos: Optional[nn.Embedding]
403 rotary: Optional[NativeRotary]
405 def __init__(self, cfg: TransformerBridgeConfig):
406 super().__init__()
407 # Write the resolved d_mlp back so downstream consumers see the real
408 # value, not None. Mutates cfg; isolating callers should deep-copy first.
409 if not getattr(cfg, "d_mlp", None): 409 ↛ 410line 409 didn't jump to line 410 because the condition on line 409 was never true
410 cfg.d_mlp = 4 * cfg.d_model
411 self.cfg = cfg
413 self.tok_embed = nn.Embedding(cfg.d_vocab, cfg.d_model)
415 kind = _positional_kind(cfg)
416 if kind == "standard":
417 self.pos = nn.Embedding(cfg.n_ctx, cfg.d_model)
418 self.rotary = None
419 elif kind == "rotary": 419 ↛ 423line 419 didn't jump to line 423 because the condition on line 419 was always true
420 self.pos = None
421 self.rotary = NativeRotary(cfg)
422 else:
423 raise ValueError(
424 f"Unsupported positional_embedding_type={kind!r}. "
425 f"NativeModel supports 'standard' and 'rotary'."
426 )
428 self.layers = TypedModuleList(
429 [NativeBlock(cfg, rotary=self.rotary) for _ in range(cfg.n_layers)]
430 )
431 # final_rms forces RMS on the final norm regardless of block-norm choice
432 # — matches the TL config semantic Llama uses.
433 self.ln_out = _make_norm(cfg, force_rms=cfg.final_rms)
434 d_vocab_out = cfg.d_vocab_out if cfg.d_vocab_out > 0 else cfg.d_vocab
435 self.head = nn.Linear(cfg.d_model, d_vocab_out, bias=False)
436 self.output_logits_soft_cap = float(cfg.output_logits_soft_cap)
438 def forward(
439 self,
440 input_ids: torch.Tensor,
441 attention_mask: Optional[torch.Tensor] = None,
442 position_ids: Optional[torch.Tensor] = None,
443 **kwargs,
444 ) -> torch.Tensor:
445 """Returns logits directly."""
446 # Bounds check up front so both absolute and rotary paths produce a
447 # self-explanatory error rather than IndexError / shape mismatch.
448 seq_len = input_ids.shape[-1]
449 if seq_len > self.cfg.n_ctx:
450 raise ValueError(
451 f"input length {seq_len} exceeds n_ctx={self.cfg.n_ctx}; "
452 f"position embeddings and rotary tables are pre-baked at n_ctx."
453 )
455 # Resolve position_ids before the block loop so rotary sees the caller's
456 # positions, not the dense default.
457 batch, seq = input_ids.shape
458 if position_ids is None:
459 position_ids = torch.arange(seq, device=input_ids.device).unsqueeze(0).expand(batch, -1)
461 hidden_states = self.tok_embed(input_ids)
462 if self.pos is not None:
463 hidden_states = hidden_states + self.pos(position_ids)
465 for block in self.layers:
466 (hidden_states,) = block(
467 hidden_states, attention_mask=attention_mask, position_ids=position_ids
468 )
469 hidden_states = self.ln_out(hidden_states)
470 logits = self.head(hidden_states)
471 if self.output_logits_soft_cap > 0:
472 c = self.output_logits_soft_cap
473 logits = c * torch.tanh(logits / c)
474 return logits