Coverage for transformer_lens/model_bridge/sources/_bridge_builder.py: 53%

69 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Loader-agnostic helpers for building a TransformerBridge around a pre-loaded model.""" 

2from __future__ import annotations 

3 

4import copy 

5from typing import Any, Callable, Optional 

6 

7import torch 

8from torch import nn 

9 

10from transformer_lens.config import TransformerBridgeConfig 

11from transformer_lens.factories.architecture_adapter_factory import ( 

12 ArchitectureAdapterFactory, 

13) 

14from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

15from transformer_lens.model_bridge.bridge import TransformerBridge 

16 

17# Architecture-agnostic; do not extend per-architecture. 

18_HF_PASSTHROUGH_ATTRS = [ 

19 # OPT 

20 "is_gated_act", 

21 "word_embed_proj_dim", 

22 "do_layer_norm_before", 

23 # Granite 

24 "position_embedding_type", 

25 # Falcon 

26 "parallel_attn", 

27 "multi_query", 

28 "new_decoder_architecture", 

29 "alibi", 

30 "num_ln_in_parallel_attn", 

31 # Mamba (SSM config) 

32 "state_size", 

33 "conv_kernel", 

34 "expand", 

35 "time_step_rank", 

36 "intermediate_size", 

37 # Mamba-2 (additional SSM config) 

38 "n_groups", 

39 "chunk_size", 

40 # Multimodal 

41 "vision_config", 

42 # Cohere 

43 "logit_scale", 

44 "rope_parameters", 

45] 

46 

47 

48def build_bridge_config_from_hf( 

49 hf_config: Any, 

50 architecture: str, 

51 model_name: str, 

52 dtype: torch.dtype, 

53) -> TransformerBridgeConfig: 

54 """Translate an HF config into a :class:`TransformerBridgeConfig`.""" 

55 from transformer_lens.model_bridge.sources.transformers import ( 

56 map_default_transformer_lens_config, 

57 ) 

58 

59 tl_config = map_default_transformer_lens_config(hf_config) 

60 config_dict = dict(tl_config.__dict__) 

61 # HF's attribute_map remaps num_experts → num_local_experts; restore the TL name. 

62 if "num_local_experts" in config_dict and "num_experts" not in config_dict: 

63 config_dict["num_experts"] = config_dict["num_local_experts"] 

64 bridge_config = TransformerBridgeConfig.from_dict(config_dict) 

65 bridge_config.architecture = architecture 

66 bridge_config.model_name = model_name 

67 bridge_config.dtype = dtype 

68 

69 for attr in _HF_PASSTHROUGH_ATTRS: 

70 val = getattr(hf_config, attr, None) 

71 if val is not None: 

72 setattr(bridge_config, attr, val) 

73 

74 # Gemma2: HF softcap field names differ from TL's. 

75 final_logit_softcapping = getattr(hf_config, "final_logit_softcapping", None) 

76 if final_logit_softcapping is not None: 

77 bridge_config.output_logits_soft_cap = float(final_logit_softcapping) 

78 attn_logit_softcapping = getattr(hf_config, "attn_logit_softcapping", None) 

79 if attn_logit_softcapping is not None: 

80 bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping) 

81 

82 return bridge_config 

83 

84 

85def detect_tokenizer_bos_eos(tokenizer: Any) -> tuple[bool, bool]: 

86 """Detect whether the tokenizer prepends BOS and/or appends EOS. 

87 

88 Non-empty test string — "" is unreliable with token aliasing. 

89 """ 

90 encoded_test = tokenizer.encode("a") 

91 prepends_bos = ( 

92 len(encoded_test) > 1 

93 and tokenizer.bos_token_id is not None 

94 and encoded_test[0] == tokenizer.bos_token_id 

95 ) 

96 appends_eos = ( 

97 len(encoded_test) > 1 

98 and tokenizer.eos_token_id is not None 

99 and encoded_test[-1] == tokenizer.eos_token_id 

100 ) 

101 return prepends_bos, appends_eos 

102 

103 

104def build_bridge_from_module( 

105 model: nn.Module, 

106 architecture: str, 

107 *, 

108 hf_config: Optional[Any] = None, 

109 tl_config: Optional[TransformerBridgeConfig] = None, 

110 tokenizer: Optional[Any] = None, 

111 dtype: Optional[torch.dtype] = None, 

112 device: Optional[Any] = None, 

113 model_name: str = "external", 

114 post_adapter_hook: Optional[Callable[[ArchitectureAdapter], None]] = None, 

115) -> TransformerBridge: 

116 """Build a :class:`TransformerBridge` around a pre-loaded model. 

