Coverage for transformer_lens/utilities/aliases.py: 87%
59 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Utilities for handling hook aliases in the bridge system."""
3import warnings
4from typing import Any, Dict, List, Optional, Set, Union
7def resolve_alias(
8 target_object: Any,
9 requested_name: str,
10 aliases: Dict[str, str] | Dict[str, Union[str, List[str]]],
11) -> Optional[Any]:
12 """Resolve a hook alias to the actual hook object.
14 Args:
15 target_object: The object to get the resolved attribute from
16 requested_name: The name being requested (potentially an alias)
17 aliases: Dictionary mapping alias names to target names
19 Returns:
20 The resolved hook object if alias found, None otherwise
21 """
22 if requested_name in aliases:
23 target_name = aliases[requested_name]
25 if hasattr(target_object, "disable_warnings") and target_object.disable_warnings == False:
26 warnings.warn(
27 f"Hook '{requested_name}' is deprecated and will be removed in a future version. "
28 f"Use '{target_name}' instead.",
29 FutureWarning,
30 stacklevel=3, # Adjusted for utility function call
31 )
33 def _resolve_single_target(target_name: str) -> Any:
34 """Helper function to resolve a single target name."""
35 target_name_split = target_name.split(".")
36 # Resolve dotted paths; list-based aliases try each option
37 if len(target_name_split) > 1:
38 current_attr = target_object
39 for i in range(len(target_name_split) - 1):
40 if not hasattr(current_attr, target_name_split[i]):
41 # Raise so list-based aliases can try next option
42 raise AttributeError(
43 f"'{type(current_attr).__name__}' object has no attribute '{target_name_split[i]}'"
44 )
45 current_attr = getattr(current_attr, target_name_split[i])
47 # Check if the final attribute exists
48 if not hasattr(current_attr, target_name_split[-1]): 48 ↛ 49line 48 didn't jump to line 49 because the condition on line 48 was never true
49 raise AttributeError(
50 f"'{type(current_attr).__name__}' object has no attribute '{target_name_split[-1]}'"
51 )
52 next_attr = getattr(current_attr, target_name_split[-1])
53 return next_attr
54 else:
55 # Check if the target attribute exists before getting it
56 if not hasattr(target_object, target_name):
57 raise AttributeError(
58 f"'{type(target_object).__name__}' object has no attribute '{target_name}'"
59 )
60 # Return the target hook
61 return getattr(target_object, target_name)
63 # if the target_name is a list, we check all elements
64 if isinstance(target_name, list): 64 ↛ 65line 64 didn't jump to line 65 because the condition on line 64 was never true
65 for target_name_item in target_name:
66 try:
67 result = _resolve_single_target(target_name_item)
68 return result
69 except AttributeError:
70 continue
71 # If we get here, none of the targets in the list were found
72 raise AttributeError(
73 f"None of the target names {target_name} could be resolved on '{type(target_object).__name__}' object"
74 )
75 else:
76 return _resolve_single_target(target_name)
77 return None
80def _collect_aliases_from_module(
81 module: Any, path: str, aliases: Dict[str, str], visited: Set[int] = set()
82) -> None:
83 """Helper function to collect all aliases from a single module.
84 Args:
85 module: The module to collect aliases from
86 path: Current path prefix for building full names
87 aliases: Dictionary to populate with aliases (modified in-place)
88 visited: Set of already visited module IDs to prevent infinite recursion
89 """
90 mod_id = id(module)
91 if mod_id in visited:
92 return
93 visited.add(mod_id)
95 if hasattr(module, "hook_aliases"):
96 for alias_name, target_name in module.hook_aliases.items():
97 if alias_name == "":
98 # Empty string creates cache alias: embed -> embed.hook_out
99 if path: # Only add if we have a meaningful path
100 aliases[path] = f"{path}.{target_name}"
101 else:
102 # Named hook alias: embed.hook_embed -> embed.hook_out
103 # Handle special case, hook_pos_embed and hook_embed should not be prefixed
104 if path and not (alias_name == "hook_pos_embed" or alias_name == "hook_embed"):
105 full_alias = f"{path}.{alias_name}"
106 full_target = f"{path}.{target_name}"
107 else:
108 full_alias = alias_name
109 full_target = f"{path}.{target_name}" if path else target_name
111 aliases[full_alias] = full_target
113 # Recursively collect from submodules, excluding original_model
114 if hasattr(module, "named_children"):
115 for child_name, child_module in module.named_children():
116 # Skip the original_model to avoid collecting hooks from HuggingFace model
117 if child_name == "original_model" or child_name == "_original_component":
118 continue
120 child_path = f"{path}.{child_name}" if path else child_name
121 _collect_aliases_from_module(child_module, child_path, aliases, visited)
124def collect_aliases_recursive(module: Any, prefix: str = "") -> Dict[str, str]:
125 """Recursively collect all aliases from a module and its children.
126 This unified function collects both:
127 - Named hook aliases: old_hook_name -> new_hook_name
128 - Cache aliases: component_name -> component_name.hook_out (from empty string keys)
129 Args:
130 module: The module to collect aliases from
131 prefix: Path prefix for building full names
132 Returns:
133 Dictionary mapping all alias names to target names
134 """
135 aliases: Dict[str, str] = {}
136 visited: Set[int] = set()
137 _collect_aliases_from_module(module, prefix, aliases, visited)
138 return aliases