Coverage for transformer_lens/model_bridge/sources/_bridge_builder.py: 53%
69 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Loader-agnostic helpers for building a TransformerBridge around a pre-loaded model."""
2from __future__ import annotations
4import copy
5from typing import Any, Callable, Optional
7import torch
8from torch import nn
10from transformer_lens.config import TransformerBridgeConfig
11from transformer_lens.factories.architecture_adapter_factory import (
12 ArchitectureAdapterFactory,
13)
14from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
15from transformer_lens.model_bridge.bridge import TransformerBridge
17# Architecture-agnostic; do not extend per-architecture.
18_HF_PASSTHROUGH_ATTRS = [
19 # OPT
20 "is_gated_act",
21 "word_embed_proj_dim",
22 "do_layer_norm_before",
23 # Granite
24 "position_embedding_type",
25 # Falcon
26 "parallel_attn",
27 "multi_query",
28 "new_decoder_architecture",
29 "alibi",
30 "num_ln_in_parallel_attn",
31 # Mamba (SSM config)
32 "state_size",
33 "conv_kernel",
34 "expand",
35 "time_step_rank",
36 "intermediate_size",
37 # Mamba-2 (additional SSM config)
38 "n_groups",
39 "chunk_size",
40 # Multimodal
41 "vision_config",
42 # Cohere
43 "logit_scale",
44 "rope_parameters",
45]
48def build_bridge_config_from_hf(
49 hf_config: Any,
50 architecture: str,
51 model_name: str,
52 dtype: torch.dtype,
53) -> TransformerBridgeConfig:
54 """Translate an HF config into a :class:`TransformerBridgeConfig`."""
55 from transformer_lens.model_bridge.sources.transformers import (
56 map_default_transformer_lens_config,
57 )
59 tl_config = map_default_transformer_lens_config(hf_config)
60 config_dict = dict(tl_config.__dict__)
61 # HF's attribute_map remaps num_experts → num_local_experts; restore the TL name.
62 if "num_local_experts" in config_dict and "num_experts" not in config_dict:
63 config_dict["num_experts"] = config_dict["num_local_experts"]
64 bridge_config = TransformerBridgeConfig.from_dict(config_dict)
65 bridge_config.architecture = architecture
66 bridge_config.model_name = model_name
67 bridge_config.dtype = dtype
69 for attr in _HF_PASSTHROUGH_ATTRS:
70 val = getattr(hf_config, attr, None)
71 if val is not None:
72 setattr(bridge_config, attr, val)
74 # Gemma2: HF softcap field names differ from TL's.
75 final_logit_softcapping = getattr(hf_config, "final_logit_softcapping", None)
76 if final_logit_softcapping is not None:
77 bridge_config.output_logits_soft_cap = float(final_logit_softcapping)
78 attn_logit_softcapping = getattr(hf_config, "attn_logit_softcapping", None)
79 if attn_logit_softcapping is not None:
80 bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping)
82 return bridge_config
85def detect_tokenizer_bos_eos(tokenizer: Any) -> tuple[bool, bool]:
86 """Detect whether the tokenizer prepends BOS and/or appends EOS.
88 Non-empty test string — "" is unreliable with token aliasing.
89 """
90 encoded_test = tokenizer.encode("a")
91 prepends_bos = (
92 len(encoded_test) > 1
93 and tokenizer.bos_token_id is not None
94 and encoded_test[0] == tokenizer.bos_token_id
95 )
96 appends_eos = (
97 len(encoded_test) > 1
98 and tokenizer.eos_token_id is not None
99 and encoded_test[-1] == tokenizer.eos_token_id
100 )
101 return prepends_bos, appends_eos
104def build_bridge_from_module(
105 model: nn.Module,
106 architecture: str,
107 *,
108 hf_config: Optional[Any] = None,
109 tl_config: Optional[TransformerBridgeConfig] = None,
110 tokenizer: Optional[Any] = None,
111 dtype: Optional[torch.dtype] = None,
112 device: Optional[Any] = None,
113 model_name: str = "external",
114 post_adapter_hook: Optional[Callable[[ArchitectureAdapter], None]] = None,
115) -> TransformerBridge:
116 """Build a :class:`TransformerBridge` around a pre-loaded model.
118 The bridge never moves, casts, or mutates the supplied model.
120 Args:
121 model: Any ``nn.Module`` whose submodule tree matches the adapter's
122 expected dot-paths for ``architecture``.
123 architecture: Architecture identifier registered in the
124 ``ArchitectureAdapterFactory`` (e.g. ``"LlamaForCausalLM"``,
125 ``"TransformerLensNative"``).
126 hf_config: Optional HF-style config; translated via
127 :func:`build_bridge_config_from_hf`. Mutually exclusive with ``tl_config``.
128 tl_config: Optional pre-built :class:`TransformerBridgeConfig`; bypasses
129 HF translation. Mutually exclusive with ``hf_config``.
130 tokenizer: Optional tokenizer. If supplied, passes through
131 ``setup_tokenizer`` and detects BOS/EOS behavior.
132 dtype: Recorded on ``cfg.dtype``. Default ``None`` reads from the model's
133 first parameter; explicit values override.
134 device: Recorded on ``cfg.device``. Default ``None`` reads from the
135 model's first parameter.
136 model_name: Recorded on ``cfg.model_name``.
137 post_adapter_hook: Optional callback invoked after adapter selection and
138 before :meth:`adapter.prepare_model`. Source-specific overlays mutate
139 ``component_mapping`` here.
141 Returns:
142 A :class:`TransformerBridge` wrapping the supplied model.
143 """
144 if hf_config is None and tl_config is None:
145 raise ValueError(
146 "build_bridge_from_module requires exactly one of hf_config or "
147 "tl_config — the bridge needs config fields (d_model, n_heads, "
148 "n_layers, ...) that can't be inferred from the model alone."
149 )
150 if hf_config is not None and tl_config is not None:
151 raise ValueError(
152 "build_bridge_from_module got both hf_config and tl_config; supply "
153 "exactly one. hf_config triggers HF→bridge translation; tl_config "
154 "bypasses it."
155 )
157 # Reading dtype from the model avoids silently lying about a bf16 model.
158 if dtype is None:
159 try:
160 dtype = next(model.parameters()).dtype
161 except StopIteration:
162 dtype = torch.float32
164 if tl_config is not None: 164 ↛ 173line 164 didn't jump to line 173 because the condition on line 164 was always true
165 # Defensive copy so adapter-init mutations (normalization_type, device,
166 # ...) don't leak between bridges built from the same config.
167 bridge_config = copy.deepcopy(tl_config)
168 bridge_config.architecture = architecture
169 if model_name != "external" or not getattr(bridge_config, "model_name", None):
170 bridge_config.model_name = model_name
171 bridge_config.dtype = dtype
172 else:
173 bridge_config = build_bridge_config_from_hf(hf_config, architecture, model_name, dtype)
175 adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config)
177 if post_adapter_hook is not None:
178 post_adapter_hook(adapter)
180 if device is not None: 180 ↛ 181line 180 didn't jump to line 181 because the condition on line 180 was never true
181 adapter.cfg.device = str(device)
182 else:
183 try:
184 adapter.cfg.device = str(next(model.parameters()).device)
185 except StopIteration:
186 adapter.cfg.device = "cpu"
188 adapter.prepare_model(model)
190 if tokenizer is not None: 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true
191 from transformer_lens.model_bridge.sources.transformers import setup_tokenizer
193 default_padding_side = getattr(adapter.cfg, "default_padding_side", None)
194 tokenizer = setup_tokenizer(tokenizer, default_padding_side=default_padding_side)
195 (
196 adapter.cfg.tokenizer_prepends_bos,
197 adapter.cfg.tokenizer_appends_eos,
198 ) = detect_tokenizer_bos_eos(tokenizer)
200 return TransformerBridge(model, adapter, tokenizer)