Coverage for transformer_lens/conversion_utils/hook_conversion_utils.py: 70%

23 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-17 18:55 +0000

1"""Weight conversion utilities.""" 

2 

3import torch 

4 

5from transformer_lens.loading_from_pretrained import get_pretrained_model_config 

6 

7 

8def get_weight_conversion_field_set(weights: dict) -> str: 

9 """Creates a formatted string showing how weights are mapped between frameworks. 

10 

11 Args: 

12 weights: Dictionary containing weight mappings where: 

13 - keys are TransformerLens weight names 

14 - values can be: 

15 * tuple[str, "BaseTensorConversion"] 

16 * torch.Tensor 

17 * strings 

18 

19 Returns: 

20 A formatted multi-line string showing each weight's mapping details. 

21 """ 

22 conversion_string = "" 

23 for transformer_lens_weight in weights: 

24 hugging_face_weight = weights[transformer_lens_weight] 

25 

26 # Case 1: Nested conversion, call __repr__ for the nested conversion 

27 if isinstance(hugging_face_weight, tuple): 

28 weight_name, conversion = hugging_face_weight 

29 conversion_string += ( 

30 f'"{transformer_lens_weight}" -> "{weight_name}", {conversion.__repr__()}\n' 

31 ) 

32 

33 # Case 2: Tensor, display shape and content 

34 elif isinstance(hugging_face_weight, torch.Tensor): 

35 if torch.all(hugging_face_weight == 0): 35 ↛ 37line 35 didn't jump to line 37 because the condition on line 35 was always true

36 conversion_string += f'"{transformer_lens_weight}" -> "Tensor filled with zeros of shape {hugging_face_weight.shape}",\n' 

37 elif torch.all(hugging_face_weight == 1): 

38 conversion_string += f'"{transformer_lens_weight}" -> "Tensor filled with ones of shape {hugging_face_weight.shape}",\n' 

39 else: 

40 conversion_string += f'"{transformer_lens_weight}" -> "Tensor of shape {hugging_face_weight.shape}",\n' 

41 

42 # Case 3: String, just display string (name of weight in HuggingFace) 

43 else: 

44 conversion_string += f'"{transformer_lens_weight}" -> "{hugging_face_weight}",\n' 

45 return conversion_string 

46 

47 

48def model_info_cfg(cfg): 

49 """ 

50 Displays the weight conversion from HuggingFace to TransformerLens for a given model configuration. 

51 

52 Args: 

53 cfg: Model configuration object containing architecture information 

54 """ 

55 

56 # TODO: WeightConversionFactory import needs to be updated or removed 

57 # from transformer_lens.factories.weight_conversion_factory import ( 

58 # WeightConversionFactory, 

59 # ) 

60 

61 # weight_conversion = WeightConversionFactory.select_weight_conversion_config(cfg) 

62 print(f"Hook conversion details for architecture {cfg.original_architecture}:") 

63 # print(weight_conversion.__repr__()) 

64 print("Hook conversion factory not yet implemented") 

65 

66 

67def model_info(model_name): 

68 """ 

69 Displays the weight conversion from HuggingFace to TransformerLens for a given model name. 

70 

71 Args: 

72 model_name (str): Name of the pretrained model to analyze 

73 (e.g., 'gpt2', 'bert-base-uncased', etc.) 

74 """ 

75 cfg = get_pretrained_model_config(model_name) 

76 model_info_cfg(cfg)