transformer_lens.components.t5_block#

class transformer_lens.components.t5_block.T5Block(cfg: HookedTransformerConfig, block_index: int, is_decoder: bool)#

Bases: Module

T5 decoder Block. Uses T5Layernorm, and T5attention insted of usual ones. Also uses cross attention if is_decoder is True.

forward(resid_pre: Float[Tensor, 'batch pos d_model'], additive_attention_mask: Optional[Float[Tensor, 'batch 1 1 pos']] = None, encoder_additive_attention_mask: Optional[Float[Tensor, 'batch 1 1 encoder_pos']] = None, position_bias: Optional[Float[Tensor, '1 head_index pos kv_pos']] = None, encoder_hidden_states: Optional[Float[Tensor, 'batch encoder_pos d_model']] = None, past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None) Float[Tensor, 'batch pos d_model']#

A single Transformer block.

Parameters:
  • resid_pre (torch.Tensor) – The residual stream - shape [batch, pos, d_model]

  • encoder_hidden_states (torch.Tensor) – The hidden states of the encoder for cross attention - shape [batch, encoder_pos, d_model]

  • cache (HookedTransformerKeyValueCache) – A cache of previous keys and values, used only when generating text. Defaults to None.

  • attention_mask (torch.Tensor, optional) – The attention mask for padded tokens. Defaults to None.

Returns:

_description_

Return type:

_type_