Coverage for transformer_lens/model_bridge/component_setup.py: 86%

101 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1from __future__ import annotations 

2 

3"Component setup utilities for creating and configuring bridged components." 

4import copy 

5import logging 

6from typing import TYPE_CHECKING, Any, cast 

7 

8logger = logging.getLogger(__name__) 

9 

10import torch.nn as nn 

11 

12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

13from transformer_lens.model_bridge.generalized_components.base import ( 

14 GeneralizedComponent, 

15) 

16from transformer_lens.model_bridge.generalized_components.symbolic import SymbolicBridge 

17from transformer_lens.model_bridge.types import RemoteModel 

18 

19if TYPE_CHECKING: 

20 pass 

21 

22 

23def replace_remote_component( 

24 replacement_component: nn.Module, remote_path: str, remote_model: RemoteModel 

25) -> None: 

26 """Replace a component in a remote model. 

27 

28 Args: 

29 replacement_component: The new component to install 

30 remote_path: Path to the component in the remote model 

31 remote_model: The remote model to modify 

32 """ 

33 path_parts = remote_path.split(".") 

34 current = remote_model 

35 for part in path_parts[:-1]: 

36 if hasattr(current, part): 

37 current = getattr(current, part) 

38 else: 

39 raise ValueError(f"Path {remote_path} not found in model") 

40 target_attr = path_parts[-1] 

41 if hasattr(current, target_attr): 

42 setattr(current, target_attr, replacement_component) 

43 else: 

44 raise ValueError(f"Attribute {target_attr} not found in {current}") 

45 

46 

47def set_original_components( 

48 bridge_module: nn.Module, architecture_adapter: ArchitectureAdapter, original_model: RemoteModel 

49) -> None: 

50 """Set original components on the pre-created bridge components. 

51 

52 Args: 

53 bridge_module: The bridge module to configure 

54 architecture_adapter: The architecture adapter 

55 original_model: The original model to get components from 

56 """ 

57 component_mapping = architecture_adapter.get_component_mapping() 

58 setup_components(component_mapping, bridge_module, architecture_adapter, original_model) 

59 

60 

61def setup_submodules( 

62 component: GeneralizedComponent, 

63 architecture_adapter: ArchitectureAdapter, 

64 original_model: RemoteModel, 

65) -> None: 

66 """Set up submodules for a bridge component using proper component setup. 

67 

68 Args: 

69 component: The bridge component to set up submodules for 

70 architecture_adapter: The architecture adapter 

71 original_model: The original model to get components from 

72 """ 

73 skipped_optional: list[str] = [] 

74 for module_name, submodule in component.submodules.items(): 

75 if submodule.is_list_item: 75 ↛ 76line 75 didn't jump to line 76 because the condition on line 75 was never true

76 if submodule.name is None: 

77 raise ValueError(f"List item component {module_name} must have a name") 

78 bridged_list = setup_blocks_bridge(submodule, architecture_adapter, original_model) 

79 component.add_module(module_name, bridged_list) 

80 replace_remote_component(bridged_list, submodule.name, original_model) 

81 # Add to real_components mapping 

82 component.real_components[module_name] = (submodule.name, list(bridged_list)) 

83 elif isinstance(submodule, SymbolicBridge): 

84 # SymbolicBridge: no real component; set up submodules via parent's model 

85 setup_submodules(submodule, architecture_adapter, original_model) 

86 

87 # Add the symbolic bridge as a module (for structural access like blocks[i].mlp.in) 

88 if module_name not in component._modules: 88 ↛ 92line 88 didn't jump to line 92 because the condition on line 88 was always true

89 component.add_module(module_name, submodule) 

90 

91 # Add symbolic bridge's real_components to parent's mapping with prefixed keys 

92 for sub_name, (sub_path, sub_comp) in submodule.real_components.items(): 

93 prefixed_key = f"{module_name}.{sub_name}" 

94 component.real_components[prefixed_key] = (sub_path, sub_comp) 

95 else: 

96 # Set up original_component if not already set 

97 if submodule.original_component is None: 

98 if submodule.name is None: 98 ↛ 99line 98 didn't jump to line 99 because the condition on line 98 was never true

99 original_subcomponent = original_model 

100 else: 

101 remote_path = submodule.name 

102 is_optional = getattr(submodule, "optional", False) 

103 # Fast path: first segment absent or None → skip 

104 first_segment = remote_path.split(".")[0] 

105 first_value = getattr(original_model, first_segment, None) 

106 if is_optional and first_value is None: 

107 logger.debug( 

108 "Optional '%s' (path '%s') absent on %s", 

109 module_name, 

110 remote_path, 

111 getattr(component, "name", "?"), 

112 ) 

