Coverage for transformer_lens/utilities/architectures.py: 75%
49 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Centralized architecture classification for TransformerLens.
3Single source of truth for architecture type detection. Used by the bridge
4loading pipeline, benchmarks, and verification tools.
5"""
7from typing import Optional
9# Encoder-decoder models (T5, BART, etc.)
10SEQ2SEQ_ARCHITECTURES: set[str] = {
11 "T5ForConditionalGeneration",
12 "MT5ForConditionalGeneration",
13 "BartForConditionalGeneration",
14 "MBartForConditionalGeneration",
15 "MarianMTModel",
16 "PegasusForConditionalGeneration",
17 "BlenderbotForConditionalGeneration",
18 "BlenderbotSmallForConditionalGeneration",
19}
21# Masked language models (BERT-style, no text generation)
22MASKED_LM_ARCHITECTURES: set[str] = {
23 "BertForMaskedLM",
24 "RobertaForMaskedLM",
25 "AlbertForMaskedLM",
26 "DistilBertForMaskedLM",
27 "ElectraForMaskedLM",
28}
30# Vision-language multimodal models
31MULTIMODAL_ARCHITECTURES: set[str] = {
32 "LlavaForConditionalGeneration",
33 "LlavaNextForConditionalGeneration",
34 "LlavaOnevisionForConditionalGeneration",
35 "Gemma3ForConditionalGeneration",
36}
38# Audio encoder models (HuBERT, wav2vec2, etc.)
39AUDIO_ARCHITECTURES: set[str] = {
40 "HubertForCTC",
41 "HubertModel",
42 "HubertForSequenceClassification",
43}
45# Bridge uses different hook shapes than HookedTransformer by design.
46# Phase 2/3 HT comparisons are skipped; Phase 1 (HF comparison) is the gold standard.
47NO_HT_COMPARISON_ARCHITECTURES: set[str] = (
48 MULTIMODAL_ARCHITECTURES
49 | AUDIO_ARCHITECTURES
50 | {
51 "Gemma3ForCausalLM",
52 }
53)
56def classify_architecture(architecture: str) -> str:
57 """Classify an architecture string into a model type.
59 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm"
60 """
61 if architecture in SEQ2SEQ_ARCHITECTURES: 61 ↛ 62line 61 didn't jump to line 62 because the condition on line 61 was never true
62 return "seq2seq"
63 if architecture in MASKED_LM_ARCHITECTURES: 63 ↛ 64line 63 didn't jump to line 64 because the condition on line 63 was never true
64 return "masked_lm"
65 if architecture in MULTIMODAL_ARCHITECTURES: 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true
66 return "multimodal"
67 if architecture in AUDIO_ARCHITECTURES: 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true
68 return "audio"
69 return "causal_lm"
72def get_architectures_for_config(config) -> list[str]:
73 """Extract architecture strings from an HF config object."""
74 architectures = []
75 if hasattr(config, "original_architecture"): 75 ↛ 76line 75 didn't jump to line 76 because the condition on line 75 was never true
76 architectures.append(config.original_architecture)
77 if hasattr(config, "architectures") and config.architectures: 77 ↛ 79line 77 didn't jump to line 79 because the condition on line 77 was always true
78 architectures.extend(config.architectures)
79 return architectures
82def classify_model_config(config) -> str:
83 """Classify a model by its HF config.
85 Checks config.is_encoder_decoder first, then falls back to architecture list.
86 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm"
87 """
88 if getattr(config, "is_encoder_decoder", False): 88 ↛ 89line 88 didn't jump to line 89 because the condition on line 88 was never true
89 return "seq2seq"
90 for arch in get_architectures_for_config(config):
91 model_type = classify_architecture(arch)
92 if model_type != "causal_lm": 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true
93 return model_type
94 return "causal_lm"
97def classify_model_name(
98 model_name: str,
99 trust_remote_code: bool = False,
100 token: Optional[str] = None,
101) -> str:
102 """Classify a model by its HuggingFace model name.
104 Loads the config once, classifies from it. If token is None, reads
105 HF_TOKEN from the environment automatically.
106 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm"
107 """
108 try:
109 from transformers import AutoConfig
111 if token is None: 111 ↛ 116line 111 didn't jump to line 116 because the condition on line 111 was always true
112 from transformer_lens.utilities.hf_utils import get_hf_token
114 token = get_hf_token()
116 config = AutoConfig.from_pretrained(
117 model_name, trust_remote_code=trust_remote_code, token=token
118 )
119 return classify_model_config(config)
120 except Exception:
121 return "causal_lm"
124def is_masked_lm_model(
125 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None
126) -> bool:
127 """Check if a model is a masked language model (BERT-style)."""
128 return (
129 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token)
130 == "masked_lm"
131 )
134def is_encoder_decoder_model(
135 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None
136) -> bool:
137 """Check if a model is an encoder-decoder architecture (T5, BART, etc.)."""
138 return (
139 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token)
140 == "seq2seq"
141 )
144def is_multimodal_model(
145 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None
146) -> bool:
147 """Check if a model is a multimodal vision-language model (LLaVA, Gemma3)."""
148 return (
149 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token)
150 == "multimodal"
151 )
154def is_audio_model(
155 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None
156) -> bool:
157 """Check if a model is an audio encoder model (HuBERT, wav2vec2)."""
158 return (
159 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token) == "audio"
160 )