Coverage for transformer_lens/utilities/architectures.py: 32%

49 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Centralized architecture classification for TransformerLens. 

2 

3Single source of truth for architecture type detection. Used by the bridge 

4loading pipeline, benchmarks, and verification tools. 

5""" 

6 

7from typing import Optional 

8 

9# Encoder-decoder models (T5, BART, etc.) 

10SEQ2SEQ_ARCHITECTURES: set[str] = { 

11 "T5ForConditionalGeneration", 

12 "MT5ForConditionalGeneration", 

13 "BartForConditionalGeneration", 

14 "MBartForConditionalGeneration", 

15 "MarianMTModel", 

16 "PegasusForConditionalGeneration", 

17 "BlenderbotForConditionalGeneration", 

18 "BlenderbotSmallForConditionalGeneration", 

19} 

20 

21# Masked language models (BERT-style, no text generation) 

22MASKED_LM_ARCHITECTURES: set[str] = { 

23 "BertForMaskedLM", 

24 "RobertaForMaskedLM", 

25 "AlbertForMaskedLM", 

26 "DistilBertForMaskedLM", 

27 "ElectraForMaskedLM", 

28} 

29 

30# Vision-language multimodal models 

31MULTIMODAL_ARCHITECTURES: set[str] = { 

32 "LlavaForConditionalGeneration", 

33 "LlavaNextForConditionalGeneration", 

34 "LlavaOnevisionForConditionalGeneration", 

35 "Gemma3ForConditionalGeneration", 

36 "Qwen3_5ForConditionalGeneration", 

37} 

38 

39# Audio encoder models (HuBERT, wav2vec2, etc.) 

40AUDIO_ARCHITECTURES: set[str] = { 

41 "HubertForCTC", 

42 "HubertModel", 

43 "HubertForSequenceClassification", 

44} 

45 

46# Bridge uses different hook shapes than HookedTransformer by design. 

47# Phase 2/3 HT comparisons are skipped; Phase 1 (HF comparison) is the gold standard. 

48NO_HT_COMPARISON_ARCHITECTURES: set[str] = ( 

49 MULTIMODAL_ARCHITECTURES 

50 | AUDIO_ARCHITECTURES 

51 | { 

52 "Gemma3ForCausalLM", 

53 } 

54) 

55 

56 

57def classify_architecture(architecture: str) -> str: 

58 """Classify an architecture string into a model type. 

59 

60 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm" 

61 """ 

62 if architecture in SEQ2SEQ_ARCHITECTURES: 

63 return "seq2seq" 

64 if architecture in MASKED_LM_ARCHITECTURES: 

65 return "masked_lm" 

66 if architecture in MULTIMODAL_ARCHITECTURES: 

67 return "multimodal" 

68 if architecture in AUDIO_ARCHITECTURES: 

69 return "audio" 

70 return "causal_lm" 

71 

72 

73def get_architectures_for_config(config) -> list[str]: 

74 """Extract architecture strings from an HF config object.""" 

75 architectures = [] 

76 if hasattr(config, "original_architecture"): 

77 architectures.append(config.original_architecture) 

78 if hasattr(config, "architectures") and config.architectures: 

79 architectures.extend(config.architectures) 

80 return architectures 

81 

82 

83def classify_model_config(config) -> str: 

84 """Classify a model by its HF config. 

85 

86 Checks config.is_encoder_decoder first, then falls back to architecture list. 

87 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm" 

88 """ 

89 if getattr(config, "is_encoder_decoder", False): 

90 return "seq2seq" 

91 for arch in get_architectures_for_config(config): 

92 model_type = classify_architecture(arch) 

93 if model_type != "causal_lm": 

94 return model_type 

95 return "causal_lm" 

96 

97 

98def classify_model_name( 

99 model_name: str, 

100 trust_remote_code: bool = False, 

101 token: Optional[str] = None, 

102) -> str: 

103 """Classify a model by its HuggingFace model name. 

104 

105 Loads the config once, classifies from it. If token is None, reads 

106 HF_TOKEN from the environment automatically. 

107 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm" 

108 """ 

109 try: 

110 from transformers import AutoConfig 

111 

112 if token is None: 

113 from transformer_lens.utilities.hf_utils import get_hf_token 

114 

115 token = get_hf_token() 

116 

117 config = AutoConfig.from_pretrained( 

118 model_name, trust_remote_code=trust_remote_code, token=token 

119 ) 

120 return classify_model_config(config) 

121 except Exception: 

122 return "causal_lm" 

123 

124 

125def is_masked_lm_model( 

126 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None 

127) -> bool: 

128 """Check if a model is a masked language model (BERT-style).""" 

129 return ( 

130 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token) 

131 == "masked_lm" 

132 ) 

133 

134 

135def is_encoder_decoder_model( 

136 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None 

137) -> bool: 

138 """Check if a model is an encoder-decoder architecture (T5, BART, etc.).""" 

139 return ( 

140 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token) 

141 == "seq2seq" 

142 ) 

143 

144 

145def is_multimodal_model( 

146 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None 

147) -> bool: 

148 """Check if a model is a multimodal vision-language model (LLaVA, Gemma3).""" 

149 return ( 

150 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token) 

151 == "multimodal" 

152 ) 

153 

154 

155def is_audio_model( 

156 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None 

157) -> bool: 

158 """Check if a model is an audio encoder model (HuBERT, wav2vec2).""" 

159 return ( 

160 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token) == "audio" 

161 )