Coverage for transformer_lens/model_bridge/sources/_bridge_builder.py: 81%

69 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Loader-agnostic helpers for building a TransformerBridge around a pre-loaded model.""" 

2from __future__ import annotations 

3 

4import copy 

5from typing import Any, Callable, Optional 

6 

7import torch 

8from torch import nn 

9 

10from transformer_lens.config import TransformerBridgeConfig 

11from transformer_lens.factories.architecture_adapter_factory import ( 

12 ArchitectureAdapterFactory, 

13) 

14from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

15from transformer_lens.model_bridge.bridge import TransformerBridge 

16 

17# Architecture-agnostic; do not extend per-architecture. 

18_HF_PASSTHROUGH_ATTRS = [ 

19 # OPT 

20 "is_gated_act", 

21 "word_embed_proj_dim", 

22 "do_layer_norm_before", 

23 # BART 

24 "encoder_layers", 

25 "decoder_layers", 

26 "encoder_attention_heads", 

27 "decoder_attention_heads", 

28 "encoder_ffn_dim", 

29 "decoder_ffn_dim", 

30 # Granite 

31 "position_embedding_type", 

32 # Falcon 

33 "parallel_attn", 

34 "multi_query", 

35 "new_decoder_architecture", 

36 "alibi", 

37 "num_ln_in_parallel_attn", 

38 # Mamba (SSM config) 

39 "state_size", 

40 "conv_kernel", 

41 "expand", 

42 "time_step_rank", 

43 "intermediate_size", 

44 # Mamba-2 (additional SSM config) 

45 "n_groups", 

46 "chunk_size", 

47 # Multimodal 

48 "vision_config", 

49 # Cohere 

50 "logit_scale", 

51 "rope_parameters", 

52 # Hybrid/MoE architectures 

53 "layer_types", 

54 "moe_intermediate_size", 

55 "norm_eps", 

56 "attention_bias", 

57 "lm_head_bias", 

58 "router_jitter_noise", 

59 "input_jitter_noise", 

60 "eos_token_id", 

61] 

62 

63 

64def build_bridge_config_from_hf( 

65 hf_config: Any, 

66 architecture: str, 

67 model_name: str, 

68 dtype: torch.dtype, 

69) -> TransformerBridgeConfig: 

70 """Translate an HF config into a :class:`TransformerBridgeConfig`.""" 

71 from transformer_lens.model_bridge.sources.transformers import ( 

72 map_default_transformer_lens_config, 

73 ) 

74 

75 tl_config = map_default_transformer_lens_config(hf_config) 

76 config_dict = dict(tl_config.__dict__) 

77 # HF's attribute_map remaps num_experts → num_local_experts; restore the TL name. 

78 if "num_local_experts" in config_dict and "num_experts" not in config_dict: 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true

79 config_dict["num_experts"] = config_dict["num_local_experts"] 

80 bridge_config = TransformerBridgeConfig.from_dict(config_dict) 

81 bridge_config.architecture = architecture 

82 bridge_config.model_name = model_name 

83 bridge_config.dtype = dtype 

84 

85 for attr in _HF_PASSTHROUGH_ATTRS: 

86 val = getattr(hf_config, attr, None) 

87 if val is not None: 

88 setattr(bridge_config, attr, val) 

89 

90 # Gemma2: HF softcap field names differ from TL's. 

91 final_logit_softcapping = getattr(hf_config, "final_logit_softcapping", None) 

92 if final_logit_softcapping is not None: 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true

93 bridge_config.output_logits_soft_cap = float(final_logit_softcapping) 

94 attn_logit_softcapping = getattr(hf_config, "attn_logit_softcapping", None) 

95 if attn_logit_softcapping is not None: 95 ↛ 96line 95 didn't jump to line 96 because the condition on line 95 was never true

96 bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping) 

97 

98 return bridge_config 

99 

100 

101def detect_tokenizer_bos_eos(tokenizer: Any) -> tuple[bool, bool]: 

102 """Detect whether the tokenizer prepends BOS and/or appends EOS. 

103 

104 Non-empty test string — "" is unreliable with token aliasing. 

