Coverage for transformer_lens/utilities/aliases.py: 87%

59 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Utilities for handling hook aliases in the bridge system.""" 

2 

3import warnings 

4from typing import Any, Dict, List, Optional, Set, Union 

5 

6 

7def resolve_alias( 

8 target_object: Any, 

9 requested_name: str, 

10 aliases: Dict[str, str] | Dict[str, Union[str, List[str]]], 

11) -> Optional[Any]: 

12 """Resolve a hook alias to the actual hook object. 

13 

14 Args: 

15 target_object: The object to get the resolved attribute from 

16 requested_name: The name being requested (potentially an alias) 

17 aliases: Dictionary mapping alias names to target names 

18 

19 Returns: 

20 The resolved hook object if alias found, None otherwise 

21 """ 

22 if requested_name in aliases: 

23 target_name = aliases[requested_name] 

24 

25 if hasattr(target_object, "disable_warnings") and target_object.disable_warnings == False: 

26 warnings.warn( 

27 f"Hook '{requested_name}' is deprecated and will be removed in a future version. " 

28 f"Use '{target_name}' instead.", 

29 FutureWarning, 

30 stacklevel=3, # Adjusted for utility function call 

31 ) 

32 

33 def _resolve_single_target(target_name: str) -> Any: 

34 """Helper function to resolve a single target name.""" 

35 target_name_split = target_name.split(".") 

36 # Resolve dotted paths; list-based aliases try each option 

37 if len(target_name_split) > 1: 

38 current_attr = target_object 

39 for i in range(len(target_name_split) - 1): 

40 if not hasattr(current_attr, target_name_split[i]): 

41 # Raise so list-based aliases can try next option 

42 raise AttributeError( 

43 f"'{type(current_attr).__name__}' object has no attribute '{target_name_split[i]}'" 

44 ) 

45 current_attr = getattr(current_attr, target_name_split[i]) 

46 

47 # Check if the final attribute exists 

48 if not hasattr(current_attr, target_name_split[-1]): 48 ↛ 49line 48 didn't jump to line 49 because the condition on line 48 was never true

49 raise AttributeError( 

50 f"'{type(current_attr).__name__}' object has no attribute '{target_name_split[-1]}'" 

51 ) 

52 next_attr = getattr(current_attr, target_name_split[-1]) 

53 return next_attr 

54 else: 

55 # Check if the target attribute exists before getting it 

56 if not hasattr(target_object, target_name): 

57 raise AttributeError( 

58 f"'{type(target_object).__name__}' object has no attribute '{target_name}'" 

59 ) 

60 # Return the target hook 

61 return getattr(target_object, target_name) 

62 

63 # if the target_name is a list, we check all elements 

64 if isinstance(target_name, list): 64 ↛ 65line 64 didn't jump to line 65 because the condition on line 64 was never true

65 for target_name_item in target_name: 

66 try: 

67 result = _resolve_single_target(target_name_item) 

68 return result 

69 except AttributeError: 

70 continue 

71 # If we get here, none of the targets in the list were found 

72 raise AttributeError( 

73 f"None of the target names {target_name} could be resolved on '{type(target_object).__name__}' object" 

74 ) 

75 else: 

76 return _resolve_single_target(target_name) 

77 return None 

78 

79 

80def _collect_aliases_from_module( 

81 module: Any, path: str, aliases: Dict[str, str], visited: Set[int] = set() 

82) -> None: 

83 """Helper function to collect all aliases from a single module. 

84 Args: 

85 module: The module to collect aliases from 

86 path: Current path prefix for building full names 

87 aliases: Dictionary to populate with aliases (modified in-place) 

88 visited: Set of already visited module IDs to prevent infinite recursion 

89 """ 

90 mod_id = id(module) 

91 if mod_id in visited: 

92 return 

93 visited.add(mod_id) 

94 

95 if hasattr(module, "hook_aliases"): 

96 for alias_name, target_name in module.hook_aliases.items(): 

97 if alias_name == "": 

98 # Empty string creates cache alias: embed -> embed.hook_out 

99 if path: # Only add if we have a meaningful path 

100 aliases[path] = f"{path}.{target_name}" 

101 else: 

102 # Named hook alias: embed.hook_embed -> embed.hook_out 

103 # Handle special case, hook_pos_embed and hook_embed should not be prefixed 

104 if path and not (alias_name == "hook_pos_embed" or alias_name == "hook_embed"): 

105 full_alias = f"{path}.{alias_name}" 

106 full_target = f"{path}.{target_name}" 

107 else: 

108 full_alias = alias_name 

109 full_target = f"{path}.{target_name}" if path else target_name 

110 

111 aliases[full_alias] = full_target 

112 

113 # Recursively collect from submodules, excluding original_model 

114 if hasattr(module, "named_children"): 

115 for child_name, child_module in module.named_children(): 

116 # Skip the original_model to avoid collecting hooks from HuggingFace model 

117 if child_name == "original_model" or child_name == "_original_component": 

118 continue 

119 

120 child_path = f"{path}.{child_name}" if path else child_name 

121 _collect_aliases_from_module(child_module, child_path, aliases, visited) 

122 

123 

124def collect_aliases_recursive(module: Any, prefix: str = "") -> Dict[str, str]: 

125 """Recursively collect all aliases from a module and its children. 

126 This unified function collects both: 

127 - Named hook aliases: old_hook_name -> new_hook_name 

128 - Cache aliases: component_name -> component_name.hook_out (from empty string keys) 

129 Args: 

130 module: The module to collect aliases from 

131 prefix: Path prefix for building full names 

132 Returns: 

133 Dictionary mapping all alias names to target names 

134 """ 

135 aliases: Dict[str, str] = {} 

136 visited: Set[int] = set() 

137 _collect_aliases_from_module(module, prefix, aliases, visited) 

138 return aliases