Coverage for transformer_lens/components/pos_embed.py: 100%
23 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Hooked Transformer POS Embed Component.
3This module contains all the component :class:`PosEmbed`.
4"""
6from typing import Dict, Optional, Union
8import einops
9import torch
10import torch.nn as nn
11from jaxtyping import Float, Int
13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
14from transformer_lens.utils import get_offset_position_ids
17# Positional Embeddings
18class PosEmbed(nn.Module):
19 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
20 super().__init__()
21 self.cfg = HookedTransformerConfig.unwrap(cfg)
22 self.W_pos = nn.Parameter(
23 torch.empty(self.cfg.n_ctx, self.cfg.d_model, dtype=self.cfg.dtype)
24 )
26 def forward(
27 self,
28 tokens: Int[torch.Tensor, "batch pos"],
29 past_kv_pos_offset: int = 0,
30 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
31 ) -> Float[torch.Tensor, "batch new_pos d_model"]:
32 """
33 Forward pass for positional embeddings.
35 Args:
36 tokens (Int[torch.Tensor, "batch pos"]): Input tokens.
37 past_kv_pos_offset (int, optional): The length of tokens in the past_kv_cache. Defaults to 0.
38 attention_mask (Int[torch.Tensor, "batch pos"], optional): The attention mask for padded tokens.
39 Defaults to None.
41 Returns:
42 Float[torch.Tensor, "batch pos d_model"]: Absolute position embeddings.
43 """
44 tokens_length = tokens.size(-1)
46 if attention_mask is None:
47 pos_embed = self.W_pos[
48 past_kv_pos_offset : tokens_length + past_kv_pos_offset, :
49 ] # [pos, d_model]
50 batch_pos_embed = einops.repeat(
51 pos_embed, "pos d_model -> batch pos d_model", batch=tokens.size(0)
52 )
54 else:
55 # Separated from the no padding case for computational efficiency
56 # (this code is a bit slower than the code above)
58 offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask)
59 pos_embed = self.W_pos[offset_position_ids] # [batch, pos, d_model]
61 # Set the position embeddings to 0 for pad tokens (this is an arbitrary choice)
62 padding_mask = ~attention_mask.bool() # [batch, tokens_length]
63 offset_padding_mask = padding_mask[
64 :, past_kv_pos_offset : tokens_length + past_kv_pos_offset
65 ].unsqueeze(
66 -1
67 ) # [batch, pos, 1]
68 batch_pos_embed = torch.where(offset_padding_mask, 0, pos_embed)
70 return batch_pos_embed.clone()