Coverage for transformer_lens/model_bridge/supported_architectures/native.py: 98%
45 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Architecture adapter for TL-native models built via ``boot_native``.
3Component mapping adapts to cfg: gated MLP → ``GatedMLPBridge``, RMS norm →
4``RMSNormalizationBridge``, rotary drops ``pos_embed``, ``attn_only`` drops MLP.
5"""
6from typing import Any
8from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
9from transformer_lens.model_bridge.generalized_components import (
10 AttentionBridge,
11 BlockBridge,
12 EmbeddingBridge,
13 GatedMLPBridge,
14 LinearBridge,
15 MLPBridge,
16 NormalizationBridge,
17 PosEmbedBridge,
18 RMSNormalizationBridge,
19 UnembeddingBridge,
20)
23def _uses_rms(cfg: Any) -> bool:
24 return (getattr(cfg, "normalization_type", None) or "LN").upper() in ("RMS", "RMSPRE")
27def _is_rotary(cfg: Any) -> bool:
28 return (getattr(cfg, "positional_embedding_type", None) or "standard").lower() == "rotary"
31def _make_norm_bridge(name: str, cfg: Any, *, force_rms: bool = False):
32 if force_rms or _uses_rms(cfg):
33 return RMSNormalizationBridge(name=name, config=cfg)
34 return NormalizationBridge(name=name, config=cfg)
37def _make_mlp_bridge(cfg: Any):
38 if cfg.gated_mlp:
39 return GatedMLPBridge(
40 name="mlp",
41 config=cfg,
42 submodules={
43 "gate": LinearBridge(name="gate"),
44 "in": LinearBridge(name="in"),
45 "out": LinearBridge(name="out"),
46 },
47 )
48 return MLPBridge(
49 name="mlp",
50 submodules={
51 "in": LinearBridge(name="fc_in"),
52 "out": LinearBridge(name="fc_out"),
53 },
54 )
57def _make_block_submodules(cfg: Any) -> dict:
58 submods: dict = {
59 "ln1": _make_norm_bridge("ln1", cfg),
60 "attn": AttentionBridge(
61 name="attn",
62 config=cfg,
63 submodules={
64 "q": LinearBridge(name="q"),
65 "k": LinearBridge(name="k"),
66 "v": LinearBridge(name="v"),
67 "o": LinearBridge(name="o"),
68 },
69 ),
70 }
71 if not cfg.attn_only:
72 submods["ln2"] = _make_norm_bridge("ln2", cfg)
73 submods["mlp"] = _make_mlp_bridge(cfg)
74 return submods
77class NativeArchitectureAdapter(ArchitectureAdapter):
78 """Adapter for ``NativeModel`` — TL-native, split-QKV, pre-LN; feature set
79 driven by cfg (gated MLP, RMS norm, rotary, GQA, soft-cap, attn_only)."""
81 def __init__(self, cfg: Any) -> None:
82 super().__init__(cfg)
84 # Native layout already stores Q/K/V split; no rearranges needed.
85 # Compatibility-mode fold_ln / center_writing_weights aren't wired up,
86 # so gate the corresponding ProcessWeights paths off — folding without
87 # the state-dict conversions would mis-place or drop weights.
88 self.supports_fold_ln = False
89 self.supports_center_writing_weights = False
90 self.weight_processing_conversions = {}
92 # Internal attribute names avoid collisions with bridge slot names
93 # ("embed", "blocks", "ln_final", "unembed") — the bridge's __getattr__
94 # forwards to original_model and would shadow add_module otherwise.
95 mapping: dict = {
96 "embed": EmbeddingBridge(name="tok_embed"),
97 }
98 if not _is_rotary(cfg):
99 mapping["pos_embed"] = PosEmbedBridge(name="pos")
100 block_bridge = BlockBridge(
101 name="layers",
102 config=self.cfg,
103 submodules=_make_block_submodules(self.cfg),
104 )
105 # Under attn_only there's no ln2 / mlp to point at; drop the aliases
106 # that would otherwise warn during _register_aliases.
107 if self.cfg.attn_only:
108 if block_bridge.hook_aliases is BlockBridge.hook_aliases: 108 ↛ 110line 108 didn't jump to line 110 because the condition on line 108 was always true
109 block_bridge.hook_aliases = dict(block_bridge.hook_aliases)
110 block_bridge.hook_aliases.pop("hook_resid_mid", None)
111 block_bridge.hook_aliases.pop("hook_mlp_out", None)
112 mapping["blocks"] = block_bridge
113 # final_rms forces RMS on the final norm independent of block norm —
114 # matches Llama's TL config semantic.
115 mapping["ln_final"] = _make_norm_bridge(
116 "ln_out", self.cfg, force_rms=bool(getattr(self.cfg, "final_rms", False))
117 )
118 mapping["unembed"] = UnembeddingBridge(name="head")
119 self.component_mapping = mapping
121 def prepare_model(self, model: Any) -> None:
122 """Reject modules whose attribute names collide with bridge slots.
124 Bridge's ``__getattr__`` falls back to ``getattr(original_model, name)``
125 for unknown attrs, so a name match — submodule, buffer, plain tensor,
126 or property — makes ``add_module`` raise mid-setup with an opaque
127 message. Failing here points at the real cause. Reserved set is derived
128 from ``component_mapping.keys()`` so adapter variants stay in sync.
129 """
130 reserved = set(self.component_mapping.keys()) if self.component_mapping else set()
131 collisions = sorted(name for name in reserved if hasattr(model, name))
132 if collisions:
133 raise ValueError(
134 f"{type(model).__name__} cannot be wrapped by NativeArchitectureAdapter: "
135 f"attribute names {collisions} collide with bridge component slots "
136 f"({sorted(reserved)}). Rename these attributes to non-colliding names "
137 f"(e.g. tok_embed, layers, ln_out, head) and update the adapter's "
138 f"component_mapping ``name=`` fields to match."
139 )