Coverage for transformer_lens/model_bridge/sources/native/model.py: 94%

253 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""TL-native transformer for TransformerBridge — minimal, no HF/HT dependency. 

2 

3Cfg-driven features: ``normalization_type`` (LN / RMS / RMSPre), ``final_rms``, 

4``gated_mlp``, ``attn_only``, ``n_key_value_heads`` (GQA), ``attn_scores_soft_cap``, 

5``output_logits_soft_cap``, ``positional_embedding_type`` (standard / rotary), 

6``rotary_dim`` / ``rotary_base`` / ``rope_scaling`` (linear PI, dynamic/NTK, 

7llama3 by-parts). 

8""" 

9 

10from __future__ import annotations 

11 

12import math 

13from typing import Callable, Optional, cast 

14 

15import torch 

16import torch.nn as nn 

17import torch.nn.functional as F 

18 

19from transformer_lens.config import TransformerBridgeConfig 

20from transformer_lens.utilities import TypedModuleList 

21 

22# gelu_new = the tanh-approximation HF GPT-2 / HT use; F.gelu(approximate="tanh") 

23# is the exact same formula. 

24_Activation = Callable[[torch.Tensor], torch.Tensor] 

25_ACTIVATIONS: dict[str, _Activation] = { 

26 "gelu": F.gelu, 

27 "gelu_new": lambda x: F.gelu(x, approximate="tanh"), 

28 "relu": F.relu, 

29 "silu": F.silu, 

30 "swish": F.silu, 

31} 

32 

33 

34def _normalization_type(cfg: TransformerBridgeConfig) -> str | None: 

35 normalization_type = cfg.normalization_type 

36 return None if normalization_type is None else normalization_type.upper() 

37 

38 

39def _uses_rms_norm(cfg: TransformerBridgeConfig) -> bool: 

40 return _normalization_type(cfg) in ("RMS", "RMSPRE") 

41 

42 

43def _uses_no_norm(cfg: TransformerBridgeConfig) -> bool: 

44 return _normalization_type(cfg) is None 

45 

46 

47def _positional_kind(cfg: TransformerBridgeConfig) -> str: 

48 return (getattr(cfg, "positional_embedding_type", None) or "standard").lower() 

49 

50 

51class NativeRMSNorm(nn.Module): 

52 """Llama-style RMSNorm. Variance in fp32 regardless of input dtype, then 

