transformer_lens.model_bridge.generalized_components.linear module

Linear bridge component for wrapping linear layers with hook points.

class transformer_lens.model_bridge.generalized_components.linear.LinearBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)

Bases: GeneralizedComponent

Bridge component for linear layers.

This component wraps a linear layer (nn.Linear) and provides hook points for intercepting the input and output activations.

Note: For Conv1D layers (used in GPT-2 style models), use Conv1DBridge instead.

forward(input: Tensor, *args: Any, **kwargs: Any) Tensor

Forward pass through the linear layer with hooks.

Parameters:
  • input – Input tensor

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns:

Output tensor after linear transformation

set_processed_weights(weights: Mapping[str, Tensor | None], verbose: bool = False) None

Set the processed weights by loading them into the original component.

This loads the processed weights directly into the original_component’s parameters, so when forward() delegates to original_component, it uses the processed weights.

Handles Linear layers (shape [out, in]). Also handles 3D weights [n_heads, d_model, d_head] by flattening them first.

Parameters:
  • weights

    Dictionary containing: - weight: The processed weight tensor. Can be:

    • 2D [in, out] format (will be transposed to [out, in] for Linear)

    • 3D [n_heads, d_model, d_head] format (will be flattened to 2D)

    • bias: The processed bias tensor (optional). Can be:
      • 1D [out] format

      • 2D [n_heads, d_head] format (will be flattened to 1D)

  • verbose – If True, print detailed information about weight setting