Coverage for transformer_lens/model_bridge/supported_architectures/mpt.py: 100%

26 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""MPT (MPTForCausalLM) adapter — ALiBi, fused Wqkv, weight-only LayerNorm, no biases.""" 

2 

3from typing import Any 

4 

5import torch 

6import torch.nn as nn 

7 

8from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

9from transformer_lens.model_bridge.generalized_components import ( 

10 BlockBridge, 

11 EmbeddingBridge, 

12 LinearBridge, 

13 MLPBridge, 

14 NormalizationBridge, 

15 UnembeddingBridge, 

16) 

17from transformer_lens.model_bridge.generalized_components.mpt_alibi_attention import ( 

18 MPTALiBiAttentionBridge, 

19) 

20 

21 

22class MPTArchitectureAdapter(ArchitectureAdapter): 

23 """MPT adapter: ALiBi bias; all layers bias-free (no b_Q/b_K/b_V/b_O/b_in/b_out/ln bias).""" 

24 

25 def __init__(self, cfg: Any) -> None: 

26 super().__init__(cfg) 

27 

28 self.cfg.normalization_type = "LN" 

29 self.cfg.positional_embedding_type = "alibi" 

30 self.cfg.final_rms = False 

31 self.cfg.gated_mlp = False 

32 self.cfg.attn_only = False 

33 self.cfg.default_prepend_bos = False 

34 

35 # Pure MHA: split_qkv yields [d_model, d_model] per head; standard rearrangements apply. 

36 self.weight_processing_conversions = { 

37 **self._qkvo_weight_conversions(), 

38 } 

39 

40 self.component_mapping = { 

41 "embed": EmbeddingBridge(name="transformer.wte"), 

42 "blocks": BlockBridge( 

43 name="transformer.blocks", 

44 submodules={ 

45 "ln1": NormalizationBridge(name="norm_1", config=self.cfg), 

46 "attn": MPTALiBiAttentionBridge( 

47 name="attn", 

48 config=self.cfg, 

49 split_qkv_matrix=self._split_mpt_qkv, 

50 submodules={ 

51 "qkv": LinearBridge(name="Wqkv"), 

52 "o": LinearBridge(name="out_proj"), 

53 }, 

54 ), 

55 "ln2": NormalizationBridge(name="norm_2", config=self.cfg), 

56 "mlp": MLPBridge( 

57 name="ffn", 

58 submodules={ 

59 "in": LinearBridge(name="up_proj"), 

60 "out": LinearBridge(name="down_proj"), 

61 }, 

62 ), 

63 }, 

64 ), 

65 "ln_final": NormalizationBridge(name="transformer.norm_f", config=self.cfg), 

66 "unembed": UnembeddingBridge(name="lm_head"), 

67 } 

68 

69 def _split_mpt_qkv(self, attn_component: Any) -> tuple[nn.Linear, nn.Linear, nn.Linear]: 

70 """Split fused Wqkv into Q, K, V — row-wise chunk (NOT interleaved like BLOOM).""" 

71 w = attn_component.Wqkv.weight.detach().clone() 

72 w_q, w_k, w_v = torch.chunk(w, 3, dim=0) 

73 d_model = self.cfg.d_model 

74 

75 def make_linear(weight: torch.Tensor) -> nn.Linear: 

76 lin = nn.Linear(d_model, d_model, bias=False, device=weight.device, dtype=weight.dtype) 

77 lin.weight = nn.Parameter(weight.contiguous()) 

78 return lin 

79 

80 return make_linear(w_q), make_linear(w_k), make_linear(w_v)