Coverage for transformer_lens/utilities/bridge_components.py: 84%

44 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Utilities for traversing and applying functions to every component in a TransformerBridge model.""" 

2 

3from typing import Any, Callable, cast 

4 

5import torch.nn as nn 

6 

7from transformer_lens.model_bridge.bridge import TransformerBridge 

8from transformer_lens.model_bridge.generalized_components.base import ( 

9 GeneralizedComponent, 

10) 

11 

12 

13def collect_all_submodules_of_component( 

14 model: TransformerBridge, 

15 component: GeneralizedComponent, 

16 submodules: dict, 

17 block_prefix: str = "", 

18) -> dict: 

19 """Recursively collects all submodules of a component in a TransformerBridge model. 

20 Args: 

21 model: The TransformerBridge model to collect submodules from 

22 component: The component to collect submodules from 

23 submodules: A dictionary to populate with submodules (modified in-place) 

24 block_prefix: Prefix for the block name, needed for components that are part of a block bridge 

25 Returns: 

26 Dictionary mapping submodule names to their respective submodules 

27 """ 

28 for component_submodule in component.submodules.values(): 

29 # Skip components without names (e.g., OPT's MLP container) 

30 if component_submodule.name is not None: 30 ↛ 34line 30 didn't jump to line 34 because the condition on line 30 was always true

31 submodules[block_prefix + component_submodule.name] = component_submodule 

32 

33 # If the component is a list item, we need to collect all submodules of the block bridge 

34 if component_submodule.is_list_item: 34 ↛ 35line 34 didn't jump to line 35 because the condition on line 34 was never true

35 submodules = collect_components_of_block_bridge(model, component_submodule, submodules) 

36 

37 # If the component has submodules, we need to collect them recursively 

38 if component_submodule.submodules: 

39 submodules = collect_all_submodules_of_component( 

40 model, component_submodule, submodules, block_prefix 

41 ) 

42 return submodules 

43 

44 

45def collect_components_of_block_bridge( 

46 model: TransformerBridge, component: GeneralizedComponent, components: dict 

47) -> dict: 

48 """Collects all components of a BlockBridge component. 

49 Args: 

50 model: The TransformerBridge model to collect components from 

51 component: The BlockBridge component to collect components from 

52 components: A dictionary to populate with components (modified in-place) 

53 Returns: 

54 Dictionary mapping component names to their respective components 

55 """ 

56 

57 # Retrieve the remote component list from the adapter (we need a ModuleList to iterate over) 

58 if component.name is None: 58 ↛ 59line 58 didn't jump to line 59 because the condition on line 58 was never true

59 raise ValueError("Block bridge component must have a name") 

60 

61 # Use cached original_component for nested list items (relative names) 

62 if component.original_component is not None: 62 ↛ 63line 62 didn't jump to line 63 because the condition on line 62 was never true

63 remote_module_list = component.original_component 

64 else: 

65 try: 

66 remote_module_list = model.adapter.get_remote_component( 

67 model.original_model, component.name 

68 ) 

69 except AttributeError: 

70 # Relative name not reachable from root; already set up during boot 

71 return components 

72 

73 # Make sure the remote component is a ModuleList 

74 if isinstance(remote_module_list, nn.ModuleList): 74 ↛ 83line 74 didn't jump to line 83 because the condition on line 74 was always true

75 for block in remote_module_list: 

76 block_component = cast(GeneralizedComponent, block) 

77 block_name = block_component.name 

78 assert block_name is not None, "Block bridge component must have a name" 

79 components[block_name] = block_component 

80 components = collect_all_submodules_of_component( 

81 model, block_component, components, block_name 

82 ) 

83 return components 

84 

85 

86def collect_all_components(model: TransformerBridge, components: dict) -> dict: 

87 """Collects all components in a TransformerBridge inside a dictionary. 

88 The keys are the component names, and the values are the components themselves. 

89 Args: 

90 model: The TransformerBridge model to collect components from 

91 components: A dictionary to populate with components (modified in-place) 

92 Returns: 

93 Dictionary mapping component names to their respective components 

94 """ 

95 

96 # Iterate through all components in component mapping 

97 for component in model.adapter.get_component_mapping().values(): 

98 components[component.name] = component 

99 components = collect_all_submodules_of_component(model, component, components) 

100 

101 # We need to enable compatibility mode for all different blocks of the component if the component is a list item 

102 if component.is_list_item: 

103 components = collect_components_of_block_bridge(model, component, components) 

104 return components 

105 

106 

107def apply_fn_to_all_components( 

108 model: TransformerBridge, 

109 fn: Callable[[GeneralizedComponent], Any], 

110 components: dict | None = None, 

111) -> dict[str, Any]: 

112 """Applies a function to all components in the TransformerBridge model. 

113 Args: 

114 model: The TransformerBridge model to apply the function to 

115 fn: The function to apply to each component 

116 components: Optional dictionary of components to apply the function to, if None, all components are collected 

117 Returns: 

118 return_values: A dictionary mapping component names to the return values of the function 

119 """ 

120 

121 if components is None: 121 ↛ 124line 121 didn't jump to line 124 because the condition on line 121 was always true

122 components = collect_all_components(model, {}) 

123 

124 return_values = {} 

125 

126 for component in components.values(): 

127 return_values[component.name] = fn(component) 

128 

129 return return_values