Coverage for transformer_lens/model_bridge/sources/native/init.py: 95%

74 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Weight init for NativeModel. 

2 

3Supported modes: ``"gpt2"`` (Normal(0, std) with 1/sqrt(2*n_layers) residual 

4scaling on output projections), ``"xavier_uniform"`` / ``"xavier_normal"``, 

5``"kaiming_uniform"`` / ``"kaiming_normal"`` (relu nonlinearity). Norm weights 

6go to 1, all biases to 0. 

7 

8Determinism uses a scoped ``torch.Generator``, not ``torch.manual_seed``, so 

9seeded init does not perturb the caller's global RNG. 

10""" 

11 

12from __future__ import annotations 

13 

14import math 

15from typing import Callable, Optional, cast 

16 

17import torch 

18import torch.nn as nn 

19 

20from transformer_lens.config import TransformerBridgeConfig 

21 

22from .model import ( 

23 NativeAttention, 

24 NativeBlock, 

25 NativeGatedMLP, 

26 NativeMLP, 

27 NativeModel, 

28 NativeRMSNorm, 

29) 

30 

31# Residual-scaled output is gpt2-specific; other modes treat every weight the 

32# same. Each entry takes ``(tensor, generator)`` to thread the scoped Generator. 

33_NonResidualInit = Callable[[torch.Tensor, Optional[torch.Generator]], torch.Tensor] 

34_NON_RESIDUAL_MODES: dict[str, _NonResidualInit] = { 

35 "xavier_uniform": lambda t, g: nn.init.xavier_uniform_(t, generator=g), 

36 "xavier_normal": lambda t, g: nn.init.xavier_normal_(t, generator=g), 

37 "kaiming_uniform": lambda t, g: nn.init.kaiming_uniform_(t, nonlinearity="relu", generator=g), 

38 "kaiming_normal": lambda t, g: nn.init.kaiming_normal_(t, nonlinearity="relu", generator=g), 

39} 

40 

41_SUPPORTED_MODES = frozenset({"gpt2", *_NON_RESIDUAL_MODES}) 

42 

43 

44def initialize_native_model( 

45 model: NativeModel, cfg: TransformerBridgeConfig, seed: int | None = None 

46) -> None: 

47 """Initialize ``model`` weights in-place. Honors ``cfg.init_mode`` and ``cfg.seed``.""" 

48 effective_seed = seed if seed is not None else cfg.seed 

49 

50 # Scoped generator on the model's device — None falls back to the global RNG. 

51 try: 

52 gen_device = next(model.parameters()).device 

53 except StopIteration: 

54 gen_device = torch.device("cpu") 

55 generator: Optional[torch.Generator] 

56 if effective_seed is not None: 

57 g = torch.Generator(device=gen_device) 

58 g.manual_seed(effective_seed) 

59 generator = g 

60 else: 

61 generator = None 

62 

63 init_mode = (cfg.init_mode or "gpt2").lower() 

64 if init_mode not in _SUPPORTED_MODES: 

65 raise NotImplementedError( 

66 f"init_mode={init_mode!r} is not supported for NativeModel. " 

67 f"Supported modes: {sorted(_SUPPORTED_MODES)}." 

68 ) 

69 

70 weight_init: Callable[[torch.Tensor], torch.Tensor] 

71 output_init: Callable[[torch.Tensor], torch.Tensor] 

72 if init_mode == "gpt2": 

73 std = cfg.initializer_range if cfg.initializer_range > 0 else 0.02 

74 residual_scale = 1.0 / math.sqrt(2 * cfg.n_layers) 

75 weight_init = lambda t: nn.init.normal_( 

76 t, mean=0.0, std=std, generator=generator 

77 ) # noqa: E731 

78 output_init = lambda t: nn.init.normal_( # noqa: E731 

79 t, mean=0.0, std=std * residual_scale, generator=generator 

80 ) 

81 else: 

82 fn = _NON_RESIDUAL_MODES[init_mode] 

83 weight_init = lambda t: fn(t, generator) # noqa: E731 

84 output_init = weight_init 

85 

86 weight_init(model.tok_embed.weight) 

87 if model.pos is not None: 

88 weight_init(model.pos.weight) 

89 # Rotary has only registered buffers (cos/sin), no parameters to init. 

90 

91 for block in model.layers: 

92 _init_block(block, weight_init=weight_init, output_init=output_init) 

93 

94 _init_norm(model.ln_out) 

95 weight_init(model.head.weight) 

96 

97 

98def _init_norm(norm: nn.Module) -> None: 

99 if isinstance(norm, NativeRMSNorm): 

100 nn.init.ones_(norm.weight) 

101 elif isinstance(norm, nn.LayerNorm): 

102 nn.init.ones_(norm.weight) 

103 nn.init.zeros_(norm.bias) 

104 elif isinstance(norm, nn.Identity): 104 ↛ 107line 104 didn't jump to line 107 because the condition on line 104 was always true

105 pass 

106 else: 

107 raise TypeError(f"Unknown normalization type: {type(norm).__name__}") 

108 

109 

110def _init_block( 

111 block: NativeBlock, 

112 *, 

113 weight_init: Callable[[torch.Tensor], torch.Tensor], 

114 output_init: Callable[[torch.Tensor], torch.Tensor], 

115) -> None: 

116 _init_norm(block.ln1) 

117 _init_attention(block.attn, weight_init=weight_init, output_init=output_init) 

118 if not block.cfg.attn_only: 

119 _init_norm(block.ln2) 

120 if isinstance(block.mlp, NativeGatedMLP): 

121 _init_gated_mlp(block.mlp, weight_init=weight_init, output_init=output_init) 

122 else: 

123 _init_mlp(block.mlp, weight_init=weight_init, output_init=output_init) 

124 

125 

126def _init_attention( 

127 attn: NativeAttention, 

128 *, 

129 weight_init: Callable[[torch.Tensor], torch.Tensor], 

130 output_init: Callable[[torch.Tensor], torch.Tensor], 

131) -> None: 

132 for linear in (attn.q, attn.k, attn.v): 

133 weight_init(linear.weight) 

134 if linear.bias is not None: 134 ↛ 132line 134 didn't jump to line 132 because the condition on line 134 was always true

135 nn.init.zeros_(linear.bias) 

136 output_init(attn.o.weight) 

137 if attn.o.bias is not None: 137 ↛ exitline 137 didn't return from function '_init_attention' because the condition on line 137 was always true

138 nn.init.zeros_(attn.o.bias) 

139 

140 

141def _init_mlp( 

142 mlp: NativeMLP, 

143 *, 

144 weight_init: Callable[[torch.Tensor], torch.Tensor], 

145 output_init: Callable[[torch.Tensor], torch.Tensor], 

146) -> None: 

147 weight_init(mlp.fc_in.weight) 

148 nn.init.zeros_(mlp.fc_in.bias) 

149 output_init(mlp.fc_out.weight) 

150 nn.init.zeros_(mlp.fc_out.bias) 

151 

152 

153def _init_gated_mlp( 

154 mlp: NativeGatedMLP, 

155 *, 

156 weight_init: Callable[[torch.Tensor], torch.Tensor], 

157 output_init: Callable[[torch.Tensor], torch.Tensor], 

158) -> None: 

159 weight_init(mlp.gate.weight) 

160 # ``in`` is registered via add_module; getattr resolves it from _modules. 

161 in_proj = cast(nn.Linear, getattr(mlp, "in")) 

162 weight_init(in_proj.weight) 

163 output_init(mlp.out.weight)