transformer_lens.components.transformer_block#

Hooked Transformer Transformer Block Component.

This module contains all the component TransformerBlock.

class transformer_lens.components.transformer_block.TransformerBlock(cfg: Union[Dict, HookedTransformerConfig], block_index)#

Bases: Module

apply_mlp(normalized_resid: Float[Tensor, 'batch pos d_model']) Float[Tensor, 'batch pos d_model']#

Centralized point where the MLP is applied to the forward pass

Returns:

Our resulting tensor

Return type:

Float[torch.Tensor, “batch pos d_model”]

forward(resid_pre: Float[Tensor, 'batch pos d_model'], shortformer_pos_embed: Optional[Float[Tensor, 'batch pos d_model']] = None, past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, attention_mask: Optional[Int[Tensor, 'batch offset_pos']] = None) Float[Tensor, 'batch pos d_model']#

A single Transformer block.

Parameters:
  • resid_pre (torch.Tensor) – The residual stream - shape [batch, pos, d_model]

  • cache (HookedTransformerKeyValueCache) – A cache of previous keys and values, used only when generating text. Defaults to None.

  • shortformer_pos_embed (torch.Tensor, optional) – Only used for positional_embeddings_type == “shortformer”. The positional embeddings. See HookedTransformerConfig for details. Defaults to None.

  • attention_mask (torch.Tensor, optional) – The attention mask for padded tokens. Defaults to None.

Returns:

Our resulting tensor

Return type:

Float[torch.Tensor, “batch pos d_model”]

ln1: Module#
ln2: Module#
mlp: CanBeUsedAsMLP#