Coverage for transformer_lens/utilities/defaults_utils.py: 100%
34 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""attribute_utils.
3This module contains utility functions related to defaults
4"""
6from __future__ import annotations
8from copy import deepcopy
9from typing import Any, Optional
11from .attribute_utils import get_nested_attr, set_nested_attr
13USE_DEFAULT_VALUE = None
16def override_or_use_default_value(
17 default_flag: Any,
18 override: Optional[Any] = None,
19) -> Any:
20 """
21 Determines which flag to return based on whether an overriding flag is provided.
22 If a not-None overriding flag is provided, it is returned.
23 Otherwise, the global flag is returned.
24 """
25 return override if override is not None else default_flag
28class LocallyOverridenDefaults:
29 """
30 Context manager that allows temporary overriding of default values within a model.
31 Once the context is exited, the default values are restored.
33 WARNING: This context manager must be used for any function/method that directly accesses
34 default values which may be overridden by the user using the function/method's arguments,
35 e.g., `model.cfg.default_prepend_bos` and `model.tokenizer.padding_side` which can be
36 overriden by `prepend_bos` and `padding_side` arguments, respectively, in the `to_tokens`.
37 """
39 def __init__(self, model, **overrides):
40 """
41 Initializes the context manager.
43 Args:
44 model (HookedTransformer): The model whose default values will be overridden.
45 overrides (dict): Key-value pairs of properties to override and their new values.
46 """
47 self.model = model
48 self.overrides = overrides
50 # Dictionary defining valid defaults, valid values, and locations to find and store them
51 self.values_with_defaults = {
52 "prepend_bos": {
53 "default_location": "model.cfg.default_prepend_bos",
54 "valid_values": [USE_DEFAULT_VALUE, True, False],
55 "skip_overriding": False,
56 "default_value_to_restore": None, # Will be set later
57 },
58 "padding_side": {
59 "default_location": "model.tokenizer.padding_side",
60 "valid_values": [USE_DEFAULT_VALUE, "left", "right"],
61 "skip_overriding": model.tokenizer is None, # Do not override if tokenizer is None
62 "default_value_to_restore": None, # Will be set later
63 },
64 }
66 # Ensure provided overrides are defined in the dictionary above
67 for override in overrides:
68 assert override in self.values_with_defaults, (
69 f"{override} is not a valid parameter to override. "
70 f"Valid parameters are {self.values_with_defaults.keys()}."
71 )
73 def __enter__(self):
74 """
75 Override default values upon entering the context.
76 """
77 for property, override in self.overrides.items():
78 info = self.values_with_defaults[property]
79 if info["skip_overriding"]:
80 continue # Skip if overriding for this property is disabled
82 # Ensure the override is a valid value
83 valid_values = info["valid_values"]
84 assert (
85 override in valid_values # type: ignore
86 ), f"{property} must be one of {valid_values}, but got {override}."
88 # Fetch current default and store it to restore later
89 default_location = info["default_location"]
90 default_value = get_nested_attr(self, default_location)
91 info["default_value_to_restore"] = deepcopy(default_value)
93 # Override the default value
94 locally_overriden_value = override_or_use_default_value(default_value, override)
95 set_nested_attr(self, default_location, locally_overriden_value)
97 def __exit__(self, exc_type, exc_val, exc_tb):
98 """
99 Restore default values upon exiting the context.
100 """
101 for property in self.overrides:
102 info = self.values_with_defaults[property]
103 if info["skip_overriding"]:
104 continue
106 # Restore the default value from before the context was entered
107 default_location = info["default_location"]
108 default_value = info["default_value_to_restore"]
109 set_nested_attr(self, default_location, default_value)