Coverage for transformer_lens/model_bridge/sources/native/model.py: 93%

240 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""TL-native transformer for TransformerBridge — minimal, no HF/HT dependency. 

2 

3Cfg-driven features: ``normalization_type`` (LN / RMS / RMSPre), ``final_rms``, 

4``gated_mlp``, ``attn_only``, ``n_key_value_heads`` (GQA), ``attn_scores_soft_cap``, 

5``output_logits_soft_cap``, ``positional_embedding_type`` (standard / rotary), 

6``rotary_dim`` / ``rotary_base`` / ``rope_scaling`` (linear PI, dynamic/NTK, 

7llama3 by-parts). 

8""" 

9from __future__ import annotations 

10 

11import math 

12from typing import Callable, Optional, cast 

13 

14import torch 

15import torch.nn as nn 

16import torch.nn.functional as F 

17 

18from transformer_lens.config import TransformerBridgeConfig 

19 

20# gelu_new = the tanh-approximation HF GPT-2 / HT use; F.gelu(approximate="tanh") 

21# is the exact same formula. 

22_Activation = Callable[[torch.Tensor], torch.Tensor] 

23_ACTIVATIONS: dict[str, _Activation] = { 

24 "gelu": F.gelu, 

25 "gelu_new": lambda x: F.gelu(x, approximate="tanh"), 

26 "relu": F.relu, 

27 "silu": F.silu, 

28 "swish": F.silu, 

29} 

30 

31 

32def _uses_rms_norm(cfg: TransformerBridgeConfig) -> bool: 

33 return (cfg.normalization_type or "LN").upper() in ("RMS", "RMSPRE") 

34 

35 

36def _positional_kind(cfg: TransformerBridgeConfig) -> str: 

37 return (getattr(cfg, "positional_embedding_type", None) or "standard").lower() 

38 

39 

40class NativeRMSNorm(nn.Module): 

41 """Llama-style RMSNorm. Variance in fp32 regardless of input dtype, then 

