Coverage for transformer_lens/components/attention.py: 75%
20 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1"""Hooked Transformer Attention Component.
3This module contains all the component :class:`Attention`.
4"""
5from typing import Dict, Optional, Union
7import torch
8import torch.nn as nn
9from transformers.utils import is_bitsandbytes_available
11from transformer_lens.components import AbstractAttention
12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
14if is_bitsandbytes_available(): 14 ↛ 15line 14 didn't jump to line 15, because the condition on line 14 was never true
15 from bitsandbytes.nn.modules import Params4bit
18# Attention
19class Attention(AbstractAttention):
20 def __init__(
21 self,
22 cfg: Union[Dict, HookedTransformerConfig],
23 attn_type: str = "global",
24 layer_id: Optional[int] = None,
25 ):
26 """Attention Block - params have shape [head_index, d_model, d_head] (or [head_index, d_head, d_model] for W_O) and multiply on the right. attn_scores refers to query key dot product immediately before attention softmax
28 Convention: All attention pattern-style matrices have shape [batch, head_index, query_pos, key_pos]
30 Args:
31 cfg (Union[Dict, HookedTransformerConfig]): Config
32 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
33 layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
34 """
35 super().__init__(cfg, attn_type, layer_id)
36 self.cfg = HookedTransformerConfig.unwrap(cfg)
38 if self.cfg.load_in_4bit: 38 ↛ 40line 38 didn't jump to line 40, because the condition on line 38 was never true
39 # 4-bit quantization convention
40 nq = int((self.cfg.d_model * self.cfg.d_model) / 2)
41 self.W_K = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
42 self.W_V = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
43 else:
44 self.W_K = nn.Parameter(
45 torch.empty(
46 self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=self.cfg.dtype
47 )
48 )
49 self.W_V = nn.Parameter(
50 torch.empty(
51 self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=self.cfg.dtype
52 )
53 )
54 self.b_K = nn.Parameter(
55 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
56 )
57 self.b_V = nn.Parameter(
58 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
59 )