Coverage for transformer_lens/utilities/bridge_components.py: 84%
44 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Utilities for traversing and applying functions to every component in a TransformerBridge model."""
3from typing import Any, Callable, cast
5import torch.nn as nn
7from transformer_lens.model_bridge.bridge import TransformerBridge
8from transformer_lens.model_bridge.generalized_components.base import (
9 GeneralizedComponent,
10)
13def collect_all_submodules_of_component(
14 model: TransformerBridge,
15 component: GeneralizedComponent,
16 submodules: dict,
17 block_prefix: str = "",
18) -> dict:
19 """Recursively collects all submodules of a component in a TransformerBridge model.
20 Args:
21 model: The TransformerBridge model to collect submodules from
22 component: The component to collect submodules from
23 submodules: A dictionary to populate with submodules (modified in-place)
24 block_prefix: Prefix for the block name, needed for components that are part of a block bridge
25 Returns:
26 Dictionary mapping submodule names to their respective submodules
27 """
28 for component_submodule in component.submodules.values():
29 # Skip components without names (e.g., OPT's MLP container)
30 if component_submodule.name is not None: 30 ↛ 34line 30 didn't jump to line 34 because the condition on line 30 was always true
31 submodules[block_prefix + component_submodule.name] = component_submodule
33 # If the component is a list item, we need to collect all submodules of the block bridge
34 if component_submodule.is_list_item: 34 ↛ 35line 34 didn't jump to line 35 because the condition on line 34 was never true
35 submodules = collect_components_of_block_bridge(model, component_submodule, submodules)
37 # If the component has submodules, we need to collect them recursively
38 if component_submodule.submodules:
39 submodules = collect_all_submodules_of_component(
40 model, component_submodule, submodules, block_prefix
41 )
42 return submodules
45def collect_components_of_block_bridge(
46 model: TransformerBridge, component: GeneralizedComponent, components: dict
47) -> dict:
48 """Collects all components of a BlockBridge component.
49 Args:
50 model: The TransformerBridge model to collect components from
51 component: The BlockBridge component to collect components from
52 components: A dictionary to populate with components (modified in-place)
53 Returns:
54 Dictionary mapping component names to their respective components
55 """
57 # Retrieve the remote component list from the adapter (we need a ModuleList to iterate over)
58 if component.name is None: 58 ↛ 59line 58 didn't jump to line 59 because the condition on line 58 was never true
59 raise ValueError("Block bridge component must have a name")
61 # Use cached original_component for nested list items (relative names)
62 if component.original_component is not None: 62 ↛ 63line 62 didn't jump to line 63 because the condition on line 62 was never true
63 remote_module_list = component.original_component
64 else:
65 try:
66 remote_module_list = model.adapter.get_remote_component(
67 model.original_model, component.name
68 )
69 except AttributeError:
70 # Relative name not reachable from root; already set up during boot
71 return components
73 # Make sure the remote component is a ModuleList
74 if isinstance(remote_module_list, nn.ModuleList): 74 ↛ 83line 74 didn't jump to line 83 because the condition on line 74 was always true
75 for block in remote_module_list:
76 block_component = cast(GeneralizedComponent, block)
77 block_name = block_component.name
78 assert block_name is not None, "Block bridge component must have a name"
79 components[block_name] = block_component
80 components = collect_all_submodules_of_component(
81 model, block_component, components, block_name
82 )
83 return components
86def collect_all_components(model: TransformerBridge, components: dict) -> dict:
87 """Collects all components in a TransformerBridge inside a dictionary.
88 The keys are the component names, and the values are the components themselves.
89 Args:
90 model: The TransformerBridge model to collect components from
91 components: A dictionary to populate with components (modified in-place)
92 Returns:
93 Dictionary mapping component names to their respective components
94 """
96 # Iterate through all components in component mapping
97 for component in model.adapter.get_component_mapping().values():
98 components[component.name] = component
99 components = collect_all_submodules_of_component(model, component, components)
101 # We need to enable compatibility mode for all different blocks of the component if the component is a list item
102 if component.is_list_item:
103 components = collect_components_of_block_bridge(model, component, components)
104 return components
107def apply_fn_to_all_components(
108 model: TransformerBridge,
109 fn: Callable[[GeneralizedComponent], Any],
110 components: dict | None = None,
111) -> dict[str, Any]:
112 """Applies a function to all components in the TransformerBridge model.
113 Args:
114 model: The TransformerBridge model to apply the function to
115 fn: The function to apply to each component
116 components: Optional dictionary of components to apply the function to, if None, all components are collected
117 Returns:
118 return_values: A dictionary mapping component names to the return values of the function
119 """
121 if components is None: 121 ↛ 124line 121 didn't jump to line 124 because the condition on line 121 was always true
122 components = collect_all_components(model, {})
124 return_values = {}
126 for component in components.values():
127 return_values[component.name] = fn(component)
129 return return_values