Coverage for transformer_lens/components/attention.py: 75%

20 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-14 00:54 +0000

1"""Hooked Transformer Attention Component. 

2 

3This module contains all the component :class:`Attention`. 

4""" 

5from typing import Dict, Optional, Union 

6 

7import torch 

8import torch.nn as nn 

9from transformers.utils import is_bitsandbytes_available 

10 

11from transformer_lens.components import AbstractAttention 

12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

13 

14if is_bitsandbytes_available(): 14 ↛ 15line 14 didn't jump to line 15, because the condition on line 14 was never true

15 from bitsandbytes.nn.modules import Params4bit 

16 

17 

18# Attention 

19class Attention(AbstractAttention): 

20 def __init__( 

21 self, 

22 cfg: Union[Dict, HookedTransformerConfig], 

23 attn_type: str = "global", 

24 layer_id: Optional[int] = None, 

25 ): 

26 """Attention Block - params have shape [head_index, d_model, d_head] (or [head_index, d_head, d_model] for W_O) and multiply on the right. attn_scores refers to query key dot product immediately before attention softmax 

27 

28 Convention: All attention pattern-style matrices have shape [batch, head_index, query_pos, key_pos] 

29 

30 Args: 

31 cfg (Union[Dict, HookedTransformerConfig]): Config 

32 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". 

33 layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. 

34 """ 

35 super().__init__(cfg, attn_type, layer_id) 

36 self.cfg = HookedTransformerConfig.unwrap(cfg) 

37 

38 if self.cfg.load_in_4bit: 38 ↛ 40line 38 didn't jump to line 40, because the condition on line 38 was never true

39 # 4-bit quantization convention 

40 nq = int((self.cfg.d_model * self.cfg.d_model) / 2) 

41 self.W_K = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

42 self.W_V = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

43 else: 

44 self.W_K = nn.Parameter( 

45 torch.empty( 

46 self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=self.cfg.dtype 

47 ) 

48 ) 

49 self.W_V = nn.Parameter( 

50 torch.empty( 

51 self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=self.cfg.dtype 

52 ) 

53 ) 

54 self.b_K = nn.Parameter( 

55 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) 

56 ) 

57 self.b_V = nn.Parameter( 

58 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) 

59 )