Coverage for transformer_lens/model_bridge/sources/native/init.py: 95%

72 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Weight init for NativeModel. 

2 

3Supported modes: ``"gpt2"`` (Normal(0, std) with 1/sqrt(2*n_layers) residual 

4scaling on output projections), ``"xavier_uniform"`` / ``"xavier_normal"``, 

5``"kaiming_uniform"`` / ``"kaiming_normal"`` (relu nonlinearity). Norm weights 

6go to 1, all biases to 0. 

7 

8Determinism uses a scoped ``torch.Generator``, not ``torch.manual_seed``, so 

9seeded init does not perturb the caller's global RNG. 

10""" 

11from __future__ import annotations 

12 

13import math 

14from typing import Callable, Optional, cast 

15 

16import torch 

17import torch.nn as nn 

18 

19from transformer_lens.config import TransformerBridgeConfig 

20 

21from .model import ( 

22 NativeAttention, 

23 NativeBlock, 

24 NativeGatedMLP, 

25 NativeMLP, 

26 NativeModel, 

27 NativeRMSNorm, 

28) 

29 

30# Residual-scaled output is gpt2-specific; other modes treat every weight the 

31# same. Each entry takes ``(tensor, generator)`` to thread the scoped Generator. 

32_NonResidualInit = Callable[[torch.Tensor, Optional[torch.Generator]], torch.Tensor] 

33_NON_RESIDUAL_MODES: dict[str, _NonResidualInit] = { 

34 "xavier_uniform": lambda t, g: nn.init.xavier_uniform_(t, generator=g), 

35 "xavier_normal": lambda t, g: nn.init.xavier_normal_(t, generator=g), 

36 "kaiming_uniform": lambda t, g: nn.init.kaiming_uniform_(t, nonlinearity="relu", generator=g), 

37 "kaiming_normal": lambda t, g: nn.init.kaiming_normal_(t, nonlinearity="relu", generator=g), 

38} 

39 

40_SUPPORTED_MODES = frozenset({"gpt2", *_NON_RESIDUAL_MODES}) 

41 

42 

43def initialize_native_model( 

44 model: NativeModel, cfg: TransformerBridgeConfig, seed: int | None = None 

45) -> None: 

46 """Initialize ``model`` weights in-place. Honors ``cfg.init_mode`` and ``cfg.seed``.""" 

47 effective_seed = seed if seed is not None else cfg.seed 

48 

49 # Scoped generator on the model's device — None falls back to the global RNG. 

50 try: 

51 gen_device = next(model.parameters()).device 

52 except StopIteration: 

53 gen_device = torch.device("cpu") 

54 generator: Optional[torch.Generator] 

55 if effective_seed is not None: 

56 g = torch.Generator(device=gen_device) 

57 g.manual_seed(effective_seed) 

58 generator = g 

59 else: 

60 generator = None 

61 

62 init_mode = (cfg.init_mode or "gpt2").lower() 

63 if init_mode not in _SUPPORTED_MODES: 

64 raise NotImplementedError( 

65 f"init_mode={init_mode!r} is not supported for NativeModel. " 

66 f"Supported modes: {sorted(_SUPPORTED_MODES)}." 

67 ) 

68 

69 weight_init: Callable[[torch.Tensor], torch.Tensor] 

70 output_init: Callable[[torch.Tensor], torch.Tensor] 

71 if init_mode == "gpt2": 

72 std = cfg.initializer_range if cfg.initializer_range > 0 else 0.02 

73 residual_scale = 1.0 / math.sqrt(2 * cfg.n_layers) 

74 weight_init = lambda t: nn.init.normal_( 

75 t, mean=0.0, std=std, generator=generator 

76 ) # noqa: E731 

77 output_init = lambda t: nn.init.normal_( # noqa: E731 

78 t, mean=0.0, std=std * residual_scale, generator=generator 

79 ) 

80 else: 

81 fn = _NON_RESIDUAL_MODES[init_mode] 

82 weight_init = lambda t: fn(t, generator) # noqa: E731 

83 output_init = weight_init 

84 

85 weight_init(model.tok_embed.weight) 

86 if model.pos is not None: 

87 weight_init(model.pos.weight) 

88 # Rotary has only registered buffers (cos/sin), no parameters to init. 

89 

90 for block in model.layers: 

91 _init_block(block, weight_init=weight_init, output_init=output_init) 

92 

93 _init_norm(model.ln_out) 

94 weight_init(model.head.weight) 

95 

96 

97def _init_norm(norm: nn.Module) -> None: 

98 if isinstance(norm, NativeRMSNorm): 

99 nn.init.ones_(norm.weight) 

100 elif isinstance(norm, nn.LayerNorm): 100 ↛ 104line 100 didn't jump to line 104 because the condition on line 100 was always true

101 nn.init.ones_(norm.weight) 

102 nn.init.zeros_(norm.bias) 

103 else: 

104 raise TypeError(f"Unknown normalization type: {type(norm).__name__}") 

105 

106 

107def _init_block( 

108 block: NativeBlock, 

109 *, 

110 weight_init: Callable[[torch.Tensor], torch.Tensor], 

111 output_init: Callable[[torch.Tensor], torch.Tensor], 

112) -> None: 

113 _init_norm(block.ln1) 

114 _init_attention(block.attn, weight_init=weight_init, output_init=output_init) 

115 if not block.cfg.attn_only: 

116 _init_norm(block.ln2) 

117 if isinstance(block.mlp, NativeGatedMLP): 

118 _init_gated_mlp(block.mlp, weight_init=weight_init, output_init=output_init) 

119 else: 

120 _init_mlp(block.mlp, weight_init=weight_init, output_init=output_init) 

121 

122 

123def _init_attention( 

124 attn: NativeAttention, 

125 *, 

126 weight_init: Callable[[torch.Tensor], torch.Tensor], 

127 output_init: Callable[[torch.Tensor], torch.Tensor], 

128) -> None: 

129 for linear in (attn.q, attn.k, attn.v): 

130 weight_init(linear.weight) 

131 if linear.bias is not None: 131 ↛ 129line 131 didn't jump to line 129 because the condition on line 131 was always true

132 nn.init.zeros_(linear.bias) 

133 output_init(attn.o.weight) 

134 if attn.o.bias is not None: 134 ↛ exitline 134 didn't return from function '_init_attention' because the condition on line 134 was always true

135 nn.init.zeros_(attn.o.bias) 

136 

137 

138def _init_mlp( 

139 mlp: NativeMLP, 

140 *, 

141 weight_init: Callable[[torch.Tensor], torch.Tensor], 

142 output_init: Callable[[torch.Tensor], torch.Tensor], 

143) -> None: 

144 weight_init(mlp.fc_in.weight) 

145 nn.init.zeros_(mlp.fc_in.bias) 

146 output_init(mlp.fc_out.weight) 

147 nn.init.zeros_(mlp.fc_out.bias) 

148 

149 

150def _init_gated_mlp( 

151 mlp: NativeGatedMLP, 

152 *, 

153 weight_init: Callable[[torch.Tensor], torch.Tensor], 

154 output_init: Callable[[torch.Tensor], torch.Tensor], 

155) -> None: 

156 weight_init(mlp.gate.weight) 

157 # ``in`` is registered via add_module; getattr resolves it from _modules. 

158 in_proj = cast(nn.Linear, getattr(mlp, "in")) 

159 weight_init(in_proj.weight) 

160 output_init(mlp.out.weight)