Coverage for transformer_lens/components/bert_block.py: 89%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Hooked Transformer Bert Block Component. 

2 

3This module contains all the component :class:`BertBlock`. 

4""" 

5 

6from typing import Optional 

7 

8import torch 

9import torch.nn as nn 

10from jaxtyping import Float 

11 

12from transformer_lens.components import Attention, LayerNorm 

13from transformer_lens.factories.mlp_factory import MLPFactory 

14from transformer_lens.hook_points import HookPoint 

15from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

16from transformer_lens.utils import repeat_along_head_dimension 

17 

18 

19class BertBlock(nn.Module): 

20 """ 

21 BERT Block. Similar to the TransformerBlock, except that the LayerNorms are applied after the attention and MLP, rather than before. 

22 """ 

23 

24 def __init__(self, cfg: HookedTransformerConfig): 

25 super().__init__() 

26 self.cfg = cfg 

27 

28 self.attn = Attention(cfg) 

29 self.ln1 = LayerNorm(cfg) 

30 self.mlp = MLPFactory.create_mlp(self.cfg) 

31 self.ln2 = LayerNorm(cfg) 

32 

33 self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model] 

34 self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model] 

35 self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model] 

36 

37 self.hook_attn_out = HookPoint() # [batch, pos, d_model] 

38 self.hook_mlp_in = HookPoint() # [batch, pos, d_model] 

39 self.hook_mlp_out = HookPoint() # [batch, pos, d_model] 

40 self.hook_resid_pre = HookPoint() # [batch, pos, d_model] 

41 self.hook_resid_mid = HookPoint() # [batch, pos, d_model] 

42 self.hook_resid_post = HookPoint() # [batch, pos, d_model] 

43 self.hook_normalized_resid_post = HookPoint() # [batch, pos, d_model] 

44 

45 def forward( 

46 self, 

47 resid_pre: Float[torch.Tensor, "batch pos d_model"], 

48 additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None, 

49 ) -> Float[torch.Tensor, "batch pos d_model"]: 

50 resid_pre = self.hook_resid_pre(resid_pre) 

51 

52 query_input = resid_pre 

53 key_input = resid_pre 

54 value_input = resid_pre 

55 

56 if self.cfg.use_split_qkv_input: 56 ↛ 57line 56 didn't jump to line 57 because the condition on line 56 was never true

57 n_heads = self.cfg.n_heads 

58 query_input = self.hook_q_input(repeat_along_head_dimension(query_input, n_heads)) 

59 key_input = self.hook_k_input(repeat_along_head_dimension(key_input, n_heads)) 

60 value_input = self.hook_v_input(repeat_along_head_dimension(value_input, n_heads)) 

61 

62 attn_out = self.hook_attn_out( 

63 self.attn( 

64 query_input, 

65 key_input, 

66 value_input, 

67 additive_attention_mask=additive_attention_mask, 

68 ) 

69 ) 

70 resid_mid = self.hook_resid_mid(resid_pre + attn_out) 

71 

72 mlp_in = resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone()) 

73 normalized_resid_mid = self.ln1(mlp_in) 

74 mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid)) 

75 resid_post = self.hook_resid_post(normalized_resid_mid + mlp_out) 

76 normalized_resid_post = self.hook_normalized_resid_post(self.ln2(resid_post)) 

77 

78 return normalized_resid_post