Coverage for transformer_lens/components/pos_embed.py: 100%
23 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
1"""Hooked Transformer POS Embed Component.
3This module contains all the component :class:`PosEmbed`.
4"""
5from typing import Dict, Optional, Union
7import einops
8import torch
9import torch.nn as nn
10from jaxtyping import Float, Int
12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
13from transformer_lens.utils import get_offset_position_ids
16# Positional Embeddings
17class PosEmbed(nn.Module):
18 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
19 super().__init__()
20 self.cfg = HookedTransformerConfig.unwrap(cfg)
21 self.W_pos = nn.Parameter(
22 torch.empty(self.cfg.n_ctx, self.cfg.d_model, dtype=self.cfg.dtype)
23 )
25 def forward(
26 self,
27 tokens: Int[torch.Tensor, "batch pos"],
28 past_kv_pos_offset: int = 0,
29 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
30 ) -> Float[torch.Tensor, "batch pos d_model"]:
31 """
32 Forward pass for positional embeddings.
34 Args:
35 tokens (Int[torch.Tensor, "batch pos"]): Input tokens.
36 past_kv_pos_offset (int, optional): The length of tokens in the past_kv_cache. Defaults to 0.
37 attention_mask (Int[torch.Tensor, "batch pos"], optional): The attention mask for padded tokens.
38 Defaults to None.
40 Returns:
41 Float[torch.Tensor, "batch pos d_model"]: Absolute position embeddings.
42 """
43 tokens_length = tokens.size(-1)
45 if attention_mask is None:
46 pos_embed = self.W_pos[
47 past_kv_pos_offset : tokens_length + past_kv_pos_offset, :
48 ] # [pos, d_model]
49 batch_pos_embed = einops.repeat(
50 pos_embed, "pos d_model -> batch pos d_model", batch=tokens.size(0)
51 )
53 else:
54 # Separated from the no padding case for computational efficiency
55 # (this code is a bit slower than the code above)
57 offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask)
58 pos_embed = self.W_pos[offset_position_ids] # [batch, pos, d_model]
60 # Set the position embeddings to 0 for pad tokens (this is an arbitrary choice)
61 padding_mask = ~attention_mask.bool() # [batch, tokens_length]
62 offset_padding_mask = padding_mask[
63 :, past_kv_pos_offset : tokens_length + past_kv_pos_offset
64 ].unsqueeze(
65 -1
66 ) # [batch, pos, 1]
67 batch_pos_embed = torch.where(offset_padding_mask, 0, pos_embed)
69 return batch_pos_embed.clone()