Coverage for transformer_lens/model_bridge/supported_architectures/bert.py: 80%

23 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""BERT architecture adapter. 

2 

3This module provides the architecture adapter for BERT models. 

4""" 

5 

6from typing import Any 

7 

8from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

9from transformer_lens.conversion_utils.param_processing_conversion import ( 

10 ParamProcessingConversion, 

11) 

12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

13from transformer_lens.model_bridge.generalized_components import ( 

14 AttentionBridge, 

15 BlockBridge, 

16 EmbeddingBridge, 

17 LinearBridge, 

18 MLPBridge, 

19 NormalizationBridge, 

20 PosEmbedBridge, 

21 UnembeddingBridge, 

22) 

23 

24 

25class BertArchitectureAdapter(ArchitectureAdapter): 

26 """Architecture adapter for BERT models.""" 

27 

28 supports_generation: bool = False 

29 

30 def __init__(self, cfg: Any) -> None: 

31 """Initialize the BERT architecture adapter. 

32 

33 Args: 

34 cfg: The configuration object. 

35 """ 

36 super().__init__(cfg) 

37 

38 # Set config variables for weight processing 

39 self.cfg.normalization_type = "LN" 

40 self.cfg.positional_embedding_type = "standard" 

41 self.cfg.final_rms = False 

42 self.cfg.gated_mlp = False 

43 self.cfg.attn_only = False 

44 

45 # BERT uses post-LN (LayerNorm after residual, not before sublayer). 

46 # fold_ln assumes pre-LN (LN before sublayer) and folds ln1 into attention 

47 # QKV and ln2 into MLP. For post-LN, ln1 output feeds MLP (not attention) 

48 # and ln2 output feeds next block's attention (not MLP), so folding into 

49 # the wrong sublayer produces incorrect results. 

50 self.supports_fold_ln = False 

51 

52 n_heads = self.cfg.n_heads 

53 

54 self.weight_processing_conversions = { 

55 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

56 tensor_conversion=RearrangeTensorConversion( 

57 "(h d_head) d_model -> h d_model d_head", h=n_heads 

58 ), 

59 ), 

60 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

61 tensor_conversion=RearrangeTensorConversion( 

62 "(h d_head) d_model -> h d_model d_head", h=n_heads 

63 ), 

64 ), 

65 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

66 tensor_conversion=RearrangeTensorConversion( 

67 "(h d_head) d_model -> h d_model d_head", h=n_heads 

68 ), 

69 ), 

70 "blocks.{i}.attn.q.bias": ParamProcessingConversion( 

71 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads), 

72 ), 

73 "blocks.{i}.attn.k.bias": ParamProcessingConversion( 

74 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads), 

75 ), 

76 "blocks.{i}.attn.v.bias": ParamProcessingConversion( 

77 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads), 

78 ), 

79 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

80 tensor_conversion=RearrangeTensorConversion( 

81 "d_model (h d_head) -> h d_head d_model", h=n_heads 

82 ), 

83 ), 

84 } 

85 

86 # Set up component mapping 

87 # MLM defaults; prepare_model() adjusts for other task heads (e.g., NSP). 

88 self.component_mapping = { 

89 "embed": EmbeddingBridge(name="bert.embeddings.word_embeddings"), 

90 "pos_embed": PosEmbedBridge(name="bert.embeddings.position_embeddings"), 

91 "blocks": BlockBridge( 

92 name="bert.encoder.layer", 

93 # BERT has no single MLP module (intermediate.dense and output.dense 

94 # are siblings in BertLayer), so the MLPBridge forward is never called 

95 # and mlp.hook_out never fires. Redirect hook_mlp_out to the actual 

96 # MLP output hook (output of the "out" linear layer). 

97 hook_alias_overrides={ 

98 "hook_mlp_out": "mlp.out.hook_out", 

99 "hook_mlp_in": "mlp.in.hook_in", 

100 }, 

101 submodules={ 

102 "ln1": NormalizationBridge( 

103 name="attention.output.LayerNorm", 

104 config=self.cfg, 

105 use_native_layernorm_autograd=True, 

106 ), 

107 "ln2": NormalizationBridge( 

108 name="output.LayerNorm", 

109 config=self.cfg, 

110 use_native_layernorm_autograd=True, 

111 ), 

112 "attn": AttentionBridge( 

113 name="attention", 

114 config=self.cfg, 

115 submodules={ 

116 "q": LinearBridge(name="self.query"), 

117 "k": LinearBridge(name="self.key"), 

118 "v": LinearBridge(name="self.value"), 

119 "o": LinearBridge(name="output.dense"), 

120 }, 

121 ), 

122 "mlp": MLPBridge( 

123 name=None, 

124 config=self.cfg, 

125 submodules={ 

126 "in": LinearBridge(name="intermediate.dense"), 

127 "out": LinearBridge(name="output.dense"), 

128 }, 

129 ), 

130 }, 

131 ), 

132 "unembed": UnembeddingBridge(name="cls.predictions.decoder"), 

133 "ln_final": NormalizationBridge( 

134 name="cls.predictions.transform.LayerNorm", 

135 config=self.cfg, 

136 use_native_layernorm_autograd=True, 

137 ), 

138 } 

139 

140 def prepare_model(self, hf_model: Any) -> None: 

141 """Adjust component mapping based on the actual HF model variant. 

142 

143 BertForMaskedLM has cls.predictions (MLM head). 

144 BertForNextSentencePrediction has cls.seq_relationship (NSP head) 

145 and no MLM-specific LayerNorm. 

146 """ 

147 if hasattr(hf_model, "cls") and hasattr(hf_model.cls, "seq_relationship"): 

148 # NSP model — swap head components 

149 assert self.component_mapping is not None 

150 self.component_mapping["unembed"] = UnembeddingBridge(name="cls.seq_relationship") 

151 self.component_mapping.pop("ln_final", None)