Coverage for transformer_lens/components/pos_embed.py: 100%

23 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Hooked Transformer POS Embed Component. 

2 

3This module contains all the component :class:`PosEmbed`. 

4""" 

5 

6from typing import Dict, Optional, Union 

7 

8import einops 

9import torch 

10import torch.nn as nn 

11from jaxtyping import Float, Int 

12 

13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

14from transformer_lens.utils import get_offset_position_ids 

15 

16 

17# Positional Embeddings 

18class PosEmbed(nn.Module): 

19 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

20 super().__init__() 

21 self.cfg = HookedTransformerConfig.unwrap(cfg) 

22 self.W_pos = nn.Parameter( 

23 torch.empty(self.cfg.n_ctx, self.cfg.d_model, dtype=self.cfg.dtype) 

24 ) 

25 

26 def forward( 

27 self, 

28 tokens: Int[torch.Tensor, "batch pos"], 

29 past_kv_pos_offset: int = 0, 

30 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

31 ) -> Float[torch.Tensor, "batch new_pos d_model"]: 

32 """ 

33 Forward pass for positional embeddings. 

34 

35 Args: 

36 tokens (Int[torch.Tensor, "batch pos"]): Input tokens. 

37 past_kv_pos_offset (int, optional): The length of tokens in the past_kv_cache. Defaults to 0. 

38 attention_mask (Int[torch.Tensor, "batch pos"], optional): The attention mask for padded tokens. 

39 Defaults to None. 

40 

41 Returns: 

42 Float[torch.Tensor, "batch pos d_model"]: Absolute position embeddings. 

43 """ 

44 tokens_length = tokens.size(-1) 

45 

46 if attention_mask is None: 

47 pos_embed = self.W_pos[ 

48 past_kv_pos_offset : tokens_length + past_kv_pos_offset, : 

49 ] # [pos, d_model] 

50 batch_pos_embed = einops.repeat( 

51 pos_embed, "pos d_model -> batch pos d_model", batch=tokens.size(0) 

52 ) 

53 

54 else: 

55 # Separated from the no padding case for computational efficiency 

56 # (this code is a bit slower than the code above) 

57 

58 offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask) 

59 pos_embed = self.W_pos[offset_position_ids] # [batch, pos, d_model] 

60 

61 # Set the position embeddings to 0 for pad tokens (this is an arbitrary choice) 

62 padding_mask = ~attention_mask.bool() # [batch, tokens_length] 

63 offset_padding_mask = padding_mask[ 

64 :, past_kv_pos_offset : tokens_length + past_kv_pos_offset 

65 ].unsqueeze( 

66 -1 

67 ) # [batch, pos, 1] 

68 batch_pos_embed = torch.where(offset_padding_mask, 0, pos_embed) 

69 

70 return batch_pos_embed.clone()