Coverage for transformer_lens/model_bridge/component_setup.py: 86%
101 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1from __future__ import annotations
3"Component setup utilities for creating and configuring bridged components."
4import copy
5import logging
6from typing import TYPE_CHECKING, Any, cast
8logger = logging.getLogger(__name__)
10import torch.nn as nn
12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
13from transformer_lens.model_bridge.generalized_components.base import (
14 GeneralizedComponent,
15)
16from transformer_lens.model_bridge.generalized_components.symbolic import SymbolicBridge
17from transformer_lens.model_bridge.types import RemoteModel
19if TYPE_CHECKING:
20 pass
23def replace_remote_component(
24 replacement_component: nn.Module, remote_path: str, remote_model: RemoteModel
25) -> None:
26 """Replace a component in a remote model.
28 Args:
29 replacement_component: The new component to install
30 remote_path: Path to the component in the remote model
31 remote_model: The remote model to modify
32 """
33 path_parts = remote_path.split(".")
34 current = remote_model
35 for part in path_parts[:-1]:
36 if hasattr(current, part):
37 current = getattr(current, part)
38 else:
39 raise ValueError(f"Path {remote_path} not found in model")
40 target_attr = path_parts[-1]
41 if hasattr(current, target_attr):
42 setattr(current, target_attr, replacement_component)
43 else:
44 raise ValueError(f"Attribute {target_attr} not found in {current}")
47def set_original_components(
48 bridge_module: nn.Module, architecture_adapter: ArchitectureAdapter, original_model: RemoteModel
49) -> None:
50 """Set original components on the pre-created bridge components.
52 Args:
53 bridge_module: The bridge module to configure
54 architecture_adapter: The architecture adapter
55 original_model: The original model to get components from
56 """
57 component_mapping = architecture_adapter.get_component_mapping()
58 setup_components(component_mapping, bridge_module, architecture_adapter, original_model)
61def setup_submodules(
62 component: GeneralizedComponent,
63 architecture_adapter: ArchitectureAdapter,
64 original_model: RemoteModel,
65) -> None:
66 """Set up submodules for a bridge component using proper component setup.
68 Args:
69 component: The bridge component to set up submodules for
70 architecture_adapter: The architecture adapter
71 original_model: The original model to get components from
72 """
73 skipped_optional: list[str] = []
74 for module_name, submodule in component.submodules.items():
75 if submodule.is_list_item: 75 ↛ 76line 75 didn't jump to line 76 because the condition on line 75 was never true
76 if submodule.name is None:
77 raise ValueError(f"List item component {module_name} must have a name")
78 bridged_list = setup_blocks_bridge(submodule, architecture_adapter, original_model)
79 component.add_module(module_name, bridged_list)
80 replace_remote_component(bridged_list, submodule.name, original_model)
81 # Add to real_components mapping
82 component.real_components[module_name] = (submodule.name, list(bridged_list))
83 elif isinstance(submodule, SymbolicBridge):
84 # SymbolicBridge: no real component; set up submodules via parent's model
85 setup_submodules(submodule, architecture_adapter, original_model)
87 # Add the symbolic bridge as a module (for structural access like blocks[i].mlp.in)
88 if module_name not in component._modules: 88 ↛ 92line 88 didn't jump to line 92 because the condition on line 88 was always true
89 component.add_module(module_name, submodule)
91 # Add symbolic bridge's real_components to parent's mapping with prefixed keys
92 for sub_name, (sub_path, sub_comp) in submodule.real_components.items():
93 prefixed_key = f"{module_name}.{sub_name}"
94 component.real_components[prefixed_key] = (sub_path, sub_comp)
95 else:
96 # Set up original_component if not already set
97 if submodule.original_component is None:
98 if submodule.name is None: 98 ↛ 99line 98 didn't jump to line 99 because the condition on line 98 was never true
99 original_subcomponent = original_model
100 else:
101 remote_path = submodule.name
102 is_optional = getattr(submodule, "optional", False)
103 # Fast path: first segment absent or None → skip
104 first_segment = remote_path.split(".")[0]
105 first_value = getattr(original_model, first_segment, None)
106 if is_optional and first_value is None:
107 logger.debug(
108 "Optional '%s' (path '%s') absent on %s",
109 module_name,
110 remote_path,
111 getattr(component, "name", "?"),
112 )
113 skipped_optional.append(module_name)
114 continue
115 # Full resolution — catches deeper path failures (e.g. stub self_attn missing q_proj)
116 try:
117 original_subcomponent = architecture_adapter.get_remote_component(
118 original_model, remote_path
119 )
120 except AttributeError:
121 if is_optional: 121 ↛ 122line 121 didn't jump to line 122 because the condition on line 121 was never true
122 logger.debug(
123 "Optional '%s' (path '%s') partially absent on %s",
124 module_name,
125 remote_path,
126 getattr(component, "name", "?"),
127 )
128 skipped_optional.append(module_name)
129 continue
130 raise
131 submodule.set_original_component(original_subcomponent)
132 setup_submodules(submodule, architecture_adapter, original_subcomponent)
133 if submodule.name is not None: 133 ↛ 137line 133 didn't jump to line 137 because the condition on line 133 was always true
134 replace_remote_component(submodule, submodule.name, original_model)
136 # Add to _modules if not already present
137 if module_name not in component._modules:
138 component.add_module(module_name, submodule)
140 # Add to real_components mapping (for non-list components)
141 if not submodule.is_list_item and submodule.name is not None: 141 ↛ 74line 141 didn't jump to line 74 because the condition on line 141 was always true
142 component.real_components[module_name] = (submodule.name, submodule)
144 # Clean up so architecture_adapter traversal won't find stale entries
145 for name in skipped_optional:
146 component.submodules.pop(name, None)
149def setup_components(
150 components: dict[str, Any],
151 bridge_module: nn.Module,
152 architecture_adapter: ArchitectureAdapter,
153 original_model: RemoteModel,
154) -> None:
155 """Set up components on the bridge module.
157 Args:
158 components: Dictionary of component name to bridge component mappings
159 bridge_module: The bridge module to configure
160 architecture_adapter: The architecture adapter
161 original_model: The original model to get components from
162 """
163 for tl_path, bridge_component in components.items():
164 remote_path = bridge_component.name
165 if bridge_component.is_list_item:
166 bridged_list = setup_blocks_bridge(
167 bridge_component, architecture_adapter, original_model
168 )
169 bridge_module.add_module(tl_path, bridged_list)
170 replace_remote_component(bridged_list, remote_path, original_model)
171 # Add to bridge module's real_components if it has the attribute
172 if hasattr(bridge_module, "real_components"):
173 bridge_module.real_components[tl_path] = (remote_path, list(bridged_list)) # type: ignore[index, assignment, operator]
174 else:
175 original_component = architecture_adapter.get_remote_component(
176 original_model, remote_path
177 )
178 bridge_component.set_original_component(original_component)
179 setup_submodules(bridge_component, architecture_adapter, original_component)
180 bridge_module.add_module(tl_path, bridge_component)
181 replace_remote_component(bridge_component, remote_path, original_model)
182 # Add to bridge module's real_components if it has the attribute
183 if hasattr(bridge_module, "real_components"):
184 bridge_module.real_components[tl_path] = (remote_path, bridge_component) # type: ignore[index, assignment, operator]
187def setup_blocks_bridge(
188 blocks_template: Any, architecture_adapter: ArchitectureAdapter, original_model: RemoteModel
189) -> nn.ModuleList:
190 """Set up blocks bridge with proper ModuleList structure.
192 Args:
193 blocks_template: Template bridge component for blocks
194 architecture_adapter: The architecture adapter
195 original_model: The original model to get components from
197 Returns:
198 ModuleList of bridged block components
199 """
200 original_blocks = architecture_adapter.get_remote_component(
201 original_model, blocks_template.name
202 )
203 if not hasattr(original_blocks, "__iter__"): 203 ↛ 204line 203 didn't jump to line 204 because the condition on line 203 was never true
204 raise TypeError(f"Component {blocks_template.name} is not iterable")
205 bridged_blocks = nn.ModuleList()
206 iterable_blocks = cast(Any, original_blocks)
207 for i, original_block in enumerate(iterable_blocks):
208 block_bridge = copy.deepcopy(blocks_template)
209 block_bridge.name = f"{blocks_template.name}.{i}"
210 block_bridge.set_original_component(original_block)
211 setup_submodules(block_bridge, architecture_adapter, original_block)
212 bridged_blocks.append(block_bridge)
213 replace_remote_component(bridged_blocks, blocks_template.name, original_model)
214 return bridged_blocks