transformer_lens.model_bridge.sources.native.init module

Weight init for NativeModel.

Supported modes: "gpt2" (Normal(0, std) with 1/sqrt(2*n_layers) residual scaling on output projections), "xavier_uniform" / "xavier_normal", "kaiming_uniform" / "kaiming_normal" (relu nonlinearity). Norm weights go to 1, all biases to 0.

Determinism uses a scoped torch.Generator, not torch.manual_seed, so seeded init does not perturb the caller’s global RNG.

transformer_lens.model_bridge.sources.native.init.initialize_native_model(model: NativeModel, cfg: TransformerBridgeConfig, seed: int | None = None) None

Initialize model weights in-place. Honors cfg.init_mode and cfg.seed.