Coverage for transformer_lens/model_bridge/supported_architectures/bloom.py: 49%

39 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Bloom architecture adapter.""" 

2 

3from typing import Any 

4 

5import torch 

6 

7from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

8from transformer_lens.conversion_utils.param_processing_conversion import ( 

9 ParamProcessingConversion, 

10) 

11from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

12from transformer_lens.model_bridge.generalized_components import ( 

13 BloomAttentionBridge, 

14 BloomBlockBridge, 

15 BloomMLPBridge, 

16 EmbeddingBridge, 

17 LinearBridge, 

18 NormalizationBridge, 

19 UnembeddingBridge, 

20) 

21 

22 

23class BloomArchitectureAdapter(ArchitectureAdapter): 

24 """Architecture adapter for Bloom models.""" 

25 

26 def __init__(self, cfg: Any) -> None: 

27 """Initialize the Bloom architecture adapter.""" 

28 super().__init__(cfg) 

29 

30 # Set config variables for weight processing 

31 self.cfg.normalization_type = "LN" 

32 self.cfg.positional_embedding_type = "alibi" 

33 self.cfg.final_rms = False 

34 self.cfg.gated_mlp = False 

35 self.cfg.attn_only = False 

36 

37 self.cfg.default_prepend_bos = False 

38 # After split_qkv_matrix, Q/K/V are individual [n_heads*d_head, d_model] weights. 

39 # Convert to TL format [n_heads, d_model, d_head]. 

40 self.weight_processing_conversions = { 

41 "blocks.{i}.attn.q": ParamProcessingConversion( 

42 tensor_conversion=RearrangeTensorConversion( 

43 "(n h) m -> n m h", 

44 n=self.cfg.n_heads, 

45 ), 

46 ), 

47 "blocks.{i}.attn.k": ParamProcessingConversion( 

48 tensor_conversion=RearrangeTensorConversion( 

49 "(n h) m -> n m h", 

50 n=self.cfg.n_heads, 

51 ), 

52 ), 

53 "blocks.{i}.attn.v": ParamProcessingConversion( 

54 tensor_conversion=RearrangeTensorConversion( 

55 "(n h) m -> n m h", 

56 n=self.cfg.n_heads, 

57 ), 

58 ), 

59 "blocks.{i}.attn.o": ParamProcessingConversion( 

60 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

61 ), 

62 } 

63 

64 self.component_mapping = { 

65 "embed": EmbeddingBridge(name="transformer.word_embeddings"), 

66 "embed_ln": NormalizationBridge( 

67 name="transformer.word_embeddings_layernorm", config=self.cfg 

68 ), 

69 "blocks": BloomBlockBridge( 

70 name="transformer.h", 

71 config=self.cfg, 

72 submodules={ 

73 "ln1": NormalizationBridge(name="input_layernorm", config=self.cfg), 

74 "ln2": NormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

75 "attn": BloomAttentionBridge( 

76 name="self_attention", 

77 config=self.cfg, 

78 split_qkv_matrix=self.split_qkv_matrix, 

79 submodules={ 

80 "qkv": LinearBridge(name="query_key_value"), 

81 "o": LinearBridge(name="dense"), 

82 }, 

83 ), 

84 "mlp": BloomMLPBridge( 

85 name="mlp", 

86 submodules={ 

87 "in": LinearBridge(name="dense_h_to_4h"), 

88 "out": LinearBridge(name="dense_4h_to_h"), 

89 }, 

90 ), 

91 }, 

92 ), 

93 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg), 

94 "unembed": UnembeddingBridge(name="lm_head"), 

95 } 

96 

97 def split_qkv_matrix( 

98 self, original_attention_component: Any 

99 ) -> tuple[torch.nn.Linear, torch.nn.Linear, torch.nn.Linear]: 

100 """Split the QKV matrix into separate linear transformations. 

101 Args: 

102 attention_component: The original attention layer component 

103 Returns: 

104 Tuple of nn.Linear modules for Q, K, and V transformations 

105 """ 

106 

107 # Keep mypy happy 

108 assert original_attention_component is not None 

109 assert original_attention_component.query_key_value is not None 

110 

111 qkv_weights = original_attention_component.query_key_value.weight 

112 

113 # Keep mypy happy 

114 assert isinstance(qkv_weights, torch.Tensor) 

115 

116 # Bloom QKV weights are interleaved: [Q0,K0,V0, Q1,K1,V1, ...] 

117 # i.e. layout is (n_heads, 3, d_head), not (3, n_heads*d_head). 

118 # Reshape to [d_model, n_heads, 3, d_head] to correctly deinterleave. 

119 W_split = qkv_weights.T.reshape(self.cfg.d_model, self.cfg.n_heads, 3, self.cfg.d_head) 

120 

121 # W_Q/K/V shape: [d_model, n_heads, d_head] 

122 W_Q, W_K, W_V = W_split[..., 0, :], W_split[..., 1, :], W_split[..., 2, :] 

123 

124 qkv_bias = original_attention_component.query_key_value.bias 

125 

126 # Keep mypy happy 

127 assert isinstance(qkv_bias, torch.Tensor) 

128 

129 # Same interleaved layout for bias: reshape to [n_heads, 3, d_head] 

130 qkv_bias = qkv_bias.reshape(self.cfg.n_heads, 3, self.cfg.d_head) 

131 

132 # b_Q/K/V shape: [n_heads, d_head] 

133 b_Q, b_K, b_V = qkv_bias[:, 0, :], qkv_bias[:, 1, :], qkv_bias[:, 2, :] 

134 

135 # Create nn.Linear modules 

136 # W_Q shape is [d_model, n_heads, d_head] -> flatten to [d_model, n_heads*d_head] 

137 # nn.Linear expects weight shape [out_features, in_features] = [n_heads*d_head, d_model] 

138 d_out = self.cfg.n_heads * self.cfg.d_head 

139 

140 W_Q_transformation = torch.nn.Linear(self.cfg.d_model, d_out, bias=True) 

141 W_Q_transformation.weight = torch.nn.Parameter(W_Q.reshape(self.cfg.d_model, d_out).T) 

142 W_Q_transformation.bias = torch.nn.Parameter(b_Q.reshape(d_out)) 

143 

144 W_K_transformation = torch.nn.Linear(self.cfg.d_model, d_out, bias=True) 

145 W_K_transformation.weight = torch.nn.Parameter(W_K.reshape(self.cfg.d_model, d_out).T) 

146 W_K_transformation.bias = torch.nn.Parameter(b_K.reshape(d_out)) 

147 

148 W_V_transformation = torch.nn.Linear(self.cfg.d_model, d_out, bias=True) 

149 W_V_transformation.weight = torch.nn.Parameter(W_V.reshape(self.cfg.d_model, d_out).T) 

150 W_V_transformation.bias = torch.nn.Parameter(b_V.reshape(d_out)) 

151 

152 return W_Q_transformation, W_K_transformation, W_V_transformation