Coverage for transformer_lens/components/attention.py: 75%

20 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Hooked Transformer Attention Component. 

2 

3This module contains all the component :class:`Attention`. 

4""" 

5 

6from typing import Dict, Optional, Union 

7 

8import torch 

9import torch.nn as nn 

10from transformers.utils import is_bitsandbytes_available 

11 

12from transformer_lens.components import AbstractAttention 

13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

14 

15if is_bitsandbytes_available(): 15 ↛ 16line 15 didn't jump to line 16 because the condition on line 15 was never true

16 from bitsandbytes.nn.modules import Params4bit 

17 

18 

19# Attention 

20class Attention(AbstractAttention): 

21 def __init__( 

22 self, 

23 cfg: Union[Dict, HookedTransformerConfig], 

24 attn_type: str = "global", 

25 layer_id: Optional[int] = None, 

26 ): 

27 """Attention Block - params have shape [head_index, d_model, d_head] (or [head_index, d_head, d_model] for W_O) and multiply on the right. attn_scores refers to query key dot product immediately before attention softmax 

28 

29 Convention: All attention pattern-style matrices have shape [batch, head_index, query_pos, key_pos] 

30 

31 Args: 

32 cfg (Union[Dict, HookedTransformerConfig]): Config 

33 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". 

34 layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. 

35 """ 

36 super().__init__(cfg, attn_type, layer_id) 

37 self.cfg = HookedTransformerConfig.unwrap(cfg) 

38 

39 if self.cfg.load_in_4bit: 39 ↛ 41line 39 didn't jump to line 41 because the condition on line 39 was never true

40 # 4-bit quantization convention 

41 nq = int((self.cfg.d_model * self.cfg.d_model) / 2) 

42 self.W_K = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

43 self.W_V = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

44 else: 

45 self.W_K = nn.Parameter( 

46 torch.empty( 

47 self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=self.cfg.dtype 

48 ) 

49 ) 

50 self.W_V = nn.Parameter( 

51 torch.empty( 

52 self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=self.cfg.dtype 

53 ) 

54 ) 

55 self.b_K = nn.Parameter( 

56 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) 

57 ) 

58 self.b_V = nn.Parameter( 

59 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) 

60 )