transformer_lens.components.transformer_block#
Hooked Transformer Transformer Block Component.
This module contains all the component TransformerBlock
.
- class transformer_lens.components.transformer_block.TransformerBlock(cfg: Union[Dict, HookedTransformerConfig], block_index)#
Bases:
Module
- apply_mlp(normalized_resid: Float[Tensor, 'batch pos d_model']) Float[Tensor, 'batch pos d_model'] #
Centralized point where the MLP is applied to the forward pass
- Returns:
Our resulting tensor
- Return type:
Float[torch.Tensor, “batch pos d_model”]
- forward(resid_pre: Float[Tensor, 'batch pos d_model'], shortformer_pos_embed: Optional[Float[Tensor, 'batch pos d_model']] = None, past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, attention_mask: Optional[Int[Tensor, 'batch offset_pos']] = None) Float[Tensor, 'batch pos d_model'] #
A single Transformer block.
- Parameters:
resid_pre (torch.Tensor) – The residual stream - shape [batch, pos, d_model]
cache (HookedTransformerKeyValueCache) – A cache of previous keys and values, used only when generating text. Defaults to None.
shortformer_pos_embed (torch.Tensor, optional) – Only used for positional_embeddings_type == “shortformer”. The positional embeddings. See HookedTransformerConfig for details. Defaults to None.
attention_mask (torch.Tensor, optional) – The attention mask for padded tokens. Defaults to None.
- Returns:
Our resulting tensor
- Return type:
Float[torch.Tensor, “batch pos d_model”]
- ln1: Module#
- ln2: Module#
- mlp: CanBeUsedAsMLP#