Coverage for transformer_lens/model_bridge/supported_architectures/bert.py: 79%
22 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""BERT architecture adapter.
3This module provides the architecture adapter for BERT models.
4"""
6from typing import Any
8from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
9from transformer_lens.conversion_utils.param_processing_conversion import (
10 ParamProcessingConversion,
11)
12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
13from transformer_lens.model_bridge.generalized_components import (
14 AttentionBridge,
15 BlockBridge,
16 EmbeddingBridge,
17 LinearBridge,
18 MLPBridge,
19 NormalizationBridge,
20 PosEmbedBridge,
21 UnembeddingBridge,
22)
25class BertArchitectureAdapter(ArchitectureAdapter):
26 """Architecture adapter for BERT models."""
28 def __init__(self, cfg: Any) -> None:
29 """Initialize the BERT architecture adapter.
31 Args:
32 cfg: The configuration object.
33 """
34 super().__init__(cfg)
36 # Set config variables for weight processing
37 self.cfg.normalization_type = "LN"
38 self.cfg.positional_embedding_type = "standard"
39 self.cfg.final_rms = False
40 self.cfg.gated_mlp = False
41 self.cfg.attn_only = False
43 # BERT uses post-LN (LayerNorm after residual, not before sublayer).
44 # fold_ln assumes pre-LN (LN before sublayer) and folds ln1 into attention
45 # QKV and ln2 into MLP. For post-LN, ln1 output feeds MLP (not attention)
46 # and ln2 output feeds next block's attention (not MLP), so folding into
47 # the wrong sublayer produces incorrect results.
48 self.supports_fold_ln = False
50 n_heads = self.cfg.n_heads
52 self.weight_processing_conversions = {
53 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
54 tensor_conversion=RearrangeTensorConversion(
55 "(h d_head) d_model -> h d_model d_head", h=n_heads
56 ),
57 ),
58 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
59 tensor_conversion=RearrangeTensorConversion(
60 "(h d_head) d_model -> h d_model d_head", h=n_heads
61 ),
62 ),
63 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
64 tensor_conversion=RearrangeTensorConversion(
65 "(h d_head) d_model -> h d_model d_head", h=n_heads
66 ),
67 ),
68 "blocks.{i}.attn.q.bias": ParamProcessingConversion(
69 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads),
70 ),
71 "blocks.{i}.attn.k.bias": ParamProcessingConversion(
72 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads),
73 ),
74 "blocks.{i}.attn.v.bias": ParamProcessingConversion(
75 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads),
76 ),
77 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
78 tensor_conversion=RearrangeTensorConversion(
79 "d_model (h d_head) -> h d_head d_model", h=n_heads
80 ),
81 ),
82 }
84 # Set up component mapping
85 # MLM defaults; prepare_model() adjusts for other task heads (e.g., NSP).
86 self.component_mapping = {
87 "embed": EmbeddingBridge(name="bert.embeddings.word_embeddings"),
88 "pos_embed": PosEmbedBridge(name="bert.embeddings.position_embeddings"),
89 "blocks": BlockBridge(
90 name="bert.encoder.layer",
91 # BERT has no single MLP module (intermediate.dense and output.dense
92 # are siblings in BertLayer), so the MLPBridge forward is never called
93 # and mlp.hook_out never fires. Redirect hook_mlp_out to the actual
94 # MLP output hook (output of the "out" linear layer).
95 hook_alias_overrides={
96 "hook_mlp_out": "mlp.out.hook_out",
97 "hook_mlp_in": "mlp.in.hook_in",
98 },
99 submodules={
100 "ln1": NormalizationBridge(
101 name="attention.output.LayerNorm",
102 config=self.cfg,
103 use_native_layernorm_autograd=True,
104 ),
105 "ln2": NormalizationBridge(
106 name="output.LayerNorm",
107 config=self.cfg,
108 use_native_layernorm_autograd=True,
109 ),
110 "attn": AttentionBridge(
111 name="attention",
112 config=self.cfg,
113 submodules={
114 "q": LinearBridge(name="self.query"),
115 "k": LinearBridge(name="self.key"),
116 "v": LinearBridge(name="self.value"),
117 "o": LinearBridge(name="output.dense"),
118 },
119 ),
120 "mlp": MLPBridge(
121 name=None,
122 config=self.cfg,
123 submodules={
124 "in": LinearBridge(name="intermediate.dense"),
125 "out": LinearBridge(name="output.dense"),
126 },
127 ),
128 },
129 ),
130 "unembed": UnembeddingBridge(name="cls.predictions.decoder"),
131 "ln_final": NormalizationBridge(
132 name="cls.predictions.transform.LayerNorm",
133 config=self.cfg,
134 use_native_layernorm_autograd=True,
135 ),
136 }
138 def prepare_model(self, hf_model: Any) -> None:
139 """Adjust component mapping based on the actual HF model variant.
141 BertForMaskedLM has cls.predictions (MLM head).
142 BertForNextSentencePrediction has cls.seq_relationship (NSP head)
143 and no MLM-specific LayerNorm.
144 """
145 if hasattr(hf_model, "cls") and hasattr(hf_model.cls, "seq_relationship"):
146 # NSP model — swap head components
147 assert self.component_mapping is not None
148 self.component_mapping["unembed"] = UnembeddingBridge(name="cls.seq_relationship")
149 self.component_mapping.pop("ln_final", None)