Coverage for transformer_lens/components/bert_pooler.py: 100%

19 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-02-20 00:46 +0000

1"""Hooked Encoder Bert Pooler Component. 

2 

3This module contains all the component :class:`BertPooler`. 

4""" 

5from typing import Dict, Union 

6 

7import torch 

8import torch.nn as nn 

9from jaxtyping import Float 

10 

11from transformer_lens.hook_points import HookPoint 

12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

13 

14 

15class BertPooler(nn.Module): 

16 """ 

17 Transforms the [CLS] token representation into a fixed-size sequence embedding. 

18 The purpose of this module is to convert variable-length sequence inputs into a single vector representation suitable for downstream tasks. 

19 (e.g. Next Sentence Prediction) 

20 """ 

21 

22 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

23 super().__init__() 

24 self.cfg = HookedTransformerConfig.unwrap(cfg) 

25 self.W = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_model, dtype=self.cfg.dtype)) 

26 self.b = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 

27 self.activation = nn.Tanh() 

28 self.hook_pooler_out = HookPoint() 

29 

30 def forward( 

31 self, resid: Float[torch.Tensor, "batch pos d_model"] 

32 ) -> Float[torch.Tensor, "batch d_model"]: 

33 first_token_tensor = resid[:, 0] 

34 pooled_output = torch.matmul(first_token_tensor, self.W) + self.b 

35 pooled_output = self.hook_pooler_out(self.activation(pooled_output)) 

36 return pooled_output