Coverage for transformer_lens/model_bridge/generalized_components/mlp.py: 86%
23 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""MLP bridge component.
3This module contains the bridge component for MLP layers.
4"""
5from typing import Any, Dict, Optional
7import torch
9from transformer_lens.model_bridge.generalized_components.base import (
10 GeneralizedComponent,
11)
14class MLPBridge(GeneralizedComponent):
15 """Bridge component for MLP layers.
17 This component wraps an MLP layer from a remote model and provides a consistent interface
18 for accessing its weights and performing MLP operations.
19 """
21 hook_aliases = {"hook_pre": "in.hook_out", "hook_post": "out.hook_in"}
22 property_aliases = {
23 "W_gate": "gate.weight",
24 "b_gate": "gate.bias",
25 "W_in": "in.weight",
26 "b_in": "in.bias",
27 "W_out": "out.weight",
28 "b_out": "out.bias",
29 }
31 def __init__(
32 self,
33 name: Optional[str],
34 config: Optional[Any] = None,
35 submodules: Optional[Dict[str, GeneralizedComponent]] = {},
36 optional: bool = False,
37 ):
38 """Initialize the MLP bridge.
40 Args:
41 name: The name of the component in the model (None if no container exists)
42 config: Optional configuration (unused for MLPBridge)
43 submodules: Dictionary of submodules to register (e.g., gate_proj, up_proj, down_proj)
44 optional: If True, setup skips this bridge when absent (hybrid architectures).
45 """
46 super().__init__(name, config, submodules=submodules, optional=optional)
48 def forward(self, *args, **kwargs) -> torch.Tensor:
49 """Forward pass through the MLP bridge.
51 Args:
52 *args: Positional arguments for the original component
53 **kwargs: Keyword arguments for the original component
55 Returns:
56 Output hidden states
57 """
58 hidden_states = args[0]
59 hidden_states = self.hook_in(hidden_states)
60 in_module = getattr(self, "in", None) or getattr(self, "input", None)
61 if in_module is not None and hasattr(in_module, "hook_in"): 61 ↛ 63line 61 didn't jump to line 63 because the condition on line 61 was always true
62 hidden_states = in_module.hook_in(hidden_states) # type: ignore[misc]
63 new_args = (hidden_states,) + args[1:]
64 original_component = self.original_component
65 if original_component is None: 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true
66 raise RuntimeError(
67 f"Original component not set for {self.name}. Call set_original_component() first."
68 )
69 output = original_component(*new_args, **kwargs)
70 output = self.hook_out(output)
71 if hasattr(self, "out") and hasattr(self.out, "hook_out"): 71 ↛ 73line 71 didn't jump to line 73 because the condition on line 71 was always true
72 output = self.out.hook_out(output)
73 return output