42 cast back before the per-channel scale (matches HF LlamaRMSNorm).""" 

43 

44 def __init__(self, d_model: int, eps: float = 1e-5): 

45 super().__init__() 

46 self.weight = nn.Parameter(torch.ones(d_model)) 

47 self.eps = eps 

48 

49 def forward(self, x: torch.Tensor) -> torch.Tensor: 

50 input_dtype = x.dtype 

51 x_fp32 = x.to(torch.float32) 

52 rms_inv = torch.rsqrt(x_fp32.pow(2).mean(dim=-1, keepdim=True) + self.eps) 

53 normalized = (x_fp32 * rms_inv).to(input_dtype) 

54 return self.weight * normalized 

55 

56 

57def _make_norm(cfg: TransformerBridgeConfig, *, force_rms: bool = False) -> nn.Module: 

58 if force_rms or _uses_rms_norm(cfg): 

59 return NativeRMSNorm(cfg.d_model, eps=cfg.eps) 

60 return nn.LayerNorm(cfg.d_model, eps=cfg.eps) 

61 

62 

63def _resolve_rope_scaling( 

64 cfg: TransformerBridgeConfig, rotary_dim: int 

65) -> tuple[float, float, torch.Tensor]: 

66 """Returns (effective_base, position_scale, inv_freq) per cfg.rope_scaling.""" 

67 base = float(cfg.rotary_base) 

68 rope_scaling = getattr(cfg, "rope_scaling", None) 

69 inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)) 

70 

71 if not isinstance(rope_scaling, dict): 

72 return base, 1.0, inv_freq 

73 

74 # Newer HF configs key on "rope_type"; older ones on "type". 

75 scale_type = str(rope_scaling.get("rope_type") or rope_scaling.get("type") or "").lower() 

76 factor = float(rope_scaling.get("factor", 1.0)) 

77 

78 if scale_type in ("", "default") or factor <= 1.0: 

79 return base, 1.0, inv_freq 

80 

81 if scale_type == "linear": 

82 return base, factor, inv_freq 

83 

84 if scale_type in ("dynamic", "ntk"): 

85 scaled_base = base * (factor ** (rotary_dim / (rotary_dim - 2))) 

86 new_inv_freq = 1.0 / (scaled_base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)) 

87 return scaled_base, 1.0, new_inv_freq 

88 

89 if scale_type == "llama3": 

90 low_freq_factor = float(rope_scaling.get("low_freq_factor", 1.0)) 

91 high_freq_factor = float(rope_scaling.get("high_freq_factor", 4.0)) 

92 original_ctx = float( 

93 rope_scaling.get("original_max_position_embeddings") 

94 or rope_scaling.get("original_context_length") 

95 or 8192 

96 ) 

97 low_wavelen = original_ctx / low_freq_factor 

98 high_wavelen = original_ctx / high_freq_factor 

99 wavelens = 2 * math.pi / inv_freq 

100 # Three regimes: low-freq → divide by factor; high-freq → unchanged; 

101 # in-between → smooth linear interpolation between the two. 

102 smooth = (original_ctx / wavelens - low_freq_factor) / (high_freq_factor - low_freq_factor) 

103 new_inv_freq = torch.where( 

104 wavelens > low_wavelen, 

105 inv_freq / factor, 

106 torch.where( 

107 wavelens < high_wavelen, 

108 inv_freq, 

109 (1 - smooth) * inv_freq / factor + smooth * inv_freq, 

110 ), 

111 ) 

112 return base, 1.0, new_inv_freq 

113 

114 raise NotImplementedError( 

115 f"rope_scaling type {scale_type!r} is not supported. " 

116 f"Supported: 'linear', 'dynamic'/'ntk', 'llama3'." 

117 ) 

118 

119 

120class NativeRotary(nn.Module): 

121 """Shared cos/sin tables for RoPE. Honors ``cfg.rope_scaling``.""" 

122 

123 # Declared so mypy sees the buffer dtype; register_buffer alone reports Module|Tensor. 

124 cos_cached: torch.Tensor 

125 sin_cached: torch.Tensor 

126 

127 def __init__(self, cfg: TransformerBridgeConfig): 

128 super().__init__() 

129 rotary_dim = cfg.rotary_dim if cfg.rotary_dim is not None else cfg.d_head 

130 if rotary_dim <= 0 or rotary_dim % 2 != 0: 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true

131 raise ValueError(f"rotary_dim must be a positive even integer, got {rotary_dim!r}") 

132 self.rotary_dim = rotary_dim 

133 

134 base, position_scale, inv_freq = _resolve_rope_scaling(cfg, rotary_dim) 

135 

136 positions = torch.arange(cfg.n_ctx).float() / position_scale 

137 freqs = torch.outer(positions, inv_freq) 

138 # Llama/HF adjacent-pair format: each (2i, 2i+1) pair rotates together. 

139 cos = freqs.cos().repeat_interleave(2, dim=-1) 

140 sin = freqs.sin().repeat_interleave(2, dim=-1) 

141 self.register_buffer("cos_cached", cos, persistent=False) 

142 self.register_buffer("sin_cached", sin, persistent=False) 

143 self.effective_base = base 

144 self.position_scale = position_scale 

145 

146 @staticmethod 

147 def _rotate_half(x: torch.Tensor) -> torch.Tensor: 

148 # Llama-style adjacent-pair rotation: (x0, x1) -> (-x1, x0). 

149 x1 = x[..., 0::2] 

150 x2 = x[..., 1::2] 

151 rot = torch.stack((-x2, x1), dim=-1) 

152 return rot.flatten(-2) 

153 

154 def apply_rope( 

155 self, 

156 q: torch.Tensor, 

157 k: torch.Tensor, 

158 *, 

159 position_ids: Optional[torch.Tensor] = None, 

160 ) -> tuple[torch.Tensor, torch.Tensor]: 

161 """Apply RoPE to Q/K of shape [batch, heads, seq, d_head]. 

162 

163 Named ``apply_rope`` rather than ``apply`` so ``nn.Module.apply(fn)`` 

164 — PyTorch's recursive function-application utility used by 

165 ``bridge.apply(init_fn)`` — isn't shadowed. 

166 """ 

167 seq = q.shape[-2] 

168 rd = self.rotary_dim 

169 if position_ids is None: 169 ↛ 170line 169 didn't jump to line 170 because the condition on line 169 was never true

170 cos = self.cos_cached[:seq].to(q.dtype) 

171 sin = self.sin_cached[:seq].to(q.dtype) 

172 else: 

173 # [batch, seq] -> [batch, 1, seq, rd] (head dim for broadcast). 

174 cos = self.cos_cached[position_ids].to(q.dtype).unsqueeze(1) 

175 sin = self.sin_cached[position_ids].to(q.dtype).unsqueeze(1) 

176 

177 def _rope(x: torch.Tensor) -> torch.Tensor: 

178 x_rot, x_pass = x[..., :rd], x[..., rd:] 

179 x_rot = x_rot * cos + self._rotate_half(x_rot) * sin 

180 return torch.cat([x_rot, x_pass], dim=-1) if x_pass.shape[-1] else x_rot 

181 

182 return _rope(q), _rope(k) 

183 

184 

185class NativeAttention(nn.Module): 

186 """Split-QKV causal self-attention. Returns (out, pattern); AttentionBridge 