53 cast back before the per-channel scale (matches HF LlamaRMSNorm).""" 

54 

55 def __init__(self, d_model: int, eps: float = 1e-5): 

56 super().__init__() 

57 self.weight = nn.Parameter(torch.ones(d_model)) 

58 self.eps = eps 

59 

60 def forward(self, x: torch.Tensor) -> torch.Tensor: 

61 input_dtype = x.dtype 

62 x_fp32 = x.to(torch.float32) 

63 rms_inv = torch.rsqrt(x_fp32.pow(2).mean(dim=-1, keepdim=True) + self.eps) 

64 normalized = (x_fp32 * rms_inv).to(input_dtype) 

65 return self.weight * normalized 

66 

67 

68def _make_norm(cfg: TransformerBridgeConfig, *, force_rms: bool = False) -> nn.Module: 

69 if force_rms or _uses_rms_norm(cfg): 

70 return NativeRMSNorm(cfg.d_model, eps=cfg.eps) 

71 if _uses_no_norm(cfg): 

72 return nn.Identity() 

73 return nn.LayerNorm(cfg.d_model, eps=cfg.eps) 

74 

75 

76def _uses_causal_attention(cfg: TransformerBridgeConfig) -> bool: 

77 return cfg.attention_dir == "causal" 

78 

79 

80def _resolve_rope_scaling( 

81 cfg: TransformerBridgeConfig, rotary_dim: int 

82) -> tuple[float, float, torch.Tensor]: 

83 """Returns (effective_base, position_scale, inv_freq) per cfg.rope_scaling.""" 

84 base = float(cfg.rotary_base) 

85 rope_scaling = getattr(cfg, "rope_scaling", None) 

86 inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)) 

87 

88 if not isinstance(rope_scaling, dict): 

89 return base, 1.0, inv_freq 

90 

91 # Newer HF configs key on "rope_type"; older ones on "type". 

92 scale_type = str(rope_scaling.get("rope_type") or rope_scaling.get("type") or "").lower() 

93 factor = float(rope_scaling.get("factor", 1.0)) 

94 

95 if scale_type in ("", "default") or factor <= 1.0: 

96 return base, 1.0, inv_freq 

97 

98 if scale_type == "linear": 

99 return base, factor, inv_freq 

100 

101 if scale_type in ("dynamic", "ntk"): 

102 scaled_base = base * (factor ** (rotary_dim / (rotary_dim - 2))) 

103 new_inv_freq = 1.0 / (scaled_base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)) 

104 return scaled_base, 1.0, new_inv_freq 

105 

106 if scale_type == "llama3": 

107 low_freq_factor = float(rope_scaling.get("low_freq_factor", 1.0)) 

108 high_freq_factor = float(rope_scaling.get("high_freq_factor", 4.0)) 

109 original_ctx = float( 

110 rope_scaling.get("original_max_position_embeddings") 

111 or rope_scaling.get("original_context_length") 

112 or 8192 

113 ) 

114 low_wavelen = original_ctx / low_freq_factor 

115 high_wavelen = original_ctx / high_freq_factor 

116 wavelens = 2 * math.pi / inv_freq 

117 # Three regimes: low-freq → divide by factor; high-freq → unchanged; 

118 # in-between → smooth linear interpolation between the two. 

119 smooth = (original_ctx / wavelens - low_freq_factor) / (high_freq_factor - low_freq_factor) 

120 new_inv_freq = torch.where( 

121 wavelens > low_wavelen, 

122 inv_freq / factor, 

123 torch.where( 

124 wavelens < high_wavelen, 

125 inv_freq, 

126 (1 - smooth) * inv_freq / factor + smooth * inv_freq, 

127 ), 

128 ) 

129 return base, 1.0, new_inv_freq 

130 

131 raise NotImplementedError( 

132 f"rope_scaling type {scale_type!r} is not supported. " 

133 f"Supported: 'linear', 'dynamic'/'ntk', 'llama3'." 

134 ) 

135 

136 

137class NativeRotary(nn.Module): 

138 """Shared cos/sin tables for RoPE. Honors ``cfg.rope_scaling``.""" 

139 

140 # Declared so mypy sees the buffer dtype; register_buffer alone reports Module|Tensor. 

141 cos_cached: torch.Tensor 

142 sin_cached: torch.Tensor 

143 

144 def __init__(self, cfg: TransformerBridgeConfig): 

145 super().__init__() 

146 rotary_dim = cfg.rotary_dim if cfg.rotary_dim is not None else cfg.d_head 

147 if rotary_dim <= 0 or rotary_dim % 2 != 0: 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true

148 raise ValueError(f"rotary_dim must be a positive even integer, got {rotary_dim!r}") 

149 self.rotary_dim = rotary_dim 

150 

151 base, position_scale, inv_freq = _resolve_rope_scaling(cfg, rotary_dim) 

152 

153 positions = torch.arange(cfg.n_ctx).float() / position_scale 

154 freqs = torch.outer(positions, inv_freq) 

155 # Llama/HF adjacent-pair format: each (2i, 2i+1) pair rotates together. 

156 cos = freqs.cos().repeat_interleave(2, dim=-1) 

157 sin = freqs.sin().repeat_interleave(2, dim=-1) 

158 self.register_buffer("cos_cached", cos, persistent=False) 

159 self.register_buffer("sin_cached", sin, persistent=False) 

160 self.effective_base = base 

161 self.position_scale = position_scale 

162 

163 @staticmethod 

164 def _rotate_half(x: torch.Tensor) -> torch.Tensor: 

165 # Llama-style adjacent-pair rotation: (x0, x1) -> (-x1, x0). 

166 x1 = x[..., 0::2] 

167 x2 = x[..., 1::2] 

168 rot = torch.stack((-x2, x1), dim=-1) 

169 return rot.flatten(-2) 

170 

171 def apply_rope( 

172 self, 

173 q: torch.Tensor, 

174 k: torch.Tensor, 

175 *, 

176 position_ids: Optional[torch.Tensor] = None, 

177 ) -> tuple[torch.Tensor, torch.Tensor]: 

178 """Apply RoPE to Q/K of shape [batch, heads, seq, d_head]. 

