Coverage for transformer_lens/model_bridge/sources/native/init.py: 95%
74 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Weight init for NativeModel.
3Supported modes: ``"gpt2"`` (Normal(0, std) with 1/sqrt(2*n_layers) residual
4scaling on output projections), ``"xavier_uniform"`` / ``"xavier_normal"``,
5``"kaiming_uniform"`` / ``"kaiming_normal"`` (relu nonlinearity). Norm weights
6go to 1, all biases to 0.
8Determinism uses a scoped ``torch.Generator``, not ``torch.manual_seed``, so
9seeded init does not perturb the caller's global RNG.
10"""
12from __future__ import annotations
14import math
15from typing import Callable, Optional, cast
17import torch
18import torch.nn as nn
20from transformer_lens.config import TransformerBridgeConfig
22from .model import (
23 NativeAttention,
24 NativeBlock,
25 NativeGatedMLP,
26 NativeMLP,
27 NativeModel,
28 NativeRMSNorm,
29)
31# Residual-scaled output is gpt2-specific; other modes treat every weight the
32# same. Each entry takes ``(tensor, generator)`` to thread the scoped Generator.
33_NonResidualInit = Callable[[torch.Tensor, Optional[torch.Generator]], torch.Tensor]
34_NON_RESIDUAL_MODES: dict[str, _NonResidualInit] = {
35 "xavier_uniform": lambda t, g: nn.init.xavier_uniform_(t, generator=g),
36 "xavier_normal": lambda t, g: nn.init.xavier_normal_(t, generator=g),
37 "kaiming_uniform": lambda t, g: nn.init.kaiming_uniform_(t, nonlinearity="relu", generator=g),
38 "kaiming_normal": lambda t, g: nn.init.kaiming_normal_(t, nonlinearity="relu", generator=g),
39}
41_SUPPORTED_MODES = frozenset({"gpt2", *_NON_RESIDUAL_MODES})
44def initialize_native_model(
45 model: NativeModel, cfg: TransformerBridgeConfig, seed: int | None = None
46) -> None:
47 """Initialize ``model`` weights in-place. Honors ``cfg.init_mode`` and ``cfg.seed``."""
48 effective_seed = seed if seed is not None else cfg.seed
50 # Scoped generator on the model's device — None falls back to the global RNG.
51 try:
52 gen_device = next(model.parameters()).device
53 except StopIteration:
54 gen_device = torch.device("cpu")
55 generator: Optional[torch.Generator]
56 if effective_seed is not None:
57 g = torch.Generator(device=gen_device)
58 g.manual_seed(effective_seed)
59 generator = g
60 else:
61 generator = None
63 init_mode = (cfg.init_mode or "gpt2").lower()
64 if init_mode not in _SUPPORTED_MODES:
65 raise NotImplementedError(
66 f"init_mode={init_mode!r} is not supported for NativeModel. "
67 f"Supported modes: {sorted(_SUPPORTED_MODES)}."
68 )
70 weight_init: Callable[[torch.Tensor], torch.Tensor]
71 output_init: Callable[[torch.Tensor], torch.Tensor]
72 if init_mode == "gpt2":
73 std = cfg.initializer_range if cfg.initializer_range > 0 else 0.02
74 residual_scale = 1.0 / math.sqrt(2 * cfg.n_layers)
75 weight_init = lambda t: nn.init.normal_(
76 t, mean=0.0, std=std, generator=generator
77 ) # noqa: E731
78 output_init = lambda t: nn.init.normal_( # noqa: E731
79 t, mean=0.0, std=std * residual_scale, generator=generator
80 )
81 else:
82 fn = _NON_RESIDUAL_MODES[init_mode]
83 weight_init = lambda t: fn(t, generator) # noqa: E731
84 output_init = weight_init
86 weight_init(model.tok_embed.weight)
87 if model.pos is not None:
88 weight_init(model.pos.weight)
89 # Rotary has only registered buffers (cos/sin), no parameters to init.
91 for block in model.layers:
92 _init_block(block, weight_init=weight_init, output_init=output_init)
94 _init_norm(model.ln_out)
95 weight_init(model.head.weight)
98def _init_norm(norm: nn.Module) -> None:
99 if isinstance(norm, NativeRMSNorm):
100 nn.init.ones_(norm.weight)
101 elif isinstance(norm, nn.LayerNorm):
102 nn.init.ones_(norm.weight)
103 nn.init.zeros_(norm.bias)
104 elif isinstance(norm, nn.Identity): 104 ↛ 107line 104 didn't jump to line 107 because the condition on line 104 was always true
105 pass
106 else:
107 raise TypeError(f"Unknown normalization type: {type(norm).__name__}")
110def _init_block(
111 block: NativeBlock,
112 *,
113 weight_init: Callable[[torch.Tensor], torch.Tensor],
114 output_init: Callable[[torch.Tensor], torch.Tensor],
115) -> None:
116 _init_norm(block.ln1)
117 _init_attention(block.attn, weight_init=weight_init, output_init=output_init)
118 if not block.cfg.attn_only:
119 _init_norm(block.ln2)
120 if isinstance(block.mlp, NativeGatedMLP):
121 _init_gated_mlp(block.mlp, weight_init=weight_init, output_init=output_init)
122 else:
123 _init_mlp(block.mlp, weight_init=weight_init, output_init=output_init)
126def _init_attention(
127 attn: NativeAttention,
128 *,
129 weight_init: Callable[[torch.Tensor], torch.Tensor],
130 output_init: Callable[[torch.Tensor], torch.Tensor],
131) -> None:
132 for linear in (attn.q, attn.k, attn.v):
133 weight_init(linear.weight)
134 if linear.bias is not None: 134 ↛ 132line 134 didn't jump to line 132 because the condition on line 134 was always true
135 nn.init.zeros_(linear.bias)
136 output_init(attn.o.weight)
137 if attn.o.bias is not None: 137 ↛ exitline 137 didn't return from function '_init_attention' because the condition on line 137 was always true
138 nn.init.zeros_(attn.o.bias)
141def _init_mlp(
142 mlp: NativeMLP,
143 *,
144 weight_init: Callable[[torch.Tensor], torch.Tensor],
145 output_init: Callable[[torch.Tensor], torch.Tensor],
146) -> None:
147 weight_init(mlp.fc_in.weight)
148 nn.init.zeros_(mlp.fc_in.bias)
149 output_init(mlp.fc_out.weight)
150 nn.init.zeros_(mlp.fc_out.bias)
153def _init_gated_mlp(
154 mlp: NativeGatedMLP,
155 *,
156 weight_init: Callable[[torch.Tensor], torch.Tensor],
157 output_init: Callable[[torch.Tensor], torch.Tensor],
158) -> None:
159 weight_init(mlp.gate.weight)
160 # ``in`` is registered via add_module; getattr resolves it from _modules.
161 in_proj = cast(nn.Linear, getattr(mlp, "in"))
162 weight_init(in_proj.weight)
163 output_init(mlp.out.weight)