187 fires ``hook_pattern`` off the second element.""" 

188 

189 causal_mask: torch.Tensor 

190 

191 def __init__(self, cfg: TransformerBridgeConfig, rotary: Optional[NativeRotary] = None): 

192 super().__init__() 

193 self.cfg = cfg 

194 self.n_heads = cfg.n_heads 

195 self.d_head = cfg.d_head 

196 self.d_model = cfg.d_model 

197 self.n_kv_heads = cfg.n_key_value_heads or cfg.n_heads 

198 if self.n_heads % self.n_kv_heads != 0: 198 ↛ 199line 198 didn't jump to line 199 because the condition on line 198 was never true

199 raise ValueError( 

200 f"n_heads ({self.n_heads}) must be divisible by n_key_value_heads " 

201 f"({self.n_kv_heads}) for GQA." 

202 ) 

203 self.kv_repeats = self.n_heads // self.n_kv_heads 

204 

205 q_dim = self.n_heads * self.d_head 

206 kv_dim = self.n_kv_heads * self.d_head 

207 self.q = nn.Linear(cfg.d_model, q_dim, bias=True) 

208 self.k = nn.Linear(cfg.d_model, kv_dim, bias=True) 

209 self.v = nn.Linear(cfg.d_model, kv_dim, bias=True) 

210 self.o = nn.Linear(q_dim, cfg.d_model, bias=True) 

211 

212 mask = torch.triu(torch.ones(cfg.n_ctx, cfg.n_ctx, dtype=torch.bool), diagonal=1) 

213 self.register_buffer("causal_mask", mask, persistent=False) 

214 

215 # attn_scale=1.0 reads like "standard scaling" but is "divide by 1" — 

216 # i.e. unscaled scores, which saturate softmax for d_head>1. 

217 if cfg.use_attn_scale and cfg.attn_scale > 0: 

218 if self.d_head > 1 and math.isclose(cfg.attn_scale, 1.0, abs_tol=1e-9): 

219 raise ValueError( 

220 f"attn_scale=1.0 with d_head={self.d_head} (>1) is unscaled " 

221 f"attention; softmax will saturate. For standard scaling " 

222 f"leave attn_scale at -1 (sentinel for sqrt(d_head))." 

223 ) 

224 scale = cfg.attn_scale 

225 else: 

226 scale = math.sqrt(cfg.d_head) 

227 self.scale = scale 

228 self.rotary = rotary 

229 self.attn_scores_soft_cap = float(cfg.attn_scores_soft_cap) 

230 

231 def forward( 

232 self, 

233 hidden_states: torch.Tensor, 

234 attention_mask: Optional[torch.Tensor] = None, 

235 position_ids: Optional[torch.Tensor] = None, 

236 **kwargs, 

237 ) -> tuple[torch.Tensor, torch.Tensor]: 

238 batch, seq, _ = hidden_states.shape 

239 

240 q = self.q(hidden_states).view(batch, seq, self.n_heads, self.d_head).transpose(1, 2) 

241 k = self.k(hidden_states).view(batch, seq, self.n_kv_heads, self.d_head).transpose(1, 2) 

242 v = self.v(hidden_states).view(batch, seq, self.n_kv_heads, self.d_head).transpose(1, 2) 

243 

244 if self.rotary is not None: 

245 q, k = self.rotary.apply_rope(q, k, position_ids=position_ids) 

246 

247 # GQA: repeat_interleave matches HF Llama's repeat_kv group ordering. 

248 if self.kv_repeats > 1: 

249 k = k.repeat_interleave(self.kv_repeats, dim=1) 

250 v = v.repeat_interleave(self.kv_repeats, dim=1) 

251 

252 scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale 

253 # Gemma2 soft-cap before the causal mask so masked positions stay -inf. 

254 if self.attn_scores_soft_cap > 0: 

255 c = self.attn_scores_soft_cap 

256 scores = c * torch.tanh(scores / c) 

257 

258 block_mask = self.causal_mask[:seq, :seq] 

259 if attention_mask is not None: 

260 block_mask = self._combine_attention_mask(block_mask, attention_mask, batch=batch) 

261 scores = scores.masked_fill(block_mask, float("-inf")) 

262 

263 pattern = F.softmax(scores, dim=-1) 

264 

265 attn = torch.matmul(pattern, v).transpose(1, 2).contiguous().view(batch, seq, -1) 

266 out = self.o(attn) 

267 return out, pattern 

268 

269 @staticmethod 

270 def _combine_attention_mask( 

271 block_mask: torch.Tensor, attention_mask: torch.Tensor, *, batch: int 

272 ) -> torch.Tensor: 

273 """Combine an external attention_mask with the causal mask. 

