Coverage for transformer_lens/model_bridge/supported_architectures/bert.py: 79%

22 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""BERT architecture adapter. 

2 

3This module provides the architecture adapter for BERT models. 

4""" 

5 

6from typing import Any 

7 

8from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

9from transformer_lens.conversion_utils.param_processing_conversion import ( 

10 ParamProcessingConversion, 

11) 

12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

13from transformer_lens.model_bridge.generalized_components import ( 

14 AttentionBridge, 

15 BlockBridge, 

16 EmbeddingBridge, 

17 LinearBridge, 

18 MLPBridge, 

19 NormalizationBridge, 

20 PosEmbedBridge, 

21 UnembeddingBridge, 

22) 

23 

24 

25class BertArchitectureAdapter(ArchitectureAdapter): 

26 """Architecture adapter for BERT models.""" 

27 

28 def __init__(self, cfg: Any) -> None: 

29 """Initialize the BERT architecture adapter. 

30 

31 Args: 

32 cfg: The configuration object. 

33 """ 

34 super().__init__(cfg) 

35 

36 # Set config variables for weight processing 

37 self.cfg.normalization_type = "LN" 

38 self.cfg.positional_embedding_type = "standard" 

39 self.cfg.final_rms = False 

40 self.cfg.gated_mlp = False 

41 self.cfg.attn_only = False 

42 

43 # BERT uses post-LN (LayerNorm after residual, not before sublayer). 

44 # fold_ln assumes pre-LN (LN before sublayer) and folds ln1 into attention 

45 # QKV and ln2 into MLP. For post-LN, ln1 output feeds MLP (not attention) 

46 # and ln2 output feeds next block's attention (not MLP), so folding into 

47 # the wrong sublayer produces incorrect results. 

48 self.supports_fold_ln = False 

49 

50 n_heads = self.cfg.n_heads 

51 

52 self.weight_processing_conversions = { 

53 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

54 tensor_conversion=RearrangeTensorConversion( 

55 "(h d_head) d_model -> h d_model d_head", h=n_heads 

56 ), 

57 ), 

58 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

59 tensor_conversion=RearrangeTensorConversion( 

60 "(h d_head) d_model -> h d_model d_head", h=n_heads 

61 ), 

62 ), 

63 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

64 tensor_conversion=RearrangeTensorConversion( 

65 "(h d_head) d_model -> h d_model d_head", h=n_heads 

66 ), 

67 ), 

68 "blocks.{i}.attn.q.bias": ParamProcessingConversion( 

69 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads), 

70 ), 

71 "blocks.{i}.attn.k.bias": ParamProcessingConversion( 

72 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads), 

73 ), 

74 "blocks.{i}.attn.v.bias": ParamProcessingConversion( 

75 tensor_conversion=RearrangeTensorConversion("(h d_head) -> h d_head", h=n_heads), 

76 ), 

77 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

78 tensor_conversion=RearrangeTensorConversion( 

79 "d_model (h d_head) -> h d_head d_model", h=n_heads 

80 ), 

81 ), 

82 } 

83 

84 # Set up component mapping 

85 # MLM defaults; prepare_model() adjusts for other task heads (e.g., NSP). 

86 self.component_mapping = { 

87 "embed": EmbeddingBridge(name="bert.embeddings.word_embeddings"), 

88 "pos_embed": PosEmbedBridge(name="bert.embeddings.position_embeddings"), 

89 "blocks": BlockBridge( 

90 name="bert.encoder.layer", 

91 # BERT has no single MLP module (intermediate.dense and output.dense 

92 # are siblings in BertLayer), so the MLPBridge forward is never called 

93 # and mlp.hook_out never fires. Redirect hook_mlp_out to the actual 

94 # MLP output hook (output of the "out" linear layer). 

95 hook_alias_overrides={ 

96 "hook_mlp_out": "mlp.out.hook_out", 

97 "hook_mlp_in": "mlp.in.hook_in", 

98 }, 

99 submodules={ 

100 "ln1": NormalizationBridge( 

101 name="attention.output.LayerNorm", 

102 config=self.cfg, 

103 use_native_layernorm_autograd=True, 

104 ), 

105 "ln2": NormalizationBridge( 

106 name="output.LayerNorm", 

107 config=self.cfg, 

108 use_native_layernorm_autograd=True, 

109 ), 

110 "attn": AttentionBridge( 

111 name="attention", 

112 config=self.cfg, 

113 submodules={ 

114 "q": LinearBridge(name="self.query"), 

115 "k": LinearBridge(name="self.key"), 

116 "v": LinearBridge(name="self.value"), 

117 "o": LinearBridge(name="output.dense"), 

118 }, 

119 ), 

120 "mlp": MLPBridge( 

121 name=None, 

122 config=self.cfg, 

123 submodules={ 

124 "in": LinearBridge(name="intermediate.dense"), 

125 "out": LinearBridge(name="output.dense"), 

126 }, 

127 ), 

128 }, 

129 ), 

130 "unembed": UnembeddingBridge(name="cls.predictions.decoder"), 

131 "ln_final": NormalizationBridge( 

132 name="cls.predictions.transform.LayerNorm", 

133 config=self.cfg, 

134 use_native_layernorm_autograd=True, 

135 ), 

136 } 

137 

138 def prepare_model(self, hf_model: Any) -> None: 

139 """Adjust component mapping based on the actual HF model variant. 

140 

141 BertForMaskedLM has cls.predictions (MLM head). 

142 BertForNextSentencePrediction has cls.seq_relationship (NSP head) 

143 and no MLM-specific LayerNorm. 

144 """ 

145 if hasattr(hf_model, "cls") and hasattr(hf_model.cls, "seq_relationship"): 

146 # NSP model — swap head components 

147 assert self.component_mapping is not None 

148 self.component_mapping["unembed"] = UnembeddingBridge(name="cls.seq_relationship") 

149 self.component_mapping.pop("ln_final", None)