Coverage for transformer_lens/components/bert_block.py: 89%
45 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Hooked Transformer Bert Block Component.
3This module contains all the component :class:`BertBlock`.
4"""
6from typing import Optional
8import torch
9import torch.nn as nn
10from jaxtyping import Float
12from transformer_lens.components import Attention, LayerNorm
13from transformer_lens.factories.mlp_factory import MLPFactory
14from transformer_lens.hook_points import HookPoint
15from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
16from transformer_lens.utils import repeat_along_head_dimension
19class BertBlock(nn.Module):
20 """
21 BERT Block. Similar to the TransformerBlock, except that the LayerNorms are applied after the attention and MLP, rather than before.
22 """
24 def __init__(self, cfg: HookedTransformerConfig):
25 super().__init__()
26 self.cfg = cfg
28 self.attn = Attention(cfg)
29 self.ln1 = LayerNorm(cfg)
30 self.mlp = MLPFactory.create_mlp(self.cfg)
31 self.ln2 = LayerNorm(cfg)
33 self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model]
34 self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model]
35 self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model]
37 self.hook_attn_out = HookPoint() # [batch, pos, d_model]
38 self.hook_mlp_in = HookPoint() # [batch, pos, d_model]
39 self.hook_mlp_out = HookPoint() # [batch, pos, d_model]
40 self.hook_resid_pre = HookPoint() # [batch, pos, d_model]
41 self.hook_resid_mid = HookPoint() # [batch, pos, d_model]
42 self.hook_resid_post = HookPoint() # [batch, pos, d_model]
43 self.hook_normalized_resid_post = HookPoint() # [batch, pos, d_model]
45 def forward(
46 self,
47 resid_pre: Float[torch.Tensor, "batch pos d_model"],
48 additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None,
49 ) -> Float[torch.Tensor, "batch pos d_model"]:
50 resid_pre = self.hook_resid_pre(resid_pre)
52 query_input = resid_pre
53 key_input = resid_pre
54 value_input = resid_pre
56 if self.cfg.use_split_qkv_input: 56 ↛ 57line 56 didn't jump to line 57 because the condition on line 56 was never true
57 n_heads = self.cfg.n_heads
58 query_input = self.hook_q_input(repeat_along_head_dimension(query_input, n_heads))
59 key_input = self.hook_k_input(repeat_along_head_dimension(key_input, n_heads))
60 value_input = self.hook_v_input(repeat_along_head_dimension(value_input, n_heads))
62 attn_out = self.hook_attn_out(
63 self.attn(
64 query_input,
65 key_input,
66 value_input,
67 additive_attention_mask=additive_attention_mask,
68 )
69 )
70 resid_mid = self.hook_resid_mid(resid_pre + attn_out)
72 mlp_in = resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone())
73 normalized_resid_mid = self.ln1(mlp_in)
74 mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid))
75 resid_post = self.hook_resid_post(normalized_resid_mid + mlp_out)
76 normalized_resid_post = self.hook_normalized_resid_post(self.ln2(resid_post))
78 return normalized_resid_post