274 

275 Accepts 2D HF padding mask ``[batch, seq]`` (1=keep, 0=mask), 4D bool 

276 mask (True=mask), or 4D additive float mask (HF generation style; values 

277 below -1 treated as masked). 

278 """ 

279 if attention_mask.dim() == 2: 

280 pad_mask = ~attention_mask.bool() 

281 return block_mask | pad_mask[:, None, None, :] 

282 if attention_mask.dim() == 4: 282 ↛ 283line 282 didn't jump to line 283 because the condition on line 282 was never true

283 if attention_mask.dtype is torch.bool: 

284 return block_mask | attention_mask 

285 # HF additive masks use -inf or large negatives; benign biases bounded. 

286 extra = attention_mask < -1.0 

287 return block_mask | extra 

288 raise ValueError( 

289 f"attention_mask must be 2D [batch, seq] or 4D [batch, *, seq, seq], " 

290 f"got shape {tuple(attention_mask.shape)}." 

291 ) 

292 

293 

294class NativeMLP(nn.Module): 

295 """Two-layer MLP with configurable activation.""" 

296 

297 act: Callable[[torch.Tensor], torch.Tensor] 

298 

299 def __init__(self, cfg: TransformerBridgeConfig): 

300 super().__init__() 

301 assert cfg.d_mlp is not None, "NativeModel resolves d_mlp before instantiating MLPs" 

302 d_mlp: int = cfg.d_mlp 

303 self.fc_in = nn.Linear(cfg.d_model, d_mlp, bias=True) 

304 self.fc_out = nn.Linear(d_mlp, cfg.d_model, bias=True) 

305 act_name = (cfg.act_fn or "gelu").lower() 

306 if act_name not in _ACTIVATIONS: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true

307 raise ValueError(f"Unsupported act_fn={act_name!r}. Supported: {sorted(_ACTIVATIONS)}") 

308 self.act = _ACTIVATIONS[act_name] 

309 

310 def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: 

311 return self.fc_out(self.act(self.fc_in(hidden_states))) 

312 

313 

314class NativeGatedMLP(nn.Module): 

315 """SwiGLU / ReGLU / GeGLU gated MLP (variant picked by ``cfg.act_fn``). 

316 

317 Submodules ``gate`` / ``in`` / ``out`` match GatedMLPBridge's expected slots. 

318 """ 

319 

320 act: Callable[[torch.Tensor], torch.Tensor] 

321 

322 def __init__(self, cfg: TransformerBridgeConfig): 

323 super().__init__() 

324 assert cfg.d_mlp is not None, "NativeModel resolves d_mlp before instantiating MLPs" 

325 d_mlp: int = cfg.d_mlp 

326 # Llama convention: no biases on gated MLP projections. 

327 self.gate = nn.Linear(cfg.d_model, d_mlp, bias=False) 

328 # ``in`` is a Python keyword; add_module + getattr(self, "in") works 

329 # because the bridge resolves LinearBridge(name="in") the same way. 

330 self.add_module("in", nn.Linear(cfg.d_model, d_mlp, bias=False)) 

331 self.out = nn.Linear(d_mlp, cfg.d_model, bias=False) 

332 # Default to SwiGLU; mirror NativeMLP's dispatch so a typo'd act_fn 

333 # raises instead of silently changing the model. 

334 act_name = (cfg.act_fn or "silu").lower() 

335 if act_name not in _ACTIVATIONS: 

336 raise ValueError(f"Unsupported act_fn={act_name!r}. Supported: {sorted(_ACTIVATIONS)}") 

337 self.act = _ACTIVATIONS[act_name] 

338 

339 def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: 

340 gate_out = self.act(self.gate(hidden_states)) 

341 in_proj = cast(nn.Linear, getattr(self, "in")) 

342 up_out = in_proj(hidden_states) 

343 return self.out(gate_out * up_out) 

344 

345 

346class NativeBlock(nn.Module): 

347 """Pre-LN transformer block. Layout adapts to ``cfg.attn_only`` and 