117 

118 The bridge never moves, casts, or mutates the supplied model. 

119 

120 Args: 

121 model: Any ``nn.Module`` whose submodule tree matches the adapter's 

122 expected dot-paths for ``architecture``. 

123 architecture: Architecture identifier registered in the 

124 ``ArchitectureAdapterFactory`` (e.g. ``"LlamaForCausalLM"``, 

125 ``"TransformerLensNative"``). 

126 hf_config: Optional HF-style config; translated via 

127 :func:`build_bridge_config_from_hf`. Mutually exclusive with ``tl_config``. 

128 tl_config: Optional pre-built :class:`TransformerBridgeConfig`; bypasses 

129 HF translation. Mutually exclusive with ``hf_config``. 

130 tokenizer: Optional tokenizer. If supplied, passes through 

131 ``setup_tokenizer`` and detects BOS/EOS behavior. 

132 dtype: Recorded on ``cfg.dtype``. Default ``None`` reads from the model's 

133 first parameter; explicit values override. 

134 device: Recorded on ``cfg.device``. Default ``None`` reads from the 

135 model's first parameter. 

136 model_name: Recorded on ``cfg.model_name``. 

137 post_adapter_hook: Optional callback invoked after adapter selection and 

138 before :meth:`adapter.prepare_model`. Source-specific overlays mutate 

139 ``component_mapping`` here. 

140 

141 Returns: 

142 A :class:`TransformerBridge` wrapping the supplied model. 

143 """ 

144 if hf_config is None and tl_config is None: 

145 raise ValueError( 

146 "build_bridge_from_module requires exactly one of hf_config or " 

147 "tl_config — the bridge needs config fields (d_model, n_heads, " 

148 "n_layers, ...) that can't be inferred from the model alone." 

149 ) 

150 if hf_config is not None and tl_config is not None: 

151 raise ValueError( 

152 "build_bridge_from_module got both hf_config and tl_config; supply " 

153 "exactly one. hf_config triggers HF→bridge translation; tl_config " 

154 "bypasses it." 

155 ) 

156 

157 # Reading dtype from the model avoids silently lying about a bf16 model. 

158 if dtype is None: 

159 try: 

160 dtype = next(model.parameters()).dtype 

161 except StopIteration: 

162 dtype = torch.float32 

163 

164 if tl_config is not None: 164 ↛ 173line 164 didn't jump to line 173 because the condition on line 164 was always true

165 # Defensive copy so adapter-init mutations (normalization_type, device, 

166 # ...) don't leak between bridges built from the same config. 

167 bridge_config = copy.deepcopy(tl_config) 

168 bridge_config.architecture = architecture 

169 if model_name != "external" or not getattr(bridge_config, "model_name", None): 

170 bridge_config.model_name = model_name 

171 bridge_config.dtype = dtype 

172 else: 

173 bridge_config = build_bridge_config_from_hf(hf_config, architecture, model_name, dtype) 

174 

175 adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config) 

176 

177 if post_adapter_hook is not None: 

178 post_adapter_hook(adapter) 

179 

180 if device is not None: 180 ↛ 181line 180 didn't jump to line 181 because the condition on line 180 was never true

181 adapter.cfg.device = str(device) 

182 else: 

183 try: 

184 adapter.cfg.device = str(next(model.parameters()).device) 

185 except StopIteration: 

186 adapter.cfg.device = "cpu" 

187 

188 adapter.prepare_model(model) 

189 

190 if tokenizer is not None: 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true

191 from transformer_lens.model_bridge.sources.transformers import setup_tokenizer 

192 

193 default_padding_side = getattr(adapter.cfg, "default_padding_side", None) 

194 tokenizer = setup_tokenizer(tokenizer, default_padding_side=default_padding_side) 

195 ( 

196 adapter.cfg.tokenizer_prepends_bos, 

197 adapter.cfg.tokenizer_appends_eos, 

198 ) = detect_tokenizer_bos_eos(tokenizer) 

199 

200 return TransformerBridge(model, adapter, tokenizer)