Coverage for transformer_lens/utilities/architectures.py: 75%

49 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Centralized architecture classification for TransformerLens. 

2 

3Single source of truth for architecture type detection. Used by the bridge 

4loading pipeline, benchmarks, and verification tools. 

5""" 

6 

7from typing import Optional 

8 

9# Encoder-decoder models (T5, BART, etc.) 

10SEQ2SEQ_ARCHITECTURES: set[str] = { 

11 "T5ForConditionalGeneration", 

12 "MT5ForConditionalGeneration", 

13 "BartForConditionalGeneration", 

14 "MBartForConditionalGeneration", 

15 "MarianMTModel", 

16 "PegasusForConditionalGeneration", 

17 "BlenderbotForConditionalGeneration", 

18 "BlenderbotSmallForConditionalGeneration", 

19} 

20 

21# Masked language models (BERT-style, no text generation) 

22MASKED_LM_ARCHITECTURES: set[str] = { 

23 "BertForMaskedLM", 

24 "RobertaForMaskedLM", 

25 "AlbertForMaskedLM", 

26 "DistilBertForMaskedLM", 

27 "ElectraForMaskedLM", 

28} 

29 

30# Vision-language multimodal models 

31MULTIMODAL_ARCHITECTURES: set[str] = { 

32 "LlavaForConditionalGeneration", 

33 "LlavaNextForConditionalGeneration", 

34 "LlavaOnevisionForConditionalGeneration", 

35 "Gemma3ForConditionalGeneration", 

36} 

37 

38# Audio encoder models (HuBERT, wav2vec2, etc.) 

39AUDIO_ARCHITECTURES: set[str] = { 

40 "HubertForCTC", 

41 "HubertModel", 

42 "HubertForSequenceClassification", 

43} 

44 

45# Bridge uses different hook shapes than HookedTransformer by design. 

46# Phase 2/3 HT comparisons are skipped; Phase 1 (HF comparison) is the gold standard. 

47NO_HT_COMPARISON_ARCHITECTURES: set[str] = ( 

48 MULTIMODAL_ARCHITECTURES 

49 | AUDIO_ARCHITECTURES 

50 | { 

51 "Gemma3ForCausalLM", 

52 } 

53) 

54 

55 

56def classify_architecture(architecture: str) -> str: 

57 """Classify an architecture string into a model type. 

58 

59 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm" 

60 """ 

61 if architecture in SEQ2SEQ_ARCHITECTURES: 61 ↛ 62line 61 didn't jump to line 62 because the condition on line 61 was never true

62 return "seq2seq" 

63 if architecture in MASKED_LM_ARCHITECTURES: 63 ↛ 64line 63 didn't jump to line 64 because the condition on line 63 was never true

64 return "masked_lm" 

65 if architecture in MULTIMODAL_ARCHITECTURES: 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true

66 return "multimodal" 

67 if architecture in AUDIO_ARCHITECTURES: 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true

68 return "audio" 

69 return "causal_lm" 

70 

71 

72def get_architectures_for_config(config) -> list[str]: 

73 """Extract architecture strings from an HF config object.""" 

74 architectures = [] 

75 if hasattr(config, "original_architecture"): 75 ↛ 76line 75 didn't jump to line 76 because the condition on line 75 was never true

76 architectures.append(config.original_architecture) 

77 if hasattr(config, "architectures") and config.architectures: 77 ↛ 79line 77 didn't jump to line 79 because the condition on line 77 was always true

78 architectures.extend(config.architectures) 

79 return architectures 

80 

81 

82def classify_model_config(config) -> str: 

83 """Classify a model by its HF config. 

84 

85 Checks config.is_encoder_decoder first, then falls back to architecture list. 

86 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm" 

87 """ 

88 if getattr(config, "is_encoder_decoder", False): 88 ↛ 89line 88 didn't jump to line 89 because the condition on line 88 was never true

89 return "seq2seq" 

90 for arch in get_architectures_for_config(config): 

91 model_type = classify_architecture(arch) 

92 if model_type != "causal_lm": 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true

93 return model_type 

94 return "causal_lm" 

95 

96 

97def classify_model_name( 

98 model_name: str, 

99 trust_remote_code: bool = False, 

100 token: Optional[str] = None, 

101) -> str: 

102 """Classify a model by its HuggingFace model name. 

103 

104 Loads the config once, classifies from it. If token is None, reads 

105 HF_TOKEN from the environment automatically. 

106 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm" 

107 """ 

108 try: 

109 from transformers import AutoConfig 

110 

111 if token is None: 111 ↛ 116line 111 didn't jump to line 116 because the condition on line 111 was always true

112 from transformer_lens.utilities.hf_utils import get_hf_token 

113 

114 token = get_hf_token() 

115 

116 config = AutoConfig.from_pretrained( 

117 model_name, trust_remote_code=trust_remote_code, token=token 

118 ) 

119 return classify_model_config(config) 

120 except Exception: 

121 return "causal_lm" 

122 

123 

124def is_masked_lm_model( 

125 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None 

126) -> bool: 

127 """Check if a model is a masked language model (BERT-style).""" 

128 return ( 

129 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token) 

130 == "masked_lm" 

131 ) 

132 

133 

134def is_encoder_decoder_model( 

135 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None 

136) -> bool: 

137 """Check if a model is an encoder-decoder architecture (T5, BART, etc.).""" 

138 return ( 

139 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token) 

140 == "seq2seq" 

141 ) 

142 

143 

144def is_multimodal_model( 

145 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None 

146) -> bool: 

147 """Check if a model is a multimodal vision-language model (LLaVA, Gemma3).""" 

148 return ( 

149 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token) 

150 == "multimodal" 

151 ) 

152 

153 

154def is_audio_model( 

155 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None 

156) -> bool: 

157 """Check if a model is an audio encoder model (HuBERT, wav2vec2).""" 

158 return ( 

159 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token) == "audio" 

160 )