179 

180 Named ``apply_rope`` rather than ``apply`` so ``nn.Module.apply(fn)`` 

181 — PyTorch's recursive function-application utility used by 

182 ``bridge.apply(init_fn)`` — isn't shadowed. 

183 """ 

184 seq = q.shape[-2] 

185 rd = self.rotary_dim 

186 if position_ids is None: 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true

187 cos = self.cos_cached[:seq].to(q.dtype) 

188 sin = self.sin_cached[:seq].to(q.dtype) 

189 else: 

190 # [batch, seq] -> [batch, 1, seq, rd] (head dim for broadcast). 

191 cos = self.cos_cached[position_ids].to(q.dtype).unsqueeze(1) 

192 sin = self.sin_cached[position_ids].to(q.dtype).unsqueeze(1) 

193 

194 def _rope(x: torch.Tensor) -> torch.Tensor: 

195 x_rot, x_pass = x[..., :rd], x[..., rd:] 

196 x_rot = x_rot * cos + self._rotate_half(x_rot) * sin 

197 return torch.cat([x_rot, x_pass], dim=-1) if x_pass.shape[-1] else x_rot 

198 

199 return _rope(q), _rope(k) 

200 

201 

202class NativeAttention(nn.Module): 

203 """Split-QKV causal self-attention. Returns (out, pattern); AttentionBridge 

