Coverage for transformer_lens/model_bridge/supported_architectures/bloom.py: 49%
39 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Bloom architecture adapter."""
3from typing import Any
5import torch
7from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
8from transformer_lens.conversion_utils.param_processing_conversion import (
9 ParamProcessingConversion,
10)
11from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
12from transformer_lens.model_bridge.generalized_components import (
13 BloomAttentionBridge,
14 BloomBlockBridge,
15 BloomMLPBridge,
16 EmbeddingBridge,
17 LinearBridge,
18 NormalizationBridge,
19 UnembeddingBridge,
20)
23class BloomArchitectureAdapter(ArchitectureAdapter):
24 """Architecture adapter for Bloom models."""
26 def __init__(self, cfg: Any) -> None:
27 """Initialize the Bloom architecture adapter."""
28 super().__init__(cfg)
30 # Set config variables for weight processing
31 self.cfg.normalization_type = "LN"
32 self.cfg.positional_embedding_type = "alibi"
33 self.cfg.final_rms = False
34 self.cfg.gated_mlp = False
35 self.cfg.attn_only = False
37 self.cfg.default_prepend_bos = False
38 # After split_qkv_matrix, Q/K/V are individual [n_heads*d_head, d_model] weights.
39 # Convert to TL format [n_heads, d_model, d_head].
40 self.weight_processing_conversions = {
41 "blocks.{i}.attn.q": ParamProcessingConversion(
42 tensor_conversion=RearrangeTensorConversion(
43 "(n h) m -> n m h",
44 n=self.cfg.n_heads,
45 ),
46 ),
47 "blocks.{i}.attn.k": ParamProcessingConversion(
48 tensor_conversion=RearrangeTensorConversion(
49 "(n h) m -> n m h",
50 n=self.cfg.n_heads,
51 ),
52 ),
53 "blocks.{i}.attn.v": ParamProcessingConversion(
54 tensor_conversion=RearrangeTensorConversion(
55 "(n h) m -> n m h",
56 n=self.cfg.n_heads,
57 ),
58 ),
59 "blocks.{i}.attn.o": ParamProcessingConversion(
60 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
61 ),
62 }
64 self.component_mapping = {
65 "embed": EmbeddingBridge(name="transformer.word_embeddings"),
66 "embed_ln": NormalizationBridge(
67 name="transformer.word_embeddings_layernorm", config=self.cfg
68 ),
69 "blocks": BloomBlockBridge(
70 name="transformer.h",
71 config=self.cfg,
72 submodules={
73 "ln1": NormalizationBridge(name="input_layernorm", config=self.cfg),
74 "ln2": NormalizationBridge(name="post_attention_layernorm", config=self.cfg),
75 "attn": BloomAttentionBridge(
76 name="self_attention",
77 config=self.cfg,
78 split_qkv_matrix=self.split_qkv_matrix,
79 submodules={
80 "qkv": LinearBridge(name="query_key_value"),
81 "o": LinearBridge(name="dense"),
82 },
83 ),
84 "mlp": BloomMLPBridge(
85 name="mlp",
86 submodules={
87 "in": LinearBridge(name="dense_h_to_4h"),
88 "out": LinearBridge(name="dense_4h_to_h"),
89 },
90 ),
91 },
92 ),
93 "ln_final": NormalizationBridge(name="transformer.ln_f", config=self.cfg),
94 "unembed": UnembeddingBridge(name="lm_head"),
95 }
97 def split_qkv_matrix(
98 self, original_attention_component: Any
99 ) -> tuple[torch.nn.Linear, torch.nn.Linear, torch.nn.Linear]:
100 """Split the QKV matrix into separate linear transformations.
101 Args:
102 attention_component: The original attention layer component
103 Returns:
104 Tuple of nn.Linear modules for Q, K, and V transformations
105 """
107 # Keep mypy happy
108 assert original_attention_component is not None
109 assert original_attention_component.query_key_value is not None
111 qkv_weights = original_attention_component.query_key_value.weight
113 # Keep mypy happy
114 assert isinstance(qkv_weights, torch.Tensor)
116 # Bloom QKV weights are interleaved: [Q0,K0,V0, Q1,K1,V1, ...]
117 # i.e. layout is (n_heads, 3, d_head), not (3, n_heads*d_head).
118 # Reshape to [d_model, n_heads, 3, d_head] to correctly deinterleave.
119 W_split = qkv_weights.T.reshape(self.cfg.d_model, self.cfg.n_heads, 3, self.cfg.d_head)
121 # W_Q/K/V shape: [d_model, n_heads, d_head]
122 W_Q, W_K, W_V = W_split[..., 0, :], W_split[..., 1, :], W_split[..., 2, :]
124 qkv_bias = original_attention_component.query_key_value.bias
126 # Keep mypy happy
127 assert isinstance(qkv_bias, torch.Tensor)
129 # Same interleaved layout for bias: reshape to [n_heads, 3, d_head]
130 qkv_bias = qkv_bias.reshape(self.cfg.n_heads, 3, self.cfg.d_head)
132 # b_Q/K/V shape: [n_heads, d_head]
133 b_Q, b_K, b_V = qkv_bias[:, 0, :], qkv_bias[:, 1, :], qkv_bias[:, 2, :]
135 # Create nn.Linear modules
136 # W_Q shape is [d_model, n_heads, d_head] -> flatten to [d_model, n_heads*d_head]
137 # nn.Linear expects weight shape [out_features, in_features] = [n_heads*d_head, d_model]
138 d_out = self.cfg.n_heads * self.cfg.d_head
140 W_Q_transformation = torch.nn.Linear(self.cfg.d_model, d_out, bias=True)
141 W_Q_transformation.weight = torch.nn.Parameter(W_Q.reshape(self.cfg.d_model, d_out).T)
142 W_Q_transformation.bias = torch.nn.Parameter(b_Q.reshape(d_out))
144 W_K_transformation = torch.nn.Linear(self.cfg.d_model, d_out, bias=True)
145 W_K_transformation.weight = torch.nn.Parameter(W_K.reshape(self.cfg.d_model, d_out).T)
146 W_K_transformation.bias = torch.nn.Parameter(b_K.reshape(d_out))
148 W_V_transformation = torch.nn.Linear(self.cfg.d_model, d_out, bias=True)
149 W_V_transformation.weight = torch.nn.Parameter(W_V.reshape(self.cfg.d_model, d_out).T)
150 W_V_transformation.bias = torch.nn.Parameter(b_V.reshape(d_out))
152 return W_Q_transformation, W_K_transformation, W_V_transformation