Coverage for transformer_lens/model_bridge/generalized_components/bloom_mlp.py: 38%
18 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""BLOOM-specific MLP bridge component.
3BLOOM MLP requires a special 'residual' argument that standard MLPBridge doesn't handle.
4This custom component passes the residual argument through to the original component.
5"""
6from typing import Any, Dict, Optional
8import torch
10from transformer_lens.model_bridge.generalized_components.base import (
11 GeneralizedComponent,
12)
13from transformer_lens.model_bridge.generalized_components.mlp import MLPBridge
16class BloomMLPBridge(MLPBridge):
17 """MLP bridge for BLOOM models that handles residual connections.
19 BLOOM MLP has a unique forward signature that requires:
20 - hidden_states (first positional arg)
21 - residual (keyword arg): The residual connection tensor
23 This bridge ensures the residual argument is properly passed through.
24 """
26 def __init__(
27 self,
28 name: Optional[str],
29 config: Optional[Any] = None,
30 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
31 ):
32 """Initialize the BLOOM MLP bridge.
34 Args:
35 name: The name of the component in the model
36 config: Optional configuration
37 submodules: Dictionary of submodules to register (e.g., dense_h_to_4h, dense_4h_to_h)
38 """
39 super().__init__(name, config, submodules or {})
41 def forward(self, *args: Any, **kwargs: Any) -> Any:
42 """Forward pass through BLOOM MLP with hooks.
44 BLOOM MLP requires these arguments:
45 - hidden_states (first positional arg)
46 - residual (second positional arg)
48 Args:
49 *args: Input arguments (hidden_states, residual)
50 **kwargs: Additional keyword arguments (if any)
52 Returns:
53 Output tensor from BLOOM MLP
54 """
55 if self.original_component is None:
56 raise RuntimeError(
57 f"Original component not set for {self.name}. Call set_original_component() first."
58 )
60 # Apply hook_in to hidden_states (first positional argument)
61 if len(args) > 0 and isinstance(args[0], torch.Tensor):
62 hooked_input = self.hook_in(args[0])
63 args = (hooked_input,) + args[1:]
64 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
65 kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"])
67 # BLOOM MLP requires residual as second positional arg
68 # The original BLOOM block passes it, so we just pass everything through
69 # No need to validate since the original component will handle it
71 # Call the original BLOOM MLP component with all arguments
72 output = self.original_component(*args, **kwargs)
74 # Apply hook_out
75 output = self.hook_out(output)
77 return output