transformer_lens.model_bridge.generalized_components.t5_block module

T5-specific block bridge component.

This module contains the bridge component for T5 blocks, which have a different structure than standard transformer blocks (3 layers in decoder vs 2 layers).

class transformer_lens.model_bridge.generalized_components.t5_block.T5BlockBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, is_decoder: bool = False)

Bases: GeneralizedComponent

Bridge component for T5 transformer blocks.

T5 has two types of blocks: - Encoder blocks: 2 layers (self-attention, feed-forward) - Decoder blocks: 3 layers (self-attention, cross-attention, feed-forward)

This bridge handles both types based on the presence of cross-attention.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, is_decoder: bool = False)

Initialize the T5 block bridge.

Parameters:
  • name – The name of the component in the model

  • config – Optional configuration

  • submodules – Dictionary of submodules to register

  • is_decoder – Whether this is a decoder block (has cross-attention)

forward(*args: Any, **kwargs: Any) Any

Forward pass through the block bridge.

Parameters:
  • *args – Input arguments

  • **kwargs – Input keyword arguments

Returns:

The output from the original component

get_expected_parameter_names(prefix: str = '') list[str]

Get the expected TransformerLens parameter names for this block.

Parameters:

prefix – Prefix to add to parameter names (e.g., “blocks.0”)

Returns:

List of expected parameter names in TransformerLens format

get_list_size() int

Get the number of transformer blocks.

Returns:

Number of layers in the model

hook_aliases: Dict[str, str | List[str]] = {'hook_resid_post': 'hook_out', 'hook_resid_pre': 'hook_in'}
is_list_item: bool = True
real_components: Dict[str, tuple]
set_original_component(component: Module)

Set the original component and monkey-patch its forward method.

Parameters:

component – The original PyTorch module to wrap

training: bool