Coverage for transformer_lens/components/attention.py: 75%
20 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Hooked Transformer Attention Component.
3This module contains all the component :class:`Attention`.
4"""
6from typing import Dict, Optional, Union
8import torch
9import torch.nn as nn
10from transformers.utils import is_bitsandbytes_available
12from transformer_lens.components import AbstractAttention
13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15if is_bitsandbytes_available(): 15 ↛ 16line 15 didn't jump to line 16 because the condition on line 15 was never true
16 from bitsandbytes.nn.modules import Params4bit
19# Attention
20class Attention(AbstractAttention):
21 def __init__(
22 self,
23 cfg: Union[Dict, HookedTransformerConfig],
24 attn_type: str = "global",
25 layer_id: Optional[int] = None,
26 ):
27 """Attention Block - params have shape [head_index, d_model, d_head] (or [head_index, d_head, d_model] for W_O) and multiply on the right. attn_scores refers to query key dot product immediately before attention softmax
29 Convention: All attention pattern-style matrices have shape [batch, head_index, query_pos, key_pos]
31 Args:
32 cfg (Union[Dict, HookedTransformerConfig]): Config
33 attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
34 layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
35 """
36 super().__init__(cfg, attn_type, layer_id)
37 self.cfg = HookedTransformerConfig.unwrap(cfg)
39 if self.cfg.load_in_4bit: 39 ↛ 41line 39 didn't jump to line 41 because the condition on line 39 was never true
40 # 4-bit quantization convention
41 nq = int((self.cfg.d_model * self.cfg.d_model) / 2)
42 self.W_K = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
43 self.W_V = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
44 else:
45 self.W_K = nn.Parameter(
46 torch.empty(
47 self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=self.cfg.dtype
48 )
49 )
50 self.W_V = nn.Parameter(
51 torch.empty(
52 self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=self.cfg.dtype
53 )
54 )
55 self.b_K = nn.Parameter(
56 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
57 )
58 self.b_V = nn.Parameter(
59 torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
60 )