Coverage for transformer_lens/model_bridge/supported_architectures/mistral.py: 100%
15 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Mistral architecture adapter."""
3from typing import Any
5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
6from transformer_lens.model_bridge.generalized_components import (
7 AttentionBridge,
8 BlockBridge,
9 EmbeddingBridge,
10 GatedMLPBridge,
11 LinearBridge,
12 RMSNormalizationBridge,
13 RotaryEmbeddingBridge,
14 UnembeddingBridge,
15)
18class MistralArchitectureAdapter(ArchitectureAdapter):
19 """Architecture adapter for Mistral models."""
21 def __init__(self, cfg: Any) -> None:
22 """Initialize the Mistral architecture adapter."""
23 super().__init__(cfg)
25 # Set config variables for weight processing
26 self.cfg.normalization_type = "RMS"
27 self.cfg.positional_embedding_type = "rotary"
28 self.cfg.final_rms = False
29 self.cfg.gated_mlp = True
30 self.cfg.attn_only = False
32 self.default_config = {
33 "d_model": cfg.d_model,
34 "d_head": cfg.d_model // cfg.n_heads,
35 "n_heads": cfg.n_heads,
36 "n_layers": cfg.n_layers,
37 "d_vocab": cfg.d_vocab,
38 "n_key_value_heads": cfg.n_key_value_heads,
39 }
41 self.cfg.uses_rms_norm = True
43 self.weight_processing_conversions = {
44 **self._qkvo_weight_conversions(),
45 }
47 self.component_mapping = {
48 "embed": EmbeddingBridge(name="model.embed_tokens"),
49 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
50 "blocks": BlockBridge(
51 name="model.layers",
52 submodules={
53 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
54 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
55 "attn": AttentionBridge(
56 name="self_attn",
57 config=self.cfg,
58 requires_position_embeddings=True,
59 requires_attention_mask=True,
60 submodules={
61 "q": LinearBridge(name="q_proj"),
62 "k": LinearBridge(name="k_proj"),
63 "v": LinearBridge(name="v_proj"),
64 "o": LinearBridge(name="o_proj"),
65 },
66 ),
67 "mlp": GatedMLPBridge(
68 name="mlp",
69 submodules={
70 "gate": LinearBridge(name="gate_proj"),
71 "in": LinearBridge(name="up_proj"),
72 "out": LinearBridge(name="down_proj"),
73 },
74 ),
75 },
76 ),
77 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
78 "unembed": UnembeddingBridge(name="lm_head"),
79 }