Coverage for transformer_lens/conversion_utils/hook_conversion_utils.py: 70%
23 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1"""Weight conversion utilities."""
3import torch
5from transformer_lens.loading_from_pretrained import get_pretrained_model_config
8def get_weight_conversion_field_set(weights: dict) -> str:
9 """Creates a formatted string showing how weights are mapped between frameworks.
11 Args:
12 weights: Dictionary containing weight mappings where:
13 - keys are TransformerLens weight names
14 - values can be:
15 * tuple[str, "BaseTensorConversion"]
16 * torch.Tensor
17 * strings
19 Returns:
20 A formatted multi-line string showing each weight's mapping details.
21 """
22 conversion_string = ""
23 for transformer_lens_weight in weights:
24 hugging_face_weight = weights[transformer_lens_weight]
26 # Case 1: Nested conversion, call __repr__ for the nested conversion
27 if isinstance(hugging_face_weight, tuple):
28 weight_name, conversion = hugging_face_weight
29 conversion_string += (
30 f'"{transformer_lens_weight}" -> "{weight_name}", {conversion.__repr__()}\n'
31 )
33 # Case 2: Tensor, display shape and content
34 elif isinstance(hugging_face_weight, torch.Tensor):
35 if torch.all(hugging_face_weight == 0): 35 ↛ 37line 35 didn't jump to line 37 because the condition on line 35 was always true
36 conversion_string += f'"{transformer_lens_weight}" -> "Tensor filled with zeros of shape {hugging_face_weight.shape}",\n'
37 elif torch.all(hugging_face_weight == 1):
38 conversion_string += f'"{transformer_lens_weight}" -> "Tensor filled with ones of shape {hugging_face_weight.shape}",\n'
39 else:
40 conversion_string += f'"{transformer_lens_weight}" -> "Tensor of shape {hugging_face_weight.shape}",\n'
42 # Case 3: String, just display string (name of weight in HuggingFace)
43 else:
44 conversion_string += f'"{transformer_lens_weight}" -> "{hugging_face_weight}",\n'
45 return conversion_string
48def model_info_cfg(cfg):
49 """
50 Displays the weight conversion from HuggingFace to TransformerLens for a given model configuration.
52 Args:
53 cfg: Model configuration object containing architecture information
54 """
56 # TODO: WeightConversionFactory import needs to be updated or removed
57 # from transformer_lens.factories.weight_conversion_factory import (
58 # WeightConversionFactory,
59 # )
61 # weight_conversion = WeightConversionFactory.select_weight_conversion_config(cfg)
62 print(f"Hook conversion details for architecture {cfg.original_architecture}:")
63 # print(weight_conversion.__repr__())
64 print("Hook conversion factory not yet implemented")
67def model_info(model_name):
68 """
69 Displays the weight conversion from HuggingFace to TransformerLens for a given model name.
71 Args:
72 model_name (str): Name of the pretrained model to analyze
73 (e.g., 'gpt2', 'bert-base-uncased', etc.)
74 """
75 cfg = get_pretrained_model_config(model_name)
76 model_info_cfg(cfg)