Coverage for transformer_lens/utilities/bridge_components.py: 83%

41 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Utilities for traversing and applying functions to every component in a TransformerBridge model.""" 

2 

3from typing import Any, Callable 

4 

5import torch.nn as nn 

6 

7from transformer_lens.model_bridge.bridge import TransformerBridge 

8from transformer_lens.model_bridge.generalized_components.base import ( 

9 GeneralizedComponent, 

10) 

11 

12 

13def collect_all_submodules_of_component( 

14 model: TransformerBridge, 

15 component: GeneralizedComponent, 

16 submodules: dict, 

17 block_prefix: str = "", 

18) -> dict: 

19 """Recursively collects all submodules of a component in a TransformerBridge model. 

20 Args: 

21 model: The TransformerBridge model to collect submodules from 

22 component: The component to collect submodules from 

23 submodules: A dictionary to populate with submodules (modified in-place) 

24 block_prefix: Prefix for the block name, needed for components that are part of a block bridge 

25 Returns: 

26 Dictionary mapping submodule names to their respective submodules 

27 """ 

28 for component_submodule in component.submodules.values(): 

29 # Skip components without names (e.g., OPT's MLP container) 

30 if component_submodule.name is not None: 30 ↛ 34line 30 didn't jump to line 34 because the condition on line 30 was always true

31 submodules[block_prefix + component_submodule.name] = component_submodule 

32 

33 # If the component is a list item, we need to collect all submodules of the block bridge 

34 if component_submodule.is_list_item: 34 ↛ 35line 34 didn't jump to line 35 because the condition on line 34 was never true

35 submodules = collect_components_of_block_bridge(model, component_submodule, submodules) 

36 

37 # If the component has submodules, we need to collect them recursively 

38 if component_submodule.submodules: 

39 submodules = collect_all_submodules_of_component( 

40 model, component_submodule, submodules, block_prefix 

41 ) 

42 return submodules 

43 

44 

45def collect_components_of_block_bridge( 

46 model: TransformerBridge, component: GeneralizedComponent, components: dict 

47) -> dict: 

48 """Collects all components of a BlockBridge component. 

49 Args: 

50 model: The TransformerBridge model to collect components from 

51 component: The BlockBridge component to collect components from 

52 components: A dictionary to populate with components (modified in-place) 

53 Returns: 

54 Dictionary mapping component names to their respective components 

55 """ 

56 

57 # Retrieve the remote component list from the adapter (we need a ModuleList to iterate over) 

58 if component.name is None: 58 ↛ 59line 58 didn't jump to line 59 because the condition on line 58 was never true

59 raise ValueError("Block bridge component must have a name") 

60 

61 # Use cached original_component for nested list items (relative names) 

62 if component.original_component is not None: 62 ↛ 63line 62 didn't jump to line 63 because the condition on line 62 was never true

63 remote_module_list = component.original_component 

64 else: 

65 try: 

66 remote_module_list = model.adapter.get_remote_component( 

67 model.original_model, component.name 

68 ) 

69 except AttributeError: 

70 # Relative name not reachable from root; already set up during boot 

71 return components 

72 

73 # Make sure the remote component is a ModuleList 

74 if isinstance(remote_module_list, nn.ModuleList): 74 ↛ 78line 74 didn't jump to line 78 because the condition on line 74 was always true

75 for block in remote_module_list: 

76 components[block.name] = block 

77 components = collect_all_submodules_of_component(model, block, components, block.name) 

78 return components 

79 

80 

81def collect_all_components(model: TransformerBridge, components: dict) -> dict: 

82 """Collects all components in a TransformerBridge inside a dictionary. 

83 The keys are the component names, and the values are the components themselves. 

84 Args: 

85 model: The TransformerBridge model to collect components from 

86 components: A dictionary to populate with components (modified in-place) 

87 Returns: 

88 Dictionary mapping component names to their respective components 

89 """ 

90 

91 # Iterate through all components in component mapping 

92 for component in model.adapter.get_component_mapping().values(): 

93 components[component.name] = component 

94 components = collect_all_submodules_of_component(model, component, components) 

95 

96 # We need to enable compatibility mode for all different blocks of the component if the component is a list item 

97 if component.is_list_item: 

98 components = collect_components_of_block_bridge(model, component, components) 

99 return components 

100 

101 

102def apply_fn_to_all_components( 

103 model: TransformerBridge, 

104 fn: Callable[[GeneralizedComponent], Any], 

105 components: dict | None = None, 

106) -> dict[str, Any]: 

107 """Applies a function to all components in the TransformerBridge model. 

108 Args: 

109 model: The TransformerBridge model to apply the function to 

110 fn: The function to apply to each component 

111 components: Optional dictionary of components to apply the function to, if None, all components are collected 

112 Returns: 

113 return_values: A dictionary mapping component names to the return values of the function 

114 """ 

115 

116 if components is None: 116 ↛ 119line 116 didn't jump to line 119 because the condition on line 116 was always true

117 components = collect_all_components(model, {}) 

118 

119 return_values = {} 

120 

121 for component in components.values(): 

122 return_values[component.name] = fn(component) 

123 

124 return return_values