348 ``cfg.gated_mlp``.""" 

349 

350 def __init__(self, cfg: TransformerBridgeConfig, rotary: Optional[NativeRotary] = None): 

351 super().__init__() 

352 self.cfg = cfg 

353 self.ln1 = _make_norm(cfg) 

354 self.attn = NativeAttention(cfg, rotary=rotary) 

355 if not cfg.attn_only: 

356 self.ln2 = _make_norm(cfg) 

357 self.mlp = NativeGatedMLP(cfg) if cfg.gated_mlp else NativeMLP(cfg) 

358 

359 def forward( 

360 self, 

361 hidden_states: torch.Tensor, 

362 attention_mask: Optional[torch.Tensor] = None, 

363 position_ids: Optional[torch.Tensor] = None, 

364 **kwargs, 

365 ) -> tuple[torch.Tensor]: 

366 attn_out, _pattern = self.attn( 

367 self.ln1(hidden_states), 

368 attention_mask=attention_mask, 

369 position_ids=position_ids, 

370 ) 

371 hidden_states = hidden_states + attn_out 

372 if not self.cfg.attn_only: 

373 hidden_states = hidden_states + self.mlp(self.ln2(hidden_states)) 

374 # Tuple return matches HF block convention; BlockBridge's parser expects it. 

375 return (hidden_states,) 

376 

377 

378class NativeModel(nn.Module): 

379 """TL-native transformer. See module docstring for the supported feature set.""" 

380 

381 pos: Optional[nn.Embedding] 

382 rotary: Optional[NativeRotary] 

383 

384 def __init__(self, cfg: TransformerBridgeConfig): 

385 super().__init__() 

386 # Write the resolved d_mlp back so downstream consumers see the real 

387 # value, not None. Mutates cfg; isolating callers should deep-copy first. 

388 if not getattr(cfg, "d_mlp", None): 388 ↛ 389line 388 didn't jump to line 389 because the condition on line 388 was never true

389 cfg.d_mlp = 4 * cfg.d_model 

390 self.cfg = cfg 

391 

392 self.tok_embed = nn.Embedding(cfg.d_vocab, cfg.d_model) 

393 

394 kind = _positional_kind(cfg) 

395 if kind == "standard": 

396 self.pos = nn.Embedding(cfg.n_ctx, cfg.d_model) 

397 self.rotary = None 

398 elif kind == "rotary": 398 ↛ 402line 398 didn't jump to line 402 because the condition on line 398 was always true

399 self.pos = None 

400 self.rotary = NativeRotary(cfg) 

401 else: 

402 raise ValueError( 

403 f"Unsupported positional_embedding_type={kind!r}. " 

404 f"NativeModel supports 'standard' and 'rotary'." 

405 ) 

406 

407 self.layers = nn.ModuleList( 

408 [NativeBlock(cfg, rotary=self.rotary) for _ in range(cfg.n_layers)] 

409 ) 

410 # final_rms forces RMS on the final norm regardless of block-norm choice 

411 # — matches the TL config semantic Llama uses. 

412 self.ln_out = _make_norm(cfg, force_rms=cfg.final_rms) 

413 d_vocab_out = cfg.d_vocab_out if cfg.d_vocab_out > 0 else cfg.d_vocab 

414 self.head = nn.Linear(cfg.d_model, d_vocab_out, bias=False) 

415 self.output_logits_soft_cap = float(cfg.output_logits_soft_cap) 

416 

417 def forward( 

418 self, 

419 input_ids: torch.Tensor, 

420 attention_mask: Optional[torch.Tensor] = None, 

421 position_ids: Optional[torch.Tensor] = None, 

422 **kwargs, 

423 ) -> torch.Tensor: 

424 """Returns logits directly.""" 

425 # Bounds check up front so both absolute and rotary paths produce a 

426 # self-explanatory error rather than IndexError / shape mismatch. 

427 seq_len = input_ids.shape[-1] 

428 if seq_len > self.cfg.n_ctx: 

429 raise ValueError( 

430 f"input length {seq_len} exceeds n_ctx={self.cfg.n_ctx}; " 

431 f"position embeddings and rotary tables are pre-baked at n_ctx." 

432 ) 

433 

434 # Resolve position_ids before the block loop so rotary sees the caller's 

435 # positions, not the dense default. 

436 batch, seq = input_ids.shape 

437 if position_ids is None: 

438 position_ids = torch.arange(seq, device=input_ids.device).unsqueeze(0).expand(batch, -1) 

439 

440 hidden_states = self.tok_embed(input_ids) 

441 if self.pos is not None: 

442 hidden_states = hidden_states + self.pos(position_ids) 

443 

444 for block in self.layers: 

445 (hidden_states,) = block( 

446 hidden_states, attention_mask=attention_mask, position_ids=position_ids 

447 ) 

448 hidden_states = self.ln_out(hidden_states) 

449 logits = self.head(hidden_states) 

450 if self.output_logits_soft_cap > 0: 

451 c = self.output_logits_soft_cap 

452 logits = c * torch.tanh(logits / c) 

453 return logits