transformer_lens.model_bridge.generalized_components.bloom_mlp module¶
BLOOM-specific MLP bridge component.
BLOOM MLP requires a special ‘residual’ argument that standard MLPBridge doesn’t handle. This custom component passes the residual argument through to the original component.
- class transformer_lens.model_bridge.generalized_components.bloom_mlp.BloomMLPBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Bases:
MLPBridgeMLP bridge for BLOOM models that handles residual connections.
BLOOM MLP has a unique forward signature that requires: - hidden_states (first positional arg) - residual (keyword arg): The residual connection tensor
This bridge ensures the residual argument is properly passed through.
- __init__(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Initialize the BLOOM MLP bridge.
- Parameters:
name – The name of the component in the model
config – Optional configuration
submodules – Dictionary of submodules to register (e.g., dense_h_to_4h, dense_4h_to_h)
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass through BLOOM MLP with hooks.
BLOOM MLP requires these arguments: - hidden_states (first positional arg) - residual (second positional arg)
- Parameters:
*args – Input arguments (hidden_states, residual)
**kwargs – Additional keyword arguments (if any)
- Returns:
Output tensor from BLOOM MLP