Coverage for transformer_lens/model_bridge/supported_architectures/native.py: 98%

50 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Architecture adapter for TL-native models built via ``boot_native``. 

2 

3Component mapping adapts to cfg: gated MLP → ``GatedMLPBridge``, RMS norm → 

4``RMSNormalizationBridge``, rotary drops ``pos_embed``, ``attn_only`` drops MLP. 

5""" 

6 

7from typing import Any 

8 

9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

10from transformer_lens.model_bridge.generalized_components import ( 

11 AttentionBridge, 

12 BlockBridge, 

13 EmbeddingBridge, 

14 GatedMLPBridge, 

15 LinearBridge, 

16 MLPBridge, 

17 NormalizationBridge, 

18 PosEmbedBridge, 

19 RMSNormalizationBridge, 

20 UnembeddingBridge, 

21) 

22from transformer_lens.model_bridge.generalized_components.base import ( 

23 GeneralizedComponent, 

24) 

25 

26 

27def _uses_rms(cfg: Any) -> bool: 

28 return (getattr(cfg, "normalization_type", None) or "LN").upper() in ("RMS", "RMSPRE") 

29 

30 

31def _uses_no_norm(cfg: Any) -> bool: 

32 return getattr(cfg, "normalization_type", None) is None 

33 

34 

35def _is_rotary(cfg: Any) -> bool: 

36 return (getattr(cfg, "positional_embedding_type", None) or "standard").lower() == "rotary" 

37 

38 

39def _make_norm_bridge(name: str, cfg: Any, *, force_rms: bool = False): 

40 if force_rms or _uses_rms(cfg): 

41 return RMSNormalizationBridge(name=name, config=cfg) 

42 if _uses_no_norm(cfg): 

43 return GeneralizedComponent(name=name, config=cfg) 

44 return NormalizationBridge(name=name, config=cfg) 

45 

46 

47def _make_mlp_bridge(cfg: Any): 

48 if cfg.gated_mlp: 

49 return GatedMLPBridge( 

50 name="mlp", 

51 config=cfg, 

52 submodules={ 

53 "gate": LinearBridge(name="gate"), 

54 "in": LinearBridge(name="in"), 

55 "out": LinearBridge(name="out"), 

56 }, 

57 ) 

58 return MLPBridge( 

59 name="mlp", 

60 submodules={ 

61 "in": LinearBridge(name="fc_in"), 

62 "out": LinearBridge(name="fc_out"), 

63 }, 

64 ) 

65 

66 

67def _make_block_submodules(cfg: Any) -> dict: 

68 submods: dict = { 

69 "ln1": _make_norm_bridge("ln1", cfg), 

70 "attn": AttentionBridge( 

71 name="attn", 

72 config=cfg, 

73 submodules={ 

74 "q": LinearBridge(name="q"), 

75 "k": LinearBridge(name="k"), 

76 "v": LinearBridge(name="v"), 

77 "o": LinearBridge(name="o"), 

78 }, 

79 ), 

80 } 

81 if not cfg.attn_only: 

82 submods["ln2"] = _make_norm_bridge("ln2", cfg) 

83 submods["mlp"] = _make_mlp_bridge(cfg) 

84 return submods 

85 

86 

87class NativeArchitectureAdapter(ArchitectureAdapter): 

88 """Adapter for ``NativeModel`` — TL-native, split-QKV, pre-LN; feature set 

89 driven by cfg (gated MLP, RMS norm, rotary, GQA, soft-cap, attn_only).""" 

90 

91 def __init__(self, cfg: Any) -> None: 

92 super().__init__(cfg) 

93 

94 # Native layout already stores Q/K/V split; no rearranges needed. 

95 # Compatibility-mode fold_ln / center_writing_weights aren't wired up, 

96 # so gate the corresponding ProcessWeights paths off — folding without 

97 # the state-dict conversions would mis-place or drop weights. 

98 self.supports_fold_ln = False 

99 self.supports_center_writing_weights = False 

100 self.weight_processing_conversions = {} 

101 

102 # Internal attribute names avoid collisions with bridge slot names 

103 # ("embed", "blocks", "ln_final", "unembed") — the bridge's __getattr__ 

104 # forwards to original_model and would shadow add_module otherwise. 

105 mapping: dict = { 

106 "embed": EmbeddingBridge(name="tok_embed"), 

107 } 

108 if not _is_rotary(cfg): 

109 mapping["pos_embed"] = PosEmbedBridge(name="pos") 

110 block_bridge = BlockBridge( 

111 name="layers", 

112 config=self.cfg, 

113 submodules=_make_block_submodules(self.cfg), 

114 ) 

115 # Under attn_only there's no ln2 / mlp to point at; drop the aliases 

116 # that would otherwise warn during _register_aliases. 

117 if self.cfg.attn_only: 

118 if block_bridge.hook_aliases is BlockBridge.hook_aliases: 118 ↛ 120line 118 didn't jump to line 120 because the condition on line 118 was always true

119 block_bridge.hook_aliases = dict(block_bridge.hook_aliases) 

120 block_bridge.hook_aliases.pop("hook_resid_mid", None) 

121 block_bridge.hook_aliases.pop("hook_mlp_out", None) 

122 mapping["blocks"] = block_bridge 

123 # final_rms forces RMS on the final norm independent of block norm — 

124 # matches Llama's TL config semantic. 

125 mapping["ln_final"] = _make_norm_bridge( 

126 "ln_out", self.cfg, force_rms=bool(getattr(self.cfg, "final_rms", False)) 

127 ) 

128 mapping["unembed"] = UnembeddingBridge(name="head") 

129 self.component_mapping = mapping 

130 

131 def prepare_model(self, model: Any) -> None: 

132 """Reject modules whose attribute names collide with bridge slots. 

133 

134 Bridge's ``__getattr__`` falls back to ``getattr(original_model, name)`` 

135 for unknown attrs, so a name match — submodule, buffer, plain tensor, 

136 or property — makes ``add_module`` raise mid-setup with an opaque 

137 message. Failing here points at the real cause. Reserved set is derived 

138 from ``component_mapping.keys()`` so adapter variants stay in sync. 

139 """ 

140 reserved = set(self.component_mapping.keys()) if self.component_mapping else set() 

141 collisions = sorted(name for name in reserved if hasattr(model, name)) 

142 if collisions: 

143 raise ValueError( 

144 f"{type(model).__name__} cannot be wrapped by NativeArchitectureAdapter: " 

145 f"attribute names {collisions} collide with bridge component slots " 

146 f"({sorted(reserved)}). Rename these attributes to non-colliding names " 

147 f"(e.g. tok_embed, layers, ln_out, head) and update the adapter's " 

148 f"component_mapping ``name=`` fields to match." 

149 )