Coverage for transformer_lens/utilities/defaults_utils.py: 100%

34 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""attribute_utils. 

2 

3This module contains utility functions related to defaults 

4""" 

5 

6from __future__ import annotations 

7 

8from copy import deepcopy 

9from typing import Any, Optional 

10 

11from .attribute_utils import get_nested_attr, set_nested_attr 

12 

13USE_DEFAULT_VALUE = None 

14 

15 

16def override_or_use_default_value( 

17 default_flag: Any, 

18 override: Optional[Any] = None, 

19) -> Any: 

20 """ 

21 Determines which flag to return based on whether an overriding flag is provided. 

22 If a not-None overriding flag is provided, it is returned. 

23 Otherwise, the global flag is returned. 

24 """ 

25 return override if override is not None else default_flag 

26 

27 

28class LocallyOverridenDefaults: 

29 """ 

30 Context manager that allows temporary overriding of default values within a model. 

31 Once the context is exited, the default values are restored. 

32 

33 WARNING: This context manager must be used for any function/method that directly accesses 

34 default values which may be overridden by the user using the function/method's arguments, 

35 e.g., `model.cfg.default_prepend_bos` and `model.tokenizer.padding_side` which can be 

36 overriden by `prepend_bos` and `padding_side` arguments, respectively, in the `to_tokens`. 

37 """ 

38 

39 def __init__(self, model, **overrides): 

40 """ 

41 Initializes the context manager. 

42 

43 Args: 

44 model (HookedTransformer): The model whose default values will be overridden. 

45 overrides (dict): Key-value pairs of properties to override and their new values. 

46 """ 

47 self.model = model 

48 self.overrides = overrides 

49 

50 # Dictionary defining valid defaults, valid values, and locations to find and store them 

51 self.values_with_defaults = { 

52 "prepend_bos": { 

53 "default_location": "model.cfg.default_prepend_bos", 

54 "valid_values": [USE_DEFAULT_VALUE, True, False], 

55 "skip_overriding": False, 

56 "default_value_to_restore": None, # Will be set later 

57 }, 

58 "padding_side": { 

59 "default_location": "model.tokenizer.padding_side", 

60 "valid_values": [USE_DEFAULT_VALUE, "left", "right"], 

61 "skip_overriding": model.tokenizer is None, # Do not override if tokenizer is None 

62 "default_value_to_restore": None, # Will be set later 

63 }, 

64 } 

65 

66 # Ensure provided overrides are defined in the dictionary above 

67 for override in overrides: 

68 assert override in self.values_with_defaults, ( 

69 f"{override} is not a valid parameter to override. " 

70 f"Valid parameters are {self.values_with_defaults.keys()}." 

71 ) 

72 

73 def __enter__(self): 

74 """ 

75 Override default values upon entering the context. 

76 """ 

77 for property, override in self.overrides.items(): 

78 info = self.values_with_defaults[property] 

79 if info["skip_overriding"]: 

80 continue # Skip if overriding for this property is disabled 

81 

82 # Ensure the override is a valid value 

83 valid_values = info["valid_values"] 

84 assert ( 

85 override in valid_values # type: ignore 

86 ), f"{property} must be one of {valid_values}, but got {override}." 

87 

88 # Fetch current default and store it to restore later 

89 default_location = info["default_location"] 

90 default_value = get_nested_attr(self, default_location) 

91 info["default_value_to_restore"] = deepcopy(default_value) 

92 

93 # Override the default value 

94 locally_overriden_value = override_or_use_default_value(default_value, override) 

95 set_nested_attr(self, default_location, locally_overriden_value) 

96 

97 def __exit__(self, exc_type, exc_val, exc_tb): 

98 """ 

99 Restore default values upon exiting the context. 

100 """ 

101 for property in self.overrides: 

102 info = self.values_with_defaults[property] 

103 if info["skip_overriding"]: 

104 continue 

105 

106 # Restore the default value from before the context was entered 

107 default_location = info["default_location"] 

108 default_value = info["default_value_to_restore"] 

109 set_nested_attr(self, default_location, default_value)