transformer_lens.components.pos_embed#

Hooked Transformer POS Embed Component.

This module contains all the component PosEmbed.

class transformer_lens.components.pos_embed.PosEmbed(cfg: Union[Dict, HookedTransformerConfig])#

Bases: Module

forward(tokens: Int[Tensor, 'batch pos'], past_kv_pos_offset: int = 0, attention_mask: Optional[Int[Tensor, 'batch offset_pos']] = None) Float[Tensor, 'batch pos d_model']#

Forward pass for positional embeddings.

Parameters:
  • tokens (Int[torch.Tensor, "batch pos"]) – Input tokens.

  • past_kv_pos_offset (int, optional) – The length of tokens in the past_kv_cache. Defaults to 0.

  • attention_mask (Int[torch.Tensor, "batch pos"], optional) – The attention mask for padded tokens. Defaults to None.

Returns:

Absolute position embeddings.

Return type:

Float[torch.Tensor, “batch pos d_model”]