105 """ 

106 encoded_test = tokenizer.encode("a") 

107 prepends_bos = ( 

108 len(encoded_test) > 1 

109 and tokenizer.bos_token_id is not None 

110 and encoded_test[0] == tokenizer.bos_token_id 

111 ) 

112 appends_eos = ( 

113 len(encoded_test) > 1 

114 and tokenizer.eos_token_id is not None 

115 and encoded_test[-1] == tokenizer.eos_token_id 

116 ) 

117 return prepends_bos, appends_eos 

118 

119 

120def build_bridge_from_module( 

121 model: nn.Module, 

122 architecture: str, 

123 *, 

124 hf_config: Optional[Any] = None, 

125 tl_config: Optional[TransformerBridgeConfig] = None, 

126 tokenizer: Optional[Any] = None, 

127 dtype: Optional[torch.dtype] = None, 

128 device: Optional[Any] = None, 

129 model_name: str = "external", 

130 post_adapter_hook: Optional[Callable[[ArchitectureAdapter], None]] = None, 

131) -> TransformerBridge: 

132 """Build a :class:`TransformerBridge` around a pre-loaded model. 

133 

134 The bridge never moves, casts, or mutates the supplied model. 

135 

136 Args: 

137 model: Any ``nn.Module`` whose submodule tree matches the adapter's 

138 expected dot-paths for ``architecture``. 

139 architecture: Architecture identifier registered in the 

140 ``ArchitectureAdapterFactory`` (e.g. ``"LlamaForCausalLM"``, 

141 ``"TransformerLensNative"``). 

142 hf_config: Optional HF-style config; translated via 

143 :func:`build_bridge_config_from_hf`. Mutually exclusive with ``tl_config``. 

144 tl_config: Optional pre-built :class:`TransformerBridgeConfig`; bypasses 

145 HF translation. Mutually exclusive with ``hf_config``. 

146 tokenizer: Optional tokenizer. If supplied, passes through 

147 ``setup_tokenizer`` and detects BOS/EOS behavior. 

148 dtype: Recorded on ``cfg.dtype``. Default ``None`` reads from the model's 

149 first parameter; explicit values override. 

150 device: Recorded on ``cfg.device``. Default ``None`` reads from the 

151 model's first parameter. 

152 model_name: Recorded on ``cfg.model_name``. 

153 post_adapter_hook: Optional callback invoked after adapter selection and 

154 before :meth:`adapter.prepare_model`. Source-specific overlays mutate 

155 ``component_mapping`` here. 

156 

157 Returns: 

158 A :class:`TransformerBridge` wrapping the supplied model. 

159 """ 

160 if hf_config is None and tl_config is None: 

161 raise ValueError( 

162 "build_bridge_from_module requires exactly one of hf_config or " 

163 "tl_config — the bridge needs config fields (d_model, n_heads, " 

164 "n_layers, ...) that can't be inferred from the model alone." 

165 ) 

166 if hf_config is not None and tl_config is not None: 

167 raise ValueError( 

168 "build_bridge_from_module got both hf_config and tl_config; supply " 

169 "exactly one. hf_config triggers HF→bridge translation; tl_config " 

170 "bypasses it." 

171 ) 

172 

173 # Reading dtype from the model avoids silently lying about a bf16 model. 

174 if dtype is None: 

175 try: 

176 dtype = next(model.parameters()).dtype 

177 except StopIteration: 

178 dtype = torch.float32 

179 

180 if tl_config is not None: 

181 # Defensive copy so adapter-init mutations (normalization_type, device, 

182 # ...) don't leak between bridges built from the same config. 

183 bridge_config = copy.deepcopy(tl_config) 

184 bridge_config.architecture = architecture 

185 if model_name != "external" or not getattr(bridge_config, "model_name", None): 

186 bridge_config.model_name = model_name 

187 bridge_config.dtype = dtype 

188 else: 

189 bridge_config = build_bridge_config_from_hf(hf_config, architecture, model_name, dtype) 

190 

191 adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config) 

192 

193 if post_adapter_hook is not None: 

194 post_adapter_hook(adapter) 

195 

196 if device is not None: 

197 adapter.cfg.device = str(device) 

198 else: 

199 try: 

200 adapter.cfg.device = str(next(model.parameters()).device) 

201 except StopIteration: 

202 adapter.cfg.device = "cpu" 

203 

204 adapter.prepare_model(model) 

205 

206 if tokenizer is not None: 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true

207 from transformer_lens.model_bridge.sources.transformers import setup_tokenizer 

208 

209 default_padding_side = getattr(adapter.cfg, "default_padding_side", None) 

210 tokenizer = setup_tokenizer(tokenizer, default_padding_side=default_padding_side) 

211 ( 

212 adapter.cfg.tokenizer_prepends_bos, 

213 adapter.cfg.tokenizer_appends_eos, 

214 ) = detect_tokenizer_bos_eos(tokenizer) 

215 

216 return TransformerBridge(model, adapter, tokenizer)