transformer_lens.model_bridge.generalized_components.bloom_mlp module

BLOOM-specific MLP bridge component.

BLOOM MLP requires a special ‘residual’ argument that standard MLPBridge doesn’t handle. This custom component passes the residual argument through to the original component.

class transformer_lens.model_bridge.generalized_components.bloom_mlp.BloomMLPBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Bases: MLPBridge

MLP bridge for BLOOM models that handles residual connections.

BLOOM MLP has a unique forward signature that requires: - hidden_states (first positional arg) - residual (keyword arg): The residual connection tensor

This bridge ensures the residual argument is properly passed through.

__init__(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Initialize the BLOOM MLP bridge.

Parameters:
  • name – The name of the component in the model

  • config – Optional configuration

  • submodules – Dictionary of submodules to register (e.g., dense_h_to_4h, dense_4h_to_h)

forward(*args: Any, **kwargs: Any) Any

Forward pass through BLOOM MLP with hooks.

BLOOM MLP requires these arguments: - hidden_states (first positional arg) - residual (second positional arg)

Parameters:
  • *args – Input arguments (hidden_states, residual)

  • **kwargs – Additional keyword arguments (if any)

Returns:

Output tensor from BLOOM MLP