113 skipped_optional.append(module_name) 

114 continue 

115 # Full resolution — catches deeper path failures (e.g. stub self_attn missing q_proj) 

116 try: 

117 original_subcomponent = architecture_adapter.get_remote_component( 

118 original_model, remote_path 

119 ) 

120 except AttributeError: 

121 if is_optional: 121 ↛ 122line 121 didn't jump to line 122 because the condition on line 121 was never true

122 logger.debug( 

123 "Optional '%s' (path '%s') partially absent on %s", 

124 module_name, 

125 remote_path, 

126 getattr(component, "name", "?"), 

127 ) 

128 skipped_optional.append(module_name) 

129 continue 

130 raise 

131 submodule.set_original_component(original_subcomponent) 

132 setup_submodules(submodule, architecture_adapter, original_subcomponent) 

133 if submodule.name is not None: 133 ↛ 137line 133 didn't jump to line 137 because the condition on line 133 was always true

134 replace_remote_component(submodule, submodule.name, original_model) 

135 

136 # Add to _modules if not already present 

137 if module_name not in component._modules: 

138 component.add_module(module_name, submodule) 

139 

140 # Add to real_components mapping (for non-list components) 

141 if not submodule.is_list_item and submodule.name is not None: 141 ↛ 74line 141 didn't jump to line 74 because the condition on line 141 was always true

142 component.real_components[module_name] = (submodule.name, submodule) 

143 

144 # Clean up so architecture_adapter traversal won't find stale entries 

145 for name in skipped_optional: 

146 component.submodules.pop(name, None) 

147 

148 

149def setup_components( 

150 components: dict[str, Any], 

151 bridge_module: nn.Module, 

152 architecture_adapter: ArchitectureAdapter, 

153 original_model: RemoteModel, 

154) -> None: 

155 """Set up components on the bridge module. 

156 

157 Args: 

158 components: Dictionary of component name to bridge component mappings 

159 bridge_module: The bridge module to configure 

160 architecture_adapter: The architecture adapter 

161 original_model: The original model to get components from 

162 """ 

163 for tl_path, bridge_component in components.items(): 

164 remote_path = bridge_component.name 

165 if bridge_component.is_list_item: 

166 bridged_list = setup_blocks_bridge( 

167 bridge_component, architecture_adapter, original_model 

168 ) 

169 bridge_module.add_module(tl_path, bridged_list) 

170 replace_remote_component(bridged_list, remote_path, original_model) 

171 # Add to bridge module's real_components if it has the attribute 

172 if hasattr(bridge_module, "real_components"): 

173 bridge_module.real_components[tl_path] = (remote_path, list(bridged_list)) # type: ignore[index, assignment, operator] 

174 else: 

175 original_component = architecture_adapter.get_remote_component( 

176 original_model, remote_path 

177 ) 

178 bridge_component.set_original_component(original_component) 

179 setup_submodules(bridge_component, architecture_adapter, original_component) 

180 bridge_module.add_module(tl_path, bridge_component) 

181 replace_remote_component(bridge_component, remote_path, original_model) 

182 # Add to bridge module's real_components if it has the attribute 

183 if hasattr(bridge_module, "real_components"): 

184 bridge_module.real_components[tl_path] = (remote_path, bridge_component) # type: ignore[index, assignment, operator] 

185 

186 

187def setup_blocks_bridge( 

188 blocks_template: Any, architecture_adapter: ArchitectureAdapter, original_model: RemoteModel 

189) -> nn.ModuleList: 

190 """Set up blocks bridge with proper ModuleList structure. 

191 

192 Args: 

193 blocks_template: Template bridge component for blocks 

194 architecture_adapter: The architecture adapter 

195 original_model: The original model to get components from 

196 

197 Returns: 

198 ModuleList of bridged block components 

199 """ 

200 original_blocks = architecture_adapter.get_remote_component( 

201 original_model, blocks_template.name 

202 ) 

203 if not hasattr(original_blocks, "__iter__"): 203 ↛ 204line 203 didn't jump to line 204 because the condition on line 203 was never true

204 raise TypeError(f"Component {blocks_template.name} is not iterable") 

205 bridged_blocks = nn.ModuleList() 

206 iterable_blocks = cast(Any, original_blocks) 

207 for i, original_block in enumerate(iterable_blocks): 

208 block_bridge = copy.deepcopy(blocks_template) 

209 block_bridge.name = f"{blocks_template.name}.{i}" 

210 block_bridge.set_original_component(original_block) 

211 setup_submodules(block_bridge, architecture_adapter, original_block) 

212 bridged_blocks.append(block_bridge) 

213 replace_remote_component(bridged_blocks, blocks_template.name, original_model) 

214 return bridged_blocks