Coverage for transformer_lens/utilities/architectures.py: 32%
49 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Centralized architecture classification for TransformerLens.
3Single source of truth for architecture type detection. Used by the bridge
4loading pipeline, benchmarks, and verification tools.
5"""
7from typing import Optional
9# Encoder-decoder models (T5, BART, etc.)
10SEQ2SEQ_ARCHITECTURES: set[str] = {
11 "T5ForConditionalGeneration",
12 "MT5ForConditionalGeneration",
13 "T5GemmaForConditionalGeneration",
14 "BartForConditionalGeneration",
15 "MBartForConditionalGeneration",
16 "MarianMTModel",
17 "PegasusForConditionalGeneration",
18 "BlenderbotForConditionalGeneration",
19 "BlenderbotSmallForConditionalGeneration",
20}
22# Masked language models (BERT-style, no text generation)
23MASKED_LM_ARCHITECTURES: set[str] = {
24 "BertForMaskedLM",
25 "RobertaForMaskedLM",
26 "AlbertForMaskedLM",
27 "DistilBertForMaskedLM",
28 "ElectraForMaskedLM",
29}
31# Vision-language multimodal models
32MULTIMODAL_ARCHITECTURES: set[str] = {
33 "LlavaForConditionalGeneration",
34 "LlavaNextForConditionalGeneration",
35 "LlavaOnevisionForConditionalGeneration",
36 "Gemma3ForConditionalGeneration",
37 "Gemma4ForConditionalGeneration",
38 "Qwen3_5ForConditionalGeneration",
39}
41# Audio encoder models (HuBERT, wav2vec2, etc.)
42AUDIO_ARCHITECTURES: set[str] = {
43 "HubertForCTC",
44 "HubertModel",
45 "HubertForSequenceClassification",
46}
48# Bridge uses different hook shapes than HookedTransformer by design.
49# Phase 2/3 HT comparisons are skipped; Phase 1 (HF comparison) is the gold standard.
50NO_HT_COMPARISON_ARCHITECTURES: set[str] = (
51 MULTIMODAL_ARCHITECTURES
52 | AUDIO_ARCHITECTURES
53 | {
54 "Gemma3ForCausalLM",
55 }
56)
59def classify_architecture(architecture: str) -> str:
60 """Classify an architecture string into a model type.
62 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm"
63 """
64 if architecture in SEQ2SEQ_ARCHITECTURES:
65 return "seq2seq"
66 if architecture in MASKED_LM_ARCHITECTURES:
67 return "masked_lm"
68 if architecture in MULTIMODAL_ARCHITECTURES:
69 return "multimodal"
70 if architecture in AUDIO_ARCHITECTURES:
71 return "audio"
72 return "causal_lm"
75def get_architectures_for_config(config) -> list[str]:
76 """Extract architecture strings from an HF config object."""
77 architectures = []
78 if hasattr(config, "original_architecture"):
79 architectures.append(config.original_architecture)
80 if hasattr(config, "architectures") and config.architectures:
81 architectures.extend(config.architectures)
82 return architectures
85def classify_model_config(config) -> str:
86 """Classify a model by its HF config.
88 Checks config.is_encoder_decoder first, then falls back to architecture list.
89 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm"
90 """
91 if getattr(config, "is_encoder_decoder", False):
92 return "seq2seq"
93 for arch in get_architectures_for_config(config):
94 model_type = classify_architecture(arch)
95 if model_type != "causal_lm":
96 return model_type
97 return "causal_lm"
100def classify_model_name(
101 model_name: str,
102 trust_remote_code: bool = False,
103 token: Optional[str] = None,
104) -> str:
105 """Classify a model by its HuggingFace model name.
107 Loads the config once, classifies from it. If token is None, reads
108 HF_TOKEN from the environment automatically.
109 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm"
110 """
111 try:
112 from transformers import AutoConfig
114 if token is None:
115 from transformer_lens.utilities.hf_utils import get_hf_token
117 token = get_hf_token()
119 config = AutoConfig.from_pretrained(
120 model_name, trust_remote_code=trust_remote_code, token=token
121 )
122 return classify_model_config(config)
123 except Exception:
124 return "causal_lm"
127def is_masked_lm_model(
128 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None
129) -> bool:
130 """Check if a model is a masked language model (BERT-style)."""
131 return (
132 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token)
133 == "masked_lm"
134 )
137def is_encoder_decoder_model(
138 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None
139) -> bool:
140 """Check if a model is an encoder-decoder architecture (T5, BART, etc.)."""
141 return (
142 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token)
143 == "seq2seq"
144 )
147def is_multimodal_model(
148 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None
149) -> bool:
150 """Check if a model is a multimodal vision-language model (LLaVA, Gemma3)."""
151 return (
152 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token)
153 == "multimodal"
154 )
157def is_audio_model(
158 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None
159) -> bool:
160 """Check if a model is an audio encoder model (HuBERT, wav2vec2)."""
161 return (
162 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token) == "audio"
163 )