Coverage for transformer_lens/model_bridge/supported_architectures/bert.py: 80%
23 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""BERT architecture adapter.
3This module provides the architecture adapter for BERT models.
4"""
6from typing import Any
8from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
9from transformer_lens.conversion_utils.param_processing_conversion import (
10 ParamProcessingConversion,
11)
12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
13from transformer_lens.model_bridge.generalized_components import (
14 AttentionBridge,
15 BlockBridge,
16 EmbeddingBridge,
17 LinearBridge,
18 MLPBridge,
19 NormalizationBridge,
20 PosEmbedBridge,
21 UnembeddingBridge,
22)
25class BertArchitectureAdapter(ArchitectureAdapter):
26 """Architecture adapter for BERT models."""
28 supports_generation: bool = False
30 def __init__(self, cfg: Any) -> None:
31 """Initialize the BERT architecture adapter.
33 Args:
34 cfg: The configuration object.
35 """
36 super().__init__(cfg)
38 # Set config variables for weight processing
39 self.cfg.normalization_type = "LN"
40 self.cfg.positional_embedding_type = "standard"
41 self.cfg.final_rms = False
42 self.cfg.gated_mlp = False
43 self.cfg.attn_only = False
45 # BERT uses post-LN (LayerNorm after residual, not before sublayer).
46 # fold_ln assumes pre-LN (LN before sublayer) and folds ln1 into attention
47 # QKV and ln2 into MLP. For post-LN, ln1 output feeds MLP (not attention)
48 # and ln2 output feeds next block's attention (not MLP), so folding into
49 # the wrong sublayer produces incorrect results.
50 self.supports_fold_ln = False
52 n_heads = self.cfg.n_heads
54 self.weight_processing_conversions = {
55 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
56 tensor_conversion=RearrangeTensorConversion(
57 "(h d_head) d_model -> h d_model d_head", h=n_heads
58 ),
59 ),
60 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
61 tensor_conversion=RearrangeTensorConversion(
62 "(h d_head) d_model -> h d_model d_head", h=n_heads
63 ),
64 ),
65 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
66 tensor_conversion=RearrangeTensorConversion(
67 "(h d_head) d_model -> h d_model d_head", h=n_heads
68 ),
69 ),
70 "blocks.{i}.attn.q.bias": ParamProcessingConversion(
71 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads),
72 ),
73 "blocks.{i}.attn.k.bias": ParamProcessingConversion(
74 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads),
75 ),
76 "blocks.{i}.attn.v.bias": ParamProcessingConversion(
77 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads),
78 ),
79 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
80 tensor_conversion=RearrangeTensorConversion(
81 "d_model (h d_head) -> h d_head d_model", h=n_heads
82 ),
83 ),
84 }
86 # Set up component mapping
87 # MLM defaults; prepare_model() adjusts for other task heads (e.g., NSP).
88 self.component_mapping = {
89 "embed": EmbeddingBridge(name="bert.embeddings.word_embeddings"),
90 "pos_embed": PosEmbedBridge(name="bert.embeddings.position_embeddings"),
91 "blocks": BlockBridge(
92 name="bert.encoder.layer",
93 # BERT has no single MLP module (intermediate.dense and output.dense
94 # are siblings in BertLayer), so the MLPBridge forward is never called
95 # and mlp.hook_out never fires. Redirect hook_mlp_out to the actual
96 # MLP output hook (output of the "out" linear layer).
97 hook_alias_overrides={
98 "hook_mlp_out": "mlp.out.hook_out",
99 "hook_mlp_in": "mlp.in.hook_in",
100 },
101 submodules={
102 "ln1": NormalizationBridge(
103 name="attention.output.LayerNorm",
104 config=self.cfg,
105 use_native_layernorm_autograd=True,
106 ),
107 "ln2": NormalizationBridge(
108 name="output.LayerNorm",
109 config=self.cfg,
110 use_native_layernorm_autograd=True,
111 ),
112 "attn": AttentionBridge(
113 name="attention",
114 config=self.cfg,
115 submodules={
116 "q": LinearBridge(name="self.query"),
117 "k": LinearBridge(name="self.key"),
118 "v": LinearBridge(name="self.value"),
119 "o": LinearBridge(name="output.dense"),
120 },
121 ),
122 "mlp": MLPBridge(
123 name=None,
124 config=self.cfg,
125 submodules={
126 "in": LinearBridge(name="intermediate.dense"),
127 "out": LinearBridge(name="output.dense"),
128 },
129 ),
130 },
131 ),
132 "unembed": UnembeddingBridge(name="cls.predictions.decoder"),
133 "ln_final": NormalizationBridge(
134 name="cls.predictions.transform.LayerNorm",
135 config=self.cfg,
136 use_native_layernorm_autograd=True,
137 ),
138 }
140 def prepare_model(self, hf_model: Any) -> None:
141 """Adjust component mapping based on the actual HF model variant.
143 BertForMaskedLM has cls.predictions (MLM head).
144 BertForNextSentencePrediction has cls.seq_relationship (NSP head)
145 and no MLM-specific LayerNorm.
146 """
147 if hasattr(hf_model, "cls") and hasattr(hf_model.cls, "seq_relationship"):
148 # NSP model — swap head components
149 assert self.component_mapping is not None
150 self.component_mapping["unembed"] = UnembeddingBridge(name="cls.seq_relationship")
151 self.component_mapping.pop("ln_final", None)