transformer_lens.components.t5_block#
- class transformer_lens.components.t5_block.T5Block(cfg: HookedTransformerConfig, block_index: int, is_decoder: bool)#
Bases:
Module
T5 decoder Block. Uses T5Layernorm, and T5attention insted of usual ones. Also uses cross attention if is_decoder is True.
- forward(resid_pre: Float[Tensor, 'batch pos d_model'], additive_attention_mask: Optional[Float[Tensor, 'batch 1 1 pos']] = None, encoder_additive_attention_mask: Optional[Float[Tensor, 'batch 1 1 encoder_pos']] = None, position_bias: Optional[Float[Tensor, '1 head_index pos kv_pos']] = None, encoder_hidden_states: Optional[Float[Tensor, 'batch encoder_pos d_model']] = None, past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None) Float[Tensor, 'batch pos d_model'] #
A single Transformer block.
- Parameters:
resid_pre (torch.Tensor) – The residual stream - shape [batch, pos, d_model]
encoder_hidden_states (torch.Tensor) – The hidden states of the encoder for cross attention - shape [batch, encoder_pos, d_model]
cache (HookedTransformerKeyValueCache) – A cache of previous keys and values, used only when generating text. Defaults to None.
attention_mask (torch.Tensor, optional) – The attention mask for padded tokens. Defaults to None.
- Returns:
_description_
- Return type:
_type_