Coverage for transformer_lens/utilities/architectures.py: 32%

49 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Centralized architecture classification for TransformerLens. 

2 

3Single source of truth for architecture type detection. Used by the bridge 

4loading pipeline, benchmarks, and verification tools. 

5""" 

6 

7from typing import Optional 

8 

9# Encoder-decoder models (T5, BART, etc.) 

10SEQ2SEQ_ARCHITECTURES: set[str] = { 

11 "T5ForConditionalGeneration", 

12 "MT5ForConditionalGeneration", 

13 "T5GemmaForConditionalGeneration", 

14 "BartForConditionalGeneration", 

15 "MBartForConditionalGeneration", 

16 "MarianMTModel", 

17 "PegasusForConditionalGeneration", 

18 "BlenderbotForConditionalGeneration", 

19 "BlenderbotSmallForConditionalGeneration", 

20} 

21 

22# Masked language models (BERT-style, no text generation) 

23MASKED_LM_ARCHITECTURES: set[str] = { 

24 "BertForMaskedLM", 

25 "RobertaForMaskedLM", 

26 "AlbertForMaskedLM", 

27 "DistilBertForMaskedLM", 

28 "ElectraForMaskedLM", 

29} 

30 

31# Vision-language multimodal models 

32MULTIMODAL_ARCHITECTURES: set[str] = { 

33 "LlavaForConditionalGeneration", 

34 "LlavaNextForConditionalGeneration", 

35 "LlavaOnevisionForConditionalGeneration", 

36 "Gemma3ForConditionalGeneration", 

37 "Gemma4ForConditionalGeneration", 

38 "Qwen3_5ForConditionalGeneration", 

39} 

40 

41# Audio encoder models (HuBERT, wav2vec2, etc.) 

42AUDIO_ARCHITECTURES: set[str] = { 

43 "HubertForCTC", 

44 "HubertModel", 

45 "HubertForSequenceClassification", 

46} 

47 

48# Bridge uses different hook shapes than HookedTransformer by design. 

49# Phase 2/3 HT comparisons are skipped; Phase 1 (HF comparison) is the gold standard. 

50NO_HT_COMPARISON_ARCHITECTURES: set[str] = ( 

51 MULTIMODAL_ARCHITECTURES 

52 | AUDIO_ARCHITECTURES 

53 | { 

54 "Gemma3ForCausalLM", 

55 } 

56) 

57 

58 

59def classify_architecture(architecture: str) -> str: 

60 """Classify an architecture string into a model type. 

61 

62 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm" 

63 """ 

64 if architecture in SEQ2SEQ_ARCHITECTURES: 

65 return "seq2seq" 

66 if architecture in MASKED_LM_ARCHITECTURES: 

67 return "masked_lm" 

68 if architecture in MULTIMODAL_ARCHITECTURES: 

69 return "multimodal" 

70 if architecture in AUDIO_ARCHITECTURES: 

71 return "audio" 

72 return "causal_lm" 

73 

74 

75def get_architectures_for_config(config) -> list[str]: 

76 """Extract architecture strings from an HF config object.""" 

77 architectures = [] 

78 if hasattr(config, "original_architecture"): 

79 architectures.append(config.original_architecture) 

80 if hasattr(config, "architectures") and config.architectures: 

81 architectures.extend(config.architectures) 

82 return architectures 

83 

84 

85def classify_model_config(config) -> str: 

86 """Classify a model by its HF config. 

87 

88 Checks config.is_encoder_decoder first, then falls back to architecture list. 

89 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm" 

90 """ 

91 if getattr(config, "is_encoder_decoder", False): 

92 return "seq2seq" 

93 for arch in get_architectures_for_config(config): 

94 model_type = classify_architecture(arch) 

95 if model_type != "causal_lm": 

96 return model_type 

97 return "causal_lm" 

98 

99 

100def classify_model_name( 

101 model_name: str, 

102 trust_remote_code: bool = False, 

103 token: Optional[str] = None, 

104) -> str: 

105 """Classify a model by its HuggingFace model name. 

106 

107 Loads the config once, classifies from it. If token is None, reads 

108 HF_TOKEN from the environment automatically. 

109 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm" 

110 """ 

111 try: 

112 from transformers import AutoConfig 

113 

114 if token is None: 

115 from transformer_lens.utilities.hf_utils import get_hf_token 

116 

117 token = get_hf_token() 

118 

119 config = AutoConfig.from_pretrained( 

120 model_name, trust_remote_code=trust_remote_code, token=token 

121 ) 

122 return classify_model_config(config) 

123 except Exception: 

124 return "causal_lm" 

125 

126 

127def is_masked_lm_model( 

128 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None 

129) -> bool: 

130 """Check if a model is a masked language model (BERT-style).""" 

131 return ( 

132 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token) 

133 == "masked_lm" 

134 ) 

135 

136 

137def is_encoder_decoder_model( 

138 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None 

139) -> bool: 

140 """Check if a model is an encoder-decoder architecture (T5, BART, etc.).""" 

141 return ( 

142 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token) 

143 == "seq2seq" 

144 ) 

145 

146 

147def is_multimodal_model( 

148 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None 

149) -> bool: 

150 """Check if a model is a multimodal vision-language model (LLaVA, Gemma3).""" 

151 return ( 

152 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token) 

153 == "multimodal" 

154 ) 

155 

156 

157def is_audio_model( 

158 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None 

159) -> bool: 

160 """Check if a model is an audio encoder model (HuBERT, wav2vec2).""" 

161 return ( 

162 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token) == "audio" 

163 )