Coverage for transformer_lens/model_bridge/supported_architectures/native.py: 98%
50 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Architecture adapter for TL-native models built via ``boot_native``.
3Component mapping adapts to cfg: gated MLP → ``GatedMLPBridge``, RMS norm →
4``RMSNormalizationBridge``, rotary drops ``pos_embed``, ``attn_only`` drops MLP.
5"""
7from typing import Any
9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
10from transformer_lens.model_bridge.generalized_components import (
11 AttentionBridge,
12 BlockBridge,
13 EmbeddingBridge,
14 GatedMLPBridge,
15 LinearBridge,
16 MLPBridge,
17 NormalizationBridge,
18 PosEmbedBridge,
19 RMSNormalizationBridge,
20 UnembeddingBridge,
21)
22from transformer_lens.model_bridge.generalized_components.base import (
23 GeneralizedComponent,
24)
27def _uses_rms(cfg: Any) -> bool:
28 return (getattr(cfg, "normalization_type", None) or "LN").upper() in ("RMS", "RMSPRE")
31def _uses_no_norm(cfg: Any) -> bool:
32 return getattr(cfg, "normalization_type", None) is None
35def _is_rotary(cfg: Any) -> bool:
36 return (getattr(cfg, "positional_embedding_type", None) or "standard").lower() == "rotary"
39def _make_norm_bridge(name: str, cfg: Any, *, force_rms: bool = False):
40 if force_rms or _uses_rms(cfg):
41 return RMSNormalizationBridge(name=name, config=cfg)
42 if _uses_no_norm(cfg):
43 return GeneralizedComponent(name=name, config=cfg)
44 return NormalizationBridge(name=name, config=cfg)
47def _make_mlp_bridge(cfg: Any):
48 if cfg.gated_mlp:
49 return GatedMLPBridge(
50 name="mlp",
51 config=cfg,
52 submodules={
53 "gate": LinearBridge(name="gate"),
54 "in": LinearBridge(name="in"),
55 "out": LinearBridge(name="out"),
56 },
57 )
58 return MLPBridge(
59 name="mlp",
60 submodules={
61 "in": LinearBridge(name="fc_in"),
62 "out": LinearBridge(name="fc_out"),
63 },
64 )
67def _make_block_submodules(cfg: Any) -> dict:
68 submods: dict = {
69 "ln1": _make_norm_bridge("ln1", cfg),
70 "attn": AttentionBridge(
71 name="attn",
72 config=cfg,
73 submodules={
74 "q": LinearBridge(name="q"),
75 "k": LinearBridge(name="k"),
76 "v": LinearBridge(name="v"),
77 "o": LinearBridge(name="o"),
78 },
79 ),
80 }
81 if not cfg.attn_only:
82 submods["ln2"] = _make_norm_bridge("ln2", cfg)
83 submods["mlp"] = _make_mlp_bridge(cfg)
84 return submods
87class NativeArchitectureAdapter(ArchitectureAdapter):
88 """Adapter for ``NativeModel`` — TL-native, split-QKV, pre-LN; feature set
89 driven by cfg (gated MLP, RMS norm, rotary, GQA, soft-cap, attn_only)."""
91 def __init__(self, cfg: Any) -> None:
92 super().__init__(cfg)
94 # Native layout already stores Q/K/V split; no rearranges needed.
95 # Compatibility-mode fold_ln / center_writing_weights aren't wired up,
96 # so gate the corresponding ProcessWeights paths off — folding without
97 # the state-dict conversions would mis-place or drop weights.
98 self.supports_fold_ln = False
99 self.supports_center_writing_weights = False
100 self.weight_processing_conversions = {}
102 # Internal attribute names avoid collisions with bridge slot names
103 # ("embed", "blocks", "ln_final", "unembed") — the bridge's __getattr__
104 # forwards to original_model and would shadow add_module otherwise.
105 mapping: dict = {
106 "embed": EmbeddingBridge(name="tok_embed"),
107 }
108 if not _is_rotary(cfg):
109 mapping["pos_embed"] = PosEmbedBridge(name="pos")
110 block_bridge = BlockBridge(
111 name="layers",
112 config=self.cfg,
113 submodules=_make_block_submodules(self.cfg),
114 )
115 # Under attn_only there's no ln2 / mlp to point at; drop the aliases
116 # that would otherwise warn during _register_aliases.
117 if self.cfg.attn_only:
118 if block_bridge.hook_aliases is BlockBridge.hook_aliases: 118 ↛ 120line 118 didn't jump to line 120 because the condition on line 118 was always true
119 block_bridge.hook_aliases = dict(block_bridge.hook_aliases)
120 block_bridge.hook_aliases.pop("hook_resid_mid", None)
121 block_bridge.hook_aliases.pop("hook_mlp_out", None)
122 mapping["blocks"] = block_bridge
123 # final_rms forces RMS on the final norm independent of block norm —
124 # matches Llama's TL config semantic.
125 mapping["ln_final"] = _make_norm_bridge(
126 "ln_out", self.cfg, force_rms=bool(getattr(self.cfg, "final_rms", False))
127 )
128 mapping["unembed"] = UnembeddingBridge(name="head")
129 self.component_mapping = mapping
131 def prepare_model(self, model: Any) -> None:
132 """Reject modules whose attribute names collide with bridge slots.
134 Bridge's ``__getattr__`` falls back to ``getattr(original_model, name)``
135 for unknown attrs, so a name match — submodule, buffer, plain tensor,
136 or property — makes ``add_module`` raise mid-setup with an opaque
137 message. Failing here points at the real cause. Reserved set is derived
138 from ``component_mapping.keys()`` so adapter variants stay in sync.
139 """
140 reserved = set(self.component_mapping.keys()) if self.component_mapping else set()
141 collisions = sorted(name for name in reserved if hasattr(model, name))
142 if collisions:
143 raise ValueError(
144 f"{type(model).__name__} cannot be wrapped by NativeArchitectureAdapter: "
145 f"attribute names {collisions} collide with bridge component slots "
146 f"({sorted(reserved)}). Rename these attributes to non-colliding names "
147 f"(e.g. tok_embed, layers, ln_out, head) and update the adapter's "
148 f"component_mapping ``name=`` fields to match."
149 )