Coverage for transformer_lens/model_bridge/sources/native/init.py: 95%
72 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Weight init for NativeModel.
3Supported modes: ``"gpt2"`` (Normal(0, std) with 1/sqrt(2*n_layers) residual
4scaling on output projections), ``"xavier_uniform"`` / ``"xavier_normal"``,
5``"kaiming_uniform"`` / ``"kaiming_normal"`` (relu nonlinearity). Norm weights
6go to 1, all biases to 0.
8Determinism uses a scoped ``torch.Generator``, not ``torch.manual_seed``, so
9seeded init does not perturb the caller's global RNG.
10"""
11from __future__ import annotations
13import math
14from typing import Callable, Optional, cast
16import torch
17import torch.nn as nn
19from transformer_lens.config import TransformerBridgeConfig
21from .model import (
22 NativeAttention,
23 NativeBlock,
24 NativeGatedMLP,
25 NativeMLP,
26 NativeModel,
27 NativeRMSNorm,
28)
30# Residual-scaled output is gpt2-specific; other modes treat every weight the
31# same. Each entry takes ``(tensor, generator)`` to thread the scoped Generator.
32_NonResidualInit = Callable[[torch.Tensor, Optional[torch.Generator]], torch.Tensor]
33_NON_RESIDUAL_MODES: dict[str, _NonResidualInit] = {
34 "xavier_uniform": lambda t, g: nn.init.xavier_uniform_(t, generator=g),
35 "xavier_normal": lambda t, g: nn.init.xavier_normal_(t, generator=g),
36 "kaiming_uniform": lambda t, g: nn.init.kaiming_uniform_(t, nonlinearity="relu", generator=g),
37 "kaiming_normal": lambda t, g: nn.init.kaiming_normal_(t, nonlinearity="relu", generator=g),
38}
40_SUPPORTED_MODES = frozenset({"gpt2", *_NON_RESIDUAL_MODES})
43def initialize_native_model(
44 model: NativeModel, cfg: TransformerBridgeConfig, seed: int | None = None
45) -> None:
46 """Initialize ``model`` weights in-place. Honors ``cfg.init_mode`` and ``cfg.seed``."""
47 effective_seed = seed if seed is not None else cfg.seed
49 # Scoped generator on the model's device — None falls back to the global RNG.
50 try:
51 gen_device = next(model.parameters()).device
52 except StopIteration:
53 gen_device = torch.device("cpu")
54 generator: Optional[torch.Generator]
55 if effective_seed is not None:
56 g = torch.Generator(device=gen_device)
57 g.manual_seed(effective_seed)
58 generator = g
59 else:
60 generator = None
62 init_mode = (cfg.init_mode or "gpt2").lower()
63 if init_mode not in _SUPPORTED_MODES:
64 raise NotImplementedError(
65 f"init_mode={init_mode!r} is not supported for NativeModel. "
66 f"Supported modes: {sorted(_SUPPORTED_MODES)}."
67 )
69 weight_init: Callable[[torch.Tensor], torch.Tensor]
70 output_init: Callable[[torch.Tensor], torch.Tensor]
71 if init_mode == "gpt2":
72 std = cfg.initializer_range if cfg.initializer_range > 0 else 0.02
73 residual_scale = 1.0 / math.sqrt(2 * cfg.n_layers)
74 weight_init = lambda t: nn.init.normal_(
75 t, mean=0.0, std=std, generator=generator
76 ) # noqa: E731
77 output_init = lambda t: nn.init.normal_( # noqa: E731
78 t, mean=0.0, std=std * residual_scale, generator=generator
79 )
80 else:
81 fn = _NON_RESIDUAL_MODES[init_mode]
82 weight_init = lambda t: fn(t, generator) # noqa: E731
83 output_init = weight_init
85 weight_init(model.tok_embed.weight)
86 if model.pos is not None:
87 weight_init(model.pos.weight)
88 # Rotary has only registered buffers (cos/sin), no parameters to init.
90 for block in model.layers:
91 _init_block(block, weight_init=weight_init, output_init=output_init)
93 _init_norm(model.ln_out)
94 weight_init(model.head.weight)
97def _init_norm(norm: nn.Module) -> None:
98 if isinstance(norm, NativeRMSNorm):
99 nn.init.ones_(norm.weight)
100 elif isinstance(norm, nn.LayerNorm): 100 ↛ 104line 100 didn't jump to line 104 because the condition on line 100 was always true
101 nn.init.ones_(norm.weight)
102 nn.init.zeros_(norm.bias)
103 else:
104 raise TypeError(f"Unknown normalization type: {type(norm).__name__}")
107def _init_block(
108 block: NativeBlock,
109 *,
110 weight_init: Callable[[torch.Tensor], torch.Tensor],
111 output_init: Callable[[torch.Tensor], torch.Tensor],
112) -> None:
113 _init_norm(block.ln1)
114 _init_attention(block.attn, weight_init=weight_init, output_init=output_init)
115 if not block.cfg.attn_only:
116 _init_norm(block.ln2)
117 if isinstance(block.mlp, NativeGatedMLP):
118 _init_gated_mlp(block.mlp, weight_init=weight_init, output_init=output_init)
119 else:
120 _init_mlp(block.mlp, weight_init=weight_init, output_init=output_init)
123def _init_attention(
124 attn: NativeAttention,
125 *,
126 weight_init: Callable[[torch.Tensor], torch.Tensor],
127 output_init: Callable[[torch.Tensor], torch.Tensor],
128) -> None:
129 for linear in (attn.q, attn.k, attn.v):
130 weight_init(linear.weight)
131 if linear.bias is not None: 131 ↛ 129line 131 didn't jump to line 129 because the condition on line 131 was always true
132 nn.init.zeros_(linear.bias)
133 output_init(attn.o.weight)
134 if attn.o.bias is not None: 134 ↛ exitline 134 didn't return from function '_init_attention' because the condition on line 134 was always true
135 nn.init.zeros_(attn.o.bias)
138def _init_mlp(
139 mlp: NativeMLP,
140 *,
141 weight_init: Callable[[torch.Tensor], torch.Tensor],
142 output_init: Callable[[torch.Tensor], torch.Tensor],
143) -> None:
144 weight_init(mlp.fc_in.weight)
145 nn.init.zeros_(mlp.fc_in.bias)
146 output_init(mlp.fc_out.weight)
147 nn.init.zeros_(mlp.fc_out.bias)
150def _init_gated_mlp(
151 mlp: NativeGatedMLP,
152 *,
153 weight_init: Callable[[torch.Tensor], torch.Tensor],
154 output_init: Callable[[torch.Tensor], torch.Tensor],
155) -> None:
156 weight_init(mlp.gate.weight)
157 # ``in`` is registered via add_module; getattr resolves it from _modules.
158 in_proj = cast(nn.Linear, getattr(mlp, "in"))
159 weight_init(in_proj.weight)
160 output_init(mlp.out.weight)