Coverage for transformer_lens/model_bridge/supported_architectures/native.py: 98%

45 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Architecture adapter for TL-native models built via ``boot_native``. 

2 

3Component mapping adapts to cfg: gated MLP → ``GatedMLPBridge``, RMS norm → 

4``RMSNormalizationBridge``, rotary drops ``pos_embed``, ``attn_only`` drops MLP. 

5""" 

6from typing import Any 

7 

8from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

9from transformer_lens.model_bridge.generalized_components import ( 

10 AttentionBridge, 

11 BlockBridge, 

12 EmbeddingBridge, 

13 GatedMLPBridge, 

14 LinearBridge, 

15 MLPBridge, 

16 NormalizationBridge, 

17 PosEmbedBridge, 

18 RMSNormalizationBridge, 

19 UnembeddingBridge, 

20) 

21 

22 

23def _uses_rms(cfg: Any) -> bool: 

24 return (getattr(cfg, "normalization_type", None) or "LN").upper() in ("RMS", "RMSPRE") 

25 

26 

27def _is_rotary(cfg: Any) -> bool: 

28 return (getattr(cfg, "positional_embedding_type", None) or "standard").lower() == "rotary" 

29 

30 

31def _make_norm_bridge(name: str, cfg: Any, *, force_rms: bool = False): 

32 if force_rms or _uses_rms(cfg): 

33 return RMSNormalizationBridge(name=name, config=cfg) 

34 return NormalizationBridge(name=name, config=cfg) 

35 

36 

37def _make_mlp_bridge(cfg: Any): 

38 if cfg.gated_mlp: 

39 return GatedMLPBridge( 

40 name="mlp", 

41 config=cfg, 

42 submodules={ 

43 "gate": LinearBridge(name="gate"), 

44 "in": LinearBridge(name="in"), 

45 "out": LinearBridge(name="out"), 

46 }, 

47 ) 

48 return MLPBridge( 

49 name="mlp", 

50 submodules={ 

51 "in": LinearBridge(name="fc_in"), 

52 "out": LinearBridge(name="fc_out"), 

53 }, 

54 ) 

55 

56 

57def _make_block_submodules(cfg: Any) -> dict: 

58 submods: dict = { 

59 "ln1": _make_norm_bridge("ln1", cfg), 

60 "attn": AttentionBridge( 

61 name="attn", 

62 config=cfg, 

63 submodules={ 

64 "q": LinearBridge(name="q"), 

65 "k": LinearBridge(name="k"), 

66 "v": LinearBridge(name="v"), 

67 "o": LinearBridge(name="o"), 

68 }, 

69 ), 

70 } 

71 if not cfg.attn_only: 

72 submods["ln2"] = _make_norm_bridge("ln2", cfg) 

73 submods["mlp"] = _make_mlp_bridge(cfg) 

74 return submods 

75 

76 

77class NativeArchitectureAdapter(ArchitectureAdapter): 

78 """Adapter for ``NativeModel`` — TL-native, split-QKV, pre-LN; feature set 

79 driven by cfg (gated MLP, RMS norm, rotary, GQA, soft-cap, attn_only).""" 

80 

81 def __init__(self, cfg: Any) -> None: 

82 super().__init__(cfg) 

83 

84 # Native layout already stores Q/K/V split; no rearranges needed. 

85 # Compatibility-mode fold_ln / center_writing_weights aren't wired up, 

86 # so gate the corresponding ProcessWeights paths off — folding without 

87 # the state-dict conversions would mis-place or drop weights. 

88 self.supports_fold_ln = False 

89 self.supports_center_writing_weights = False 

90 self.weight_processing_conversions = {} 

91 

92 # Internal attribute names avoid collisions with bridge slot names 

93 # ("embed", "blocks", "ln_final", "unembed") — the bridge's __getattr__ 

94 # forwards to original_model and would shadow add_module otherwise. 

95 mapping: dict = { 

96 "embed": EmbeddingBridge(name="tok_embed"), 

97 } 

98 if not _is_rotary(cfg): 

99 mapping["pos_embed"] = PosEmbedBridge(name="pos") 

100 block_bridge = BlockBridge( 

101 name="layers", 

102 config=self.cfg, 

103 submodules=_make_block_submodules(self.cfg), 

104 ) 

105 # Under attn_only there's no ln2 / mlp to point at; drop the aliases 

106 # that would otherwise warn during _register_aliases. 

107 if self.cfg.attn_only: 

108 if block_bridge.hook_aliases is BlockBridge.hook_aliases: 108 ↛ 110line 108 didn't jump to line 110 because the condition on line 108 was always true

109 block_bridge.hook_aliases = dict(block_bridge.hook_aliases) 

110 block_bridge.hook_aliases.pop("hook_resid_mid", None) 

111 block_bridge.hook_aliases.pop("hook_mlp_out", None) 

112 mapping["blocks"] = block_bridge 

113 # final_rms forces RMS on the final norm independent of block norm — 

114 # matches Llama's TL config semantic. 

115 mapping["ln_final"] = _make_norm_bridge( 

116 "ln_out", self.cfg, force_rms=bool(getattr(self.cfg, "final_rms", False)) 

117 ) 

118 mapping["unembed"] = UnembeddingBridge(name="head") 

119 self.component_mapping = mapping 

120 

121 def prepare_model(self, model: Any) -> None: 

122 """Reject modules whose attribute names collide with bridge slots. 

123 

124 Bridge's ``__getattr__`` falls back to ``getattr(original_model, name)`` 

125 for unknown attrs, so a name match — submodule, buffer, plain tensor, 

126 or property — makes ``add_module`` raise mid-setup with an opaque 

127 message. Failing here points at the real cause. Reserved set is derived 

128 from ``component_mapping.keys()`` so adapter variants stay in sync. 

129 """ 

130 reserved = set(self.component_mapping.keys()) if self.component_mapping else set() 

131 collisions = sorted(name for name in reserved if hasattr(model, name)) 

132 if collisions: 

133 raise ValueError( 

134 f"{type(model).__name__} cannot be wrapped by NativeArchitectureAdapter: " 

135 f"attribute names {collisions} collide with bridge component slots " 

136 f"({sorted(reserved)}). Rename these attributes to non-colliding names " 

137 f"(e.g. tok_embed, layers, ln_out, head) and update the adapter's " 

138 f"component_mapping ``name=`` fields to match." 

139 )