Coverage for transformer_lens/model_bridge/supported_architectures/mpt.py: 100%
26 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""MPT (MPTForCausalLM) adapter — ALiBi, fused Wqkv, weight-only LayerNorm, no biases."""
3from typing import Any
5import torch
6import torch.nn as nn
8from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
9from transformer_lens.model_bridge.generalized_components import (
10 BlockBridge,
11 EmbeddingBridge,
12 LinearBridge,
13 MLPBridge,
14 NormalizationBridge,
15 UnembeddingBridge,
16)
17from transformer_lens.model_bridge.generalized_components.mpt_alibi_attention import (
18 MPTALiBiAttentionBridge,
19)
22class MPTArchitectureAdapter(ArchitectureAdapter):
23 """MPT adapter: ALiBi bias; all layers bias-free (no b_Q/b_K/b_V/b_O/b_in/b_out/ln bias)."""
25 def __init__(self, cfg: Any) -> None:
26 super().__init__(cfg)
28 self.cfg.normalization_type = "LN"
29 self.cfg.positional_embedding_type = "alibi"
30 self.cfg.final_rms = False
31 self.cfg.gated_mlp = False
32 self.cfg.attn_only = False
33 self.cfg.default_prepend_bos = False
35 # Pure MHA: split_qkv yields [d_model, d_model] per head; standard rearrangements apply.
36 self.weight_processing_conversions = {
37 **self._qkvo_weight_conversions(),
38 }
40 self.component_mapping = {
41 "embed": EmbeddingBridge(name="transformer.wte"),
42 "blocks": BlockBridge(
43 name="transformer.blocks",
44 submodules={
45 "ln1": NormalizationBridge(name="norm_1", config=self.cfg),
46 "attn": MPTALiBiAttentionBridge(
47 name="attn",
48 config=self.cfg,
49 split_qkv_matrix=self._split_mpt_qkv,
50 submodules={
51 "qkv": LinearBridge(name="Wqkv"),
52 "o": LinearBridge(name="out_proj"),
53 },
54 ),
55 "ln2": NormalizationBridge(name="norm_2", config=self.cfg),
56 "mlp": MLPBridge(
57 name="ffn",
58 submodules={
59 "in": LinearBridge(name="up_proj"),
60 "out": LinearBridge(name="down_proj"),
61 },
62 ),
63 },
64 ),
65 "ln_final": NormalizationBridge(name="transformer.norm_f", config=self.cfg),
66 "unembed": UnembeddingBridge(name="lm_head"),
67 }
69 def _split_mpt_qkv(self, attn_component: Any) -> tuple[nn.Linear, nn.Linear, nn.Linear]:
70 """Split fused Wqkv into Q, K, V — row-wise chunk (NOT interleaved like BLOOM)."""
71 w = attn_component.Wqkv.weight.detach().clone()
72 w_q, w_k, w_v = torch.chunk(w, 3, dim=0)
73 d_model = self.cfg.d_model
75 def make_linear(weight: torch.Tensor) -> nn.Linear:
76 lin = nn.Linear(d_model, d_model, bias=False, device=weight.device, dtype=weight.dtype)
77 lin.weight = nn.Parameter(weight.contiguous())
78 return lin
80 return make_linear(w_q), make_linear(w_k), make_linear(w_v)