Coverage for transformer_lens/factories/architecture_adapter_factory.py: 94%

13 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Architecture adapter factory. 

2 

3This module provides a factory for creating architecture adapters. 

4""" 

5 

6from transformer_lens.config import TransformerBridgeConfig 

7from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

8from transformer_lens.model_bridge.supported_architectures import ( 

9 ApertusArchitectureAdapter, 

10 BaichuanArchitectureAdapter, 

11 BertArchitectureAdapter, 

12 BloomArchitectureAdapter, 

13 CodeGenArchitectureAdapter, 

14 CohereArchitectureAdapter, 

15 DeepSeekV3ArchitectureAdapter, 

16 FalconArchitectureAdapter, 

17 Gemma1ArchitectureAdapter, 

18 Gemma2ArchitectureAdapter, 

19 Gemma3ArchitectureAdapter, 

20 Gemma3MultimodalArchitectureAdapter, 

21 GPT2ArchitectureAdapter, 

22 Gpt2LmHeadCustomArchitectureAdapter, 

23 GPTBigCodeArchitectureAdapter, 

24 GptjArchitectureAdapter, 

25 GPTOSSArchitectureAdapter, 

26 GraniteArchitectureAdapter, 

27 GraniteMoeArchitectureAdapter, 

28 GraniteMoeHybridArchitectureAdapter, 

29 HubertArchitectureAdapter, 

30 InternLM2ArchitectureAdapter, 

31 LlamaArchitectureAdapter, 

32 LlavaArchitectureAdapter, 

33 LlavaNextArchitectureAdapter, 

34 LlavaOnevisionArchitectureAdapter, 

35 Mamba2ArchitectureAdapter, 

36 MambaArchitectureAdapter, 

37 MingptArchitectureAdapter, 

38 MistralArchitectureAdapter, 

39 MixtralArchitectureAdapter, 

40 MPTArchitectureAdapter, 

41 NanogptArchitectureAdapter, 

42 NeelSoluOldArchitectureAdapter, 

43 NeoArchitectureAdapter, 

44 NeoxArchitectureAdapter, 

45 Olmo2ArchitectureAdapter, 

46 Olmo3ArchitectureAdapter, 

47 OlmoArchitectureAdapter, 

48 OlmoeArchitectureAdapter, 

49 OpenElmArchitectureAdapter, 

50 OptArchitectureAdapter, 

51 Phi3ArchitectureAdapter, 

52 PhiArchitectureAdapter, 

53 Qwen2ArchitectureAdapter, 

54 Qwen3_5ArchitectureAdapter, 

55 Qwen3ArchitectureAdapter, 

56 Qwen3MoeArchitectureAdapter, 

57 Qwen3NextArchitectureAdapter, 

58 QwenArchitectureAdapter, 

59 StableLmArchitectureAdapter, 

60 T5ArchitectureAdapter, 

61 XGLMArchitectureAdapter, 

62) 

63 

64# Export supported architectures 

65SUPPORTED_ARCHITECTURES = { 

66 "ApertusForCausalLM": ApertusArchitectureAdapter, 

67 "BaiChuanForCausalLM": BaichuanArchitectureAdapter, 

68 "BaichuanForCausalLM": BaichuanArchitectureAdapter, 

69 "BertForMaskedLM": BertArchitectureAdapter, 

70 "BloomForCausalLM": BloomArchitectureAdapter, 

71 "CodeGenForCausalLM": CodeGenArchitectureAdapter, 

72 "CohereForCausalLM": CohereArchitectureAdapter, 

73 "DeepseekV3ForCausalLM": DeepSeekV3ArchitectureAdapter, 

74 "FalconForCausalLM": FalconArchitectureAdapter, 

75 "GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version 

76 "Gemma1ForCausalLM": Gemma1ArchitectureAdapter, 

77 "Gemma2ForCausalLM": Gemma2ArchitectureAdapter, 

78 "Gemma3ForCausalLM": Gemma3ArchitectureAdapter, 

79 "Gemma3ForConditionalGeneration": Gemma3MultimodalArchitectureAdapter, 

80 "GraniteForCausalLM": GraniteArchitectureAdapter, 

81 "GraniteMoeForCausalLM": GraniteMoeArchitectureAdapter, 

82 "GraniteMoeHybridForCausalLM": GraniteMoeHybridArchitectureAdapter, 

83 "GPT2LMHeadModel": GPT2ArchitectureAdapter, 

84 "GPTBigCodeForCausalLM": GPTBigCodeArchitectureAdapter, 

85 "GptOssForCausalLM": GPTOSSArchitectureAdapter, 

86 "GPT2LMHeadCustomModel": Gpt2LmHeadCustomArchitectureAdapter, 

87 "GPTJForCausalLM": GptjArchitectureAdapter, 

88 "HubertForCTC": HubertArchitectureAdapter, 

89 "HubertModel": HubertArchitectureAdapter, 

90 "InternLM2ForCausalLM": InternLM2ArchitectureAdapter, 

91 "LlamaForCausalLM": LlamaArchitectureAdapter, 

92 "LlavaForConditionalGeneration": LlavaArchitectureAdapter, 

93 "LlavaNextForConditionalGeneration": LlavaNextArchitectureAdapter, 

94 "LlavaOnevisionForConditionalGeneration": LlavaOnevisionArchitectureAdapter, 

95 "Mamba2ForCausalLM": Mamba2ArchitectureAdapter, 

96 "MambaForCausalLM": MambaArchitectureAdapter, 

97 "MixtralForCausalLM": MixtralArchitectureAdapter, 

98 "MistralForCausalLM": MistralArchitectureAdapter, 

99 "MPTForCausalLM": MPTArchitectureAdapter, 

100 "NeoForCausalLM": NeoArchitectureAdapter, 

101 "NeoXForCausalLM": NeoxArchitectureAdapter, 

102 "NeelSoluOldForCausalLM": NeelSoluOldArchitectureAdapter, 

103 "OlmoForCausalLM": OlmoArchitectureAdapter, 

104 "Olmo2ForCausalLM": Olmo2ArchitectureAdapter, 

105 "Olmo3ForCausalLM": Olmo3ArchitectureAdapter, 

106 "OlmoeForCausalLM": OlmoeArchitectureAdapter, 

107 "OpenELMForCausalLM": OpenElmArchitectureAdapter, 

108 "OPTForCausalLM": OptArchitectureAdapter, 

109 "PhiForCausalLM": PhiArchitectureAdapter, 

110 "Phi3ForCausalLM": Phi3ArchitectureAdapter, 

111 "QwenForCausalLM": QwenArchitectureAdapter, 

112 "Qwen2ForCausalLM": Qwen2ArchitectureAdapter, 

113 "Qwen3ForCausalLM": Qwen3ArchitectureAdapter, 

114 "Qwen3MoeForCausalLM": Qwen3MoeArchitectureAdapter, 

115 "Qwen3NextForCausalLM": Qwen3NextArchitectureAdapter, 

116 "Qwen3_5ForCausalLM": Qwen3_5ArchitectureAdapter, 

117 "StableLmForCausalLM": StableLmArchitectureAdapter, 

118 "T5ForConditionalGeneration": T5ArchitectureAdapter, 

119 "XGLMForCausalLM": XGLMArchitectureAdapter, 

120 "NanoGPTForCausalLM": NanogptArchitectureAdapter, 

121 "MinGPTForCausalLM": MingptArchitectureAdapter, 

122 "GPTNeoForCausalLM": NeoArchitectureAdapter, 

123 "GPTNeoXForCausalLM": NeoxArchitectureAdapter, 

124} 

125 

126 

127class ArchitectureAdapterFactory: 

128 """Factory for creating architecture adapters.""" 

129 

130 _adapters = SUPPORTED_ARCHITECTURES 

131 

132 @classmethod 

133 def select_architecture_adapter(cls, cfg: TransformerBridgeConfig) -> ArchitectureAdapter: 

134 """Select the appropriate architecture adapter for the given config. 

135 

136 Args: 

137 cfg: The TransformerBridgeConfig to select the adapter for. 

138 

139 Returns: 

140 The selected architecture adapter. 

141 

142 Raises: 

143 ValueError: If no adapter is found for the given config. 

144 """ 

145 if cfg.architecture is not None: 145 ↛ 152line 145 didn't jump to line 152 because the condition on line 145 was always true

146 if cfg.architecture in cls._adapters: 

147 return cls._adapters[cfg.architecture](cfg) 

148 else: 

149 raise ValueError(f"Unsupported architecture: {cfg.architecture}") 

150 

151 # If architecture is None, this is an error since TransformerBridgeConfig should always have it set 

152 raise ValueError(f"TransformerBridgeConfig must have architecture set, got: {cfg}")