Coverage for transformer_lens/utilities/architectures.py: 32%
49 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Centralized architecture classification for TransformerLens.
3Single source of truth for architecture type detection. Used by the bridge
4loading pipeline, benchmarks, and verification tools.
5"""
7from typing import Optional
9# Encoder-decoder models (T5, BART, etc.)
10SEQ2SEQ_ARCHITECTURES: set[str] = {
11 "T5ForConditionalGeneration",
12 "MT5ForConditionalGeneration",
13 "BartForConditionalGeneration",
14 "MBartForConditionalGeneration",
15 "MarianMTModel",
16 "PegasusForConditionalGeneration",
17 "BlenderbotForConditionalGeneration",
18 "BlenderbotSmallForConditionalGeneration",
19}
21# Masked language models (BERT-style, no text generation)
22MASKED_LM_ARCHITECTURES: set[str] = {
23 "BertForMaskedLM",
24 "RobertaForMaskedLM",
25 "AlbertForMaskedLM",
26 "DistilBertForMaskedLM",
27 "ElectraForMaskedLM",
28}
30# Vision-language multimodal models
31MULTIMODAL_ARCHITECTURES: set[str] = {
32 "LlavaForConditionalGeneration",
33 "LlavaNextForConditionalGeneration",
34 "LlavaOnevisionForConditionalGeneration",
35 "Gemma3ForConditionalGeneration",
36 "Qwen3_5ForConditionalGeneration",
37}
39# Audio encoder models (HuBERT, wav2vec2, etc.)
40AUDIO_ARCHITECTURES: set[str] = {
41 "HubertForCTC",
42 "HubertModel",
43 "HubertForSequenceClassification",
44}
46# Bridge uses different hook shapes than HookedTransformer by design.
47# Phase 2/3 HT comparisons are skipped; Phase 1 (HF comparison) is the gold standard.
48NO_HT_COMPARISON_ARCHITECTURES: set[str] = (
49 MULTIMODAL_ARCHITECTURES
50 | AUDIO_ARCHITECTURES
51 | {
52 "Gemma3ForCausalLM",
53 }
54)
57def classify_architecture(architecture: str) -> str:
58 """Classify an architecture string into a model type.
60 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm"
61 """
62 if architecture in SEQ2SEQ_ARCHITECTURES:
63 return "seq2seq"
64 if architecture in MASKED_LM_ARCHITECTURES:
65 return "masked_lm"
66 if architecture in MULTIMODAL_ARCHITECTURES:
67 return "multimodal"
68 if architecture in AUDIO_ARCHITECTURES:
69 return "audio"
70 return "causal_lm"
73def get_architectures_for_config(config) -> list[str]:
74 """Extract architecture strings from an HF config object."""
75 architectures = []
76 if hasattr(config, "original_architecture"):
77 architectures.append(config.original_architecture)
78 if hasattr(config, "architectures") and config.architectures:
79 architectures.extend(config.architectures)
80 return architectures
83def classify_model_config(config) -> str:
84 """Classify a model by its HF config.
86 Checks config.is_encoder_decoder first, then falls back to architecture list.
87 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm"
88 """
89 if getattr(config, "is_encoder_decoder", False):
90 return "seq2seq"
91 for arch in get_architectures_for_config(config):
92 model_type = classify_architecture(arch)
93 if model_type != "causal_lm":
94 return model_type
95 return "causal_lm"
98def classify_model_name(
99 model_name: str,
100 trust_remote_code: bool = False,
101 token: Optional[str] = None,
102) -> str:
103 """Classify a model by its HuggingFace model name.
105 Loads the config once, classifies from it. If token is None, reads
106 HF_TOKEN from the environment automatically.
107 Returns one of: "seq2seq", "masked_lm", "multimodal", "audio", "causal_lm"
108 """
109 try:
110 from transformers import AutoConfig
112 if token is None:
113 from transformer_lens.utilities.hf_utils import get_hf_token
115 token = get_hf_token()
117 config = AutoConfig.from_pretrained(
118 model_name, trust_remote_code=trust_remote_code, token=token
119 )
120 return classify_model_config(config)
121 except Exception:
122 return "causal_lm"
125def is_masked_lm_model(
126 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None
127) -> bool:
128 """Check if a model is a masked language model (BERT-style)."""
129 return (
130 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token)
131 == "masked_lm"
132 )
135def is_encoder_decoder_model(
136 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None
137) -> bool:
138 """Check if a model is an encoder-decoder architecture (T5, BART, etc.)."""
139 return (
140 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token)
141 == "seq2seq"
142 )
145def is_multimodal_model(
146 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None
147) -> bool:
148 """Check if a model is a multimodal vision-language model (LLaVA, Gemma3)."""
149 return (
150 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token)
151 == "multimodal"
152 )
155def is_audio_model(
156 model_name: str, trust_remote_code: bool = False, token: Optional[str] = None
157) -> bool:
158 """Check if a model is an audio encoder model (HuBERT, wav2vec2)."""
159 return (
160 classify_model_name(model_name, trust_remote_code=trust_remote_code, token=token) == "audio"
161 )