Coverage for transformer_lens/components/pos_embed.py: 100%

23 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-14 00:54 +0000

1"""Hooked Transformer POS Embed Component. 

2 

3This module contains all the component :class:`PosEmbed`. 

4""" 

5from typing import Dict, Optional, Union 

6 

7import einops 

8import torch 

9import torch.nn as nn 

10from jaxtyping import Float, Int 

11 

12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

13from transformer_lens.utils import get_offset_position_ids 

14 

15 

16# Positional Embeddings 

17class PosEmbed(nn.Module): 

18 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

19 super().__init__() 

20 self.cfg = HookedTransformerConfig.unwrap(cfg) 

21 self.W_pos = nn.Parameter( 

22 torch.empty(self.cfg.n_ctx, self.cfg.d_model, dtype=self.cfg.dtype) 

23 ) 

24 

25 def forward( 

26 self, 

27 tokens: Int[torch.Tensor, "batch pos"], 

28 past_kv_pos_offset: int = 0, 

29 attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, 

30 ) -> Float[torch.Tensor, "batch pos d_model"]: 

31 """ 

32 Forward pass for positional embeddings. 

33 

34 Args: 

35 tokens (Int[torch.Tensor, "batch pos"]): Input tokens. 

36 past_kv_pos_offset (int, optional): The length of tokens in the past_kv_cache. Defaults to 0. 

37 attention_mask (Int[torch.Tensor, "batch pos"], optional): The attention mask for padded tokens. 

38 Defaults to None. 

39 

40 Returns: 

41 Float[torch.Tensor, "batch pos d_model"]: Absolute position embeddings. 

42 """ 

43 tokens_length = tokens.size(-1) 

44 

45 if attention_mask is None: 

46 pos_embed = self.W_pos[ 

47 past_kv_pos_offset : tokens_length + past_kv_pos_offset, : 

48 ] # [pos, d_model] 

49 batch_pos_embed = einops.repeat( 

50 pos_embed, "pos d_model -> batch pos d_model", batch=tokens.size(0) 

51 ) 

52 

53 else: 

54 # Separated from the no padding case for computational efficiency 

55 # (this code is a bit slower than the code above) 

56 

57 offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask) 

58 pos_embed = self.W_pos[offset_position_ids] # [batch, pos, d_model] 

59 

60 # Set the position embeddings to 0 for pad tokens (this is an arbitrary choice) 

61 padding_mask = ~attention_mask.bool() # [batch, tokens_length] 

62 offset_padding_mask = padding_mask[ 

63 :, past_kv_pos_offset : tokens_length + past_kv_pos_offset 

64 ].unsqueeze( 

65 -1 

66 ) # [batch, pos, 1] 

67 batch_pos_embed = torch.where(offset_padding_mask, 0, pos_embed) 

68 

69 return batch_pos_embed.clone()