204 fires ``hook_pattern`` off the second element.""" 

205 

206 causal_mask: torch.Tensor 

207 

208 def __init__(self, cfg: TransformerBridgeConfig, rotary: Optional[NativeRotary] = None): 

209 super().__init__() 

210 self.cfg = cfg 

211 self.n_heads = cfg.n_heads 

212 self.d_head = cfg.d_head 

213 self.d_model = cfg.d_model 

214 self.n_kv_heads = cfg.n_key_value_heads or cfg.n_heads 

215 if self.n_heads % self.n_kv_heads != 0: 215 ↛ 216line 215 didn't jump to line 216 because the condition on line 215 was never true

216 raise ValueError( 

217 f"n_heads ({self.n_heads}) must be divisible by n_key_value_heads " 

218 f"({self.n_kv_heads}) for GQA." 

219 ) 

220 self.kv_repeats = self.n_heads // self.n_kv_heads 

221 

222 q_dim = self.n_heads * self.d_head 

223 kv_dim = self.n_kv_heads * self.d_head 

224 self.q = nn.Linear(cfg.d_model, q_dim, bias=True) 

225 self.k = nn.Linear(cfg.d_model, kv_dim, bias=True) 

226 self.v = nn.Linear(cfg.d_model, kv_dim, bias=True) 

227 self.o = nn.Linear(q_dim, cfg.d_model, bias=True) 

228 

229 mask = torch.triu(torch.ones(cfg.n_ctx, cfg.n_ctx, dtype=torch.bool), diagonal=1) 

230 self.register_buffer("causal_mask", mask, persistent=False) 

231 

232 # attn_scale=1.0 reads like "standard scaling" but is "divide by 1" — 

233 # i.e. unscaled scores, which saturate softmax for d_head>1. 

234 if cfg.use_attn_scale and cfg.attn_scale > 0: 

235 if self.d_head > 1 and math.isclose(cfg.attn_scale, 1.0, abs_tol=1e-9): 

236 raise ValueError( 

237 f"attn_scale=1.0 with d_head={self.d_head} (>1) is unscaled " 

238 f"attention; softmax will saturate. For standard scaling " 

239 f"leave attn_scale at -1 (sentinel for sqrt(d_head))." 

240 ) 

241 scale = cfg.attn_scale 

242 else: 

243 scale = math.sqrt(cfg.d_head) 

244 self.scale = scale 

245 self.rotary = rotary 

246 self.attn_scores_soft_cap = float(cfg.attn_scores_soft_cap) 

247 self.causal = _uses_causal_attention(cfg) 

248 

249 def forward( 

250 self, 

251 hidden_states: torch.Tensor, 

252 attention_mask: Optional[torch.Tensor] = None, 

253 position_ids: Optional[torch.Tensor] = None, 

254 **kwargs, 

255 ) -> tuple[torch.Tensor, torch.Tensor]: 

256 batch, seq, _ = hidden_states.shape 

257 

258 q = self.q(hidden_states).view(batch, seq, self.n_heads, self.d_head).transpose(1, 2) 

259 k = self.k(hidden_states).view(batch, seq, self.n_kv_heads, self.d_head).transpose(1, 2) 

260 v = self.v(hidden_states).view(batch, seq, self.n_kv_heads, self.d_head).transpose(1, 2) 

261 

262 if self.rotary is not None: 

263 q, k = self.rotary.apply_rope(q, k, position_ids=position_ids) 

264 

265 # GQA: repeat_interleave matches HF Llama's repeat_kv group ordering. 

266 if self.kv_repeats > 1: 

267 k = k.repeat_interleave(self.kv_repeats, dim=1) 

268 v = v.repeat_interleave(self.kv_repeats, dim=1) 

269 

270 scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale 

271 # Gemma2 soft-cap before the causal mask so masked positions stay -inf. 

272 if self.attn_scores_soft_cap > 0: 

273 c = self.attn_scores_soft_cap 

274 scores = c * torch.tanh(scores / c) 

275 

276 if self.causal: 

277 block_mask = self.causal_mask[:seq, :seq] 

278 else: 

279 block_mask = torch.zeros(seq, seq, dtype=torch.bool, device=scores.device) 

280 if attention_mask is not None: 

281 block_mask = self._combine_attention_mask(block_mask, attention_mask, batch=batch) 

282 scores = scores.masked_fill(block_mask, float("-inf")) 

283 

284 pattern = F.softmax(scores, dim=-1) 

285 

286 attn = torch.matmul(pattern, v).transpose(1, 2).contiguous().view(batch, seq, -1) 

287 out = self.o(attn) 

288 return out, pattern 

289 

290 @staticmethod 

291 def _combine_attention_mask( 

292 block_mask: torch.Tensor, attention_mask: torch.Tensor, *, batch: int 

293 ) -> torch.Tensor: 

294 """Combine an external attention_mask with the causal mask. 

295 

296 Accepts 2D HF padding mask ``[batch, seq]`` (1=keep, 0=mask), 4D bool 

297 mask (True=mask), or 4D additive float mask (HF generation style; values 

298 below -1 treated as masked). 

