Coverage for transformer_lens/utilities/bridge_components.py: 83%
41 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Utilities for traversing and applying functions to every component in a TransformerBridge model."""
3from typing import Any, Callable
5import torch.nn as nn
7from transformer_lens.model_bridge.bridge import TransformerBridge
8from transformer_lens.model_bridge.generalized_components.base import (
9 GeneralizedComponent,
10)
13def collect_all_submodules_of_component(
14 model: TransformerBridge,
15 component: GeneralizedComponent,
16 submodules: dict,
17 block_prefix: str = "",
18) -> dict:
19 """Recursively collects all submodules of a component in a TransformerBridge model.
20 Args:
21 model: The TransformerBridge model to collect submodules from
22 component: The component to collect submodules from
23 submodules: A dictionary to populate with submodules (modified in-place)
24 block_prefix: Prefix for the block name, needed for components that are part of a block bridge
25 Returns:
26 Dictionary mapping submodule names to their respective submodules
27 """
28 for component_submodule in component.submodules.values():
29 # Skip components without names (e.g., OPT's MLP container)
30 if component_submodule.name is not None: 30 ↛ 34line 30 didn't jump to line 34 because the condition on line 30 was always true
31 submodules[block_prefix + component_submodule.name] = component_submodule
33 # If the component is a list item, we need to collect all submodules of the block bridge
34 if component_submodule.is_list_item: 34 ↛ 35line 34 didn't jump to line 35 because the condition on line 34 was never true
35 submodules = collect_components_of_block_bridge(model, component_submodule, submodules)
37 # If the component has submodules, we need to collect them recursively
38 if component_submodule.submodules:
39 submodules = collect_all_submodules_of_component(
40 model, component_submodule, submodules, block_prefix
41 )
42 return submodules
45def collect_components_of_block_bridge(
46 model: TransformerBridge, component: GeneralizedComponent, components: dict
47) -> dict:
48 """Collects all components of a BlockBridge component.
49 Args:
50 model: The TransformerBridge model to collect components from
51 component: The BlockBridge component to collect components from
52 components: A dictionary to populate with components (modified in-place)
53 Returns:
54 Dictionary mapping component names to their respective components
55 """
57 # Retrieve the remote component list from the adapter (we need a ModuleList to iterate over)
58 if component.name is None: 58 ↛ 59line 58 didn't jump to line 59 because the condition on line 58 was never true
59 raise ValueError("Block bridge component must have a name")
61 # Use cached original_component for nested list items (relative names)
62 if component.original_component is not None: 62 ↛ 63line 62 didn't jump to line 63 because the condition on line 62 was never true
63 remote_module_list = component.original_component
64 else:
65 try:
66 remote_module_list = model.adapter.get_remote_component(
67 model.original_model, component.name
68 )
69 except AttributeError:
70 # Relative name not reachable from root; already set up during boot
71 return components
73 # Make sure the remote component is a ModuleList
74 if isinstance(remote_module_list, nn.ModuleList): 74 ↛ 78line 74 didn't jump to line 78 because the condition on line 74 was always true
75 for block in remote_module_list:
76 components[block.name] = block
77 components = collect_all_submodules_of_component(model, block, components, block.name)
78 return components
81def collect_all_components(model: TransformerBridge, components: dict) -> dict:
82 """Collects all components in a TransformerBridge inside a dictionary.
83 The keys are the component names, and the values are the components themselves.
84 Args:
85 model: The TransformerBridge model to collect components from
86 components: A dictionary to populate with components (modified in-place)
87 Returns:
88 Dictionary mapping component names to their respective components
89 """
91 # Iterate through all components in component mapping
92 for component in model.adapter.get_component_mapping().values():
93 components[component.name] = component
94 components = collect_all_submodules_of_component(model, component, components)
96 # We need to enable compatibility mode for all different blocks of the component if the component is a list item
97 if component.is_list_item:
98 components = collect_components_of_block_bridge(model, component, components)
99 return components
102def apply_fn_to_all_components(
103 model: TransformerBridge,
104 fn: Callable[[GeneralizedComponent], Any],
105 components: dict | None = None,
106) -> dict[str, Any]:
107 """Applies a function to all components in the TransformerBridge model.
108 Args:
109 model: The TransformerBridge model to apply the function to
110 fn: The function to apply to each component
111 components: Optional dictionary of components to apply the function to, if None, all components are collected
112 Returns:
113 return_values: A dictionary mapping component names to the return values of the function
114 """
116 if components is None: 116 ↛ 119line 116 didn't jump to line 119 because the condition on line 116 was always true
117 components = collect_all_components(model, {})
119 return_values = {}
121 for component in components.values():
122 return_values[component.name] = fn(component)
124 return return_values