Coverage for transformer_lens/components/bert_pooler.py: 100%
19 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Hooked Encoder Bert Pooler Component.
3This module contains all the component :class:`BertPooler`.
4"""
6from typing import Dict, Union
8import torch
9import torch.nn as nn
10from jaxtyping import Float
12from transformer_lens.hook_points import HookPoint
13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
16class BertPooler(nn.Module):
17 """
18 Transforms the [CLS] token representation into a fixed-size sequence embedding.
19 The purpose of this module is to convert variable-length sequence inputs into a single vector representation suitable for downstream tasks.
20 (e.g. Next Sentence Prediction)
21 """
23 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
24 super().__init__()
25 self.cfg = HookedTransformerConfig.unwrap(cfg)
26 self.W = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_model, dtype=self.cfg.dtype))
27 self.b = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
28 self.activation = nn.Tanh()
29 self.hook_pooler_out = HookPoint()
31 def forward(
32 self, resid: Float[torch.Tensor, "batch pos d_model"]
33 ) -> Float[torch.Tensor, "batch d_model"]:
34 first_token_tensor = resid[:, 0]
35 pooled_output = torch.matmul(first_token_tensor, self.W) + self.b
36 pooled_output = self.hook_pooler_out(self.activation(pooled_output))
37 return pooled_output