299 """ 

300 if attention_mask.dim() == 2: 

301 pad_mask = ~attention_mask.bool() 

302 return block_mask | pad_mask[:, None, None, :] 

303 if attention_mask.dim() == 4: 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true

304 if attention_mask.dtype is torch.bool: 

305 return block_mask | attention_mask 

306 # HF additive masks use -inf or large negatives; benign biases bounded. 

307 extra = attention_mask < -1.0 

308 return block_mask | extra 

309 raise ValueError( 

310 f"attention_mask must be 2D [batch, seq] or 4D [batch, *, seq, seq], " 

311 f"got shape {tuple(attention_mask.shape)}." 

312 ) 

313 

314 

315class NativeMLP(nn.Module): 

316 """Two-layer MLP with configurable activation.""" 

317 

318 act: Callable[[torch.Tensor], torch.Tensor] 

319 

320 def __init__(self, cfg: TransformerBridgeConfig): 

321 super().__init__() 

322 assert cfg.d_mlp is not None, "NativeModel resolves d_mlp before instantiating MLPs" 

323 d_mlp: int = cfg.d_mlp 

324 self.fc_in = nn.Linear(cfg.d_model, d_mlp, bias=True) 

325 self.fc_out = nn.Linear(d_mlp, cfg.d_model, bias=True) 

326 act_name = (cfg.act_fn or "gelu").lower() 

327 if act_name not in _ACTIVATIONS: 327 ↛ 328line 327 didn't jump to line 328 because the condition on line 327 was never true

328 raise ValueError(f"Unsupported act_fn={act_name!r}. Supported: {sorted(_ACTIVATIONS)}") 

329 self.act = _ACTIVATIONS[act_name] 

330 

331 def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: 

332 return self.fc_out(self.act(self.fc_in(hidden_states))) 

333 

334 

335class NativeGatedMLP(nn.Module): 

336 """SwiGLU / ReGLU / GeGLU gated MLP (variant picked by ``cfg.act_fn``). 

337 

338 Submodules ``gate`` / ``in`` / ``out`` match GatedMLPBridge's expected slots. 

339 """ 

340 

341 act: Callable[[torch.Tensor], torch.Tensor] 

342 

343 def __init__(self, cfg: TransformerBridgeConfig): 

344 super().__init__() 

345 assert cfg.d_mlp is not None, "NativeModel resolves d_mlp before instantiating MLPs" 

346 d_mlp: int = cfg.d_mlp 

347 # Llama convention: no biases on gated MLP projections. 

348 self.gate = nn.Linear(cfg.d_model, d_mlp, bias=False) 

349 # ``in`` is a Python keyword; add_module + getattr(self, "in") works 

350 # because the bridge resolves LinearBridge(name="in") the same way. 

351 self.add_module("in", nn.Linear(cfg.d_model, d_mlp, bias=False)) 

352 self.out = nn.Linear(d_mlp, cfg.d_model, bias=False) 

353 # Default to SwiGLU; mirror NativeMLP's dispatch so a typo'd act_fn 

354 # raises instead of silently changing the model. 

355 act_name = (cfg.act_fn or "silu").lower() 

356 if act_name not in _ACTIVATIONS: 

357 raise ValueError(f"Unsupported act_fn={act_name!r}. Supported: {sorted(_ACTIVATIONS)}") 

358 self.act = _ACTIVATIONS[act_name] 

359 

360 def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: 

361 gate_out = self.act(self.gate(hidden_states)) 

362 in_proj = cast(nn.Linear, getattr(self, "in")) 

363 up_out = in_proj(hidden_states) 

364 return self.out(gate_out * up_out) 

365 

366 

367class NativeBlock(nn.Module): 

368 """Pre-LN transformer block. Layout adapts to ``cfg.attn_only`` and 

369 ``cfg.gated_mlp``.""" 

370 

371 def __init__(self, cfg: TransformerBridgeConfig, rotary: Optional[NativeRotary] = None): 

372 super().__init__() 

373 self.cfg = cfg 

374 self.ln1 = _make_norm(cfg) 

375 self.attn = NativeAttention(cfg, rotary=rotary) 

376 if not cfg.attn_only: 

377 self.ln2 = _make_norm(cfg) 

378 self.mlp = NativeGatedMLP(cfg) if cfg.gated_mlp else NativeMLP(cfg) 

379 

380 def forward( 

381 self, 

382 hidden_states: torch.Tensor, 

383 attention_mask: Optional[torch.Tensor] = None, 

384 position_ids: Optional[torch.Tensor] = None, 

385 **kwargs, 

386 ) -> tuple[torch.Tensor]: 

387 attn_out, _pattern = self.attn( 

388 self.ln1(hidden_states), 

389 attention_mask=attention_mask, 

390 position_ids=position_ids, 

391 ) 

392 hidden_states = hidden_states + attn_out 

393 if not self.cfg.attn_only: 

394 hidden_states = hidden_states + self.mlp(self.ln2(hidden_states)) 

395 # Tuple return matches HF block convention; BlockBridge's parser expects it. 

396 return (hidden_states,) 

397 

398 

399class NativeModel(nn.Module): 

400 """TL-native transformer. See module docstring for the supported feature set.""" 

401 

402 pos: Optional[nn.Embedding] 

403 rotary: Optional[NativeRotary] 

404 

405 def __init__(self, cfg: TransformerBridgeConfig): 

406 super().__init__() 

407 # Write the resolved d_mlp back so downstream consumers see the real 

408 # value, not None. Mutates cfg; isolating callers should deep-copy first. 

409 if not getattr(cfg, "d_mlp", None): 409 ↛ 410line 409 didn't jump to line 410 because the condition on line 409 was never true

410 cfg.d_mlp = 4 * cfg.d_model 

411 self.cfg = cfg 

412 

413 self.tok_embed = nn.Embedding(cfg.d_vocab, cfg.d_model) 

414 

415 kind = _positional_kind(cfg) 

416 if kind == "standard": 

417 self.pos = nn.Embedding(cfg.n_ctx, cfg.d_model) 

418 self.rotary = None 

419 elif kind == "rotary": 419 ↛ 423line 419 didn't jump to line 423 because the condition on line 419 was always true

420 self.pos = None 

421 self.rotary = NativeRotary(cfg) 

422 else: 

423 raise ValueError( 

424 f"Unsupported positional_embedding_type={kind!r}. " 

425 f"NativeModel supports 'standard' and 'rotary'." 

426 ) 

427 

428 self.layers = TypedModuleList( 

429 [NativeBlock(cfg, rotary=self.rotary) for _ in range(cfg.n_layers)] 

430 ) 

431 # final_rms forces RMS on the final norm regardless of block-norm choice 

432 # — matches the TL config semantic Llama uses. 

433 self.ln_out = _make_norm(cfg, force_rms=cfg.final_rms) 

434 d_vocab_out = cfg.d_vocab_out if cfg.d_vocab_out > 0 else cfg.d_vocab 

435 self.head = nn.Linear(cfg.d_model, d_vocab_out, bias=False) 

436 self.output_logits_soft_cap = float(cfg.output_logits_soft_cap) 

437 

438 def forward( 

439 self, 

440 input_ids: torch.Tensor, 

441 attention_mask: Optional[torch.Tensor] = None, 

442 position_ids: Optional[torch.Tensor] = None, 

443 **kwargs, 

444 ) -> torch.Tensor: 

445 """Returns logits directly.""" 

446 # Bounds check up front so both absolute and rotary paths produce a 

447 # self-explanatory error rather than IndexError / shape mismatch. 

448 seq_len = input_ids.shape[-1] 

449 if seq_len > self.cfg.n_ctx: 

450 raise ValueError( 

451 f"input length {seq_len} exceeds n_ctx={self.cfg.n_ctx}; " 

452 f"position embeddings and rotary tables are pre-baked at n_ctx." 

453 ) 

454 

455 # Resolve position_ids before the block loop so rotary sees the caller's 

456 # positions, not the dense default. 

457 batch, seq = input_ids.shape 

458 if position_ids is None: 

459 position_ids = torch.arange(seq, device=input_ids.device).unsqueeze(0).expand(batch, -1) 

460 

461 hidden_states = self.tok_embed(input_ids) 

462 if self.pos is not None: 

463 hidden_states = hidden_states + self.pos(position_ids) 

464 

465 for block in self.layers: 

466 (hidden_states,) = block( 

467 hidden_states, attention_mask=attention_mask, position_ids=position_ids 

468 ) 

469 hidden_states = self.ln_out(hidden_states) 

470 logits = self.head(hidden_states) 

471 if self.output_logits_soft_cap > 0: 

472 c = self.output_logits_soft_cap 

473 logits = c * torch.tanh(logits / c) 

474 return logits