transformer_lens.model_bridge.generalized_components.linear module¶
Linear bridge component for wrapping linear layers with hook points.
- class transformer_lens.model_bridge.generalized_components.linear.LinearBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)¶
Bases:
GeneralizedComponentBridge component for linear layers.
This component wraps a linear layer (nn.Linear) and provides hook points for intercepting the input and output activations.
Note: For Conv1D layers (used in GPT-2 style models), use Conv1DBridge instead.
- forward(input: Tensor, *args: Any, **kwargs: Any) Tensor¶
Forward pass through the linear layer with hooks.
- Parameters:
input – Input tensor
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns:
Output tensor after linear transformation
- set_processed_weights(weights: Mapping[str, Tensor | None], verbose: bool = False) None¶
Set the processed weights by loading them into the original component.
This loads the processed weights directly into the original_component’s parameters, so when forward() delegates to original_component, it uses the processed weights.
Handles Linear layers (shape [out, in]). Also handles 3D weights [n_heads, d_model, d_head] by flattening them first.
- Parameters:
weights –
Dictionary containing: - weight: The processed weight tensor. Can be:
2D [in, out] format (will be transposed to [out, in] for Linear)
3D [n_heads, d_model, d_head] format (will be flattened to 2D)
- bias: The processed bias tensor (optional). Can be:
1D [out] format
2D [n_heads, d_head] format (will be flattened to 1D)
verbose – If True, print detailed information about weight setting