Coverage for transformer_lens/model_bridge/supported_architectures/mistral.py: 100%

15 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Mistral architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

6from transformer_lens.model_bridge.generalized_components import ( 

7 AttentionBridge, 

8 BlockBridge, 

9 EmbeddingBridge, 

10 GatedMLPBridge, 

11 LinearBridge, 

12 RMSNormalizationBridge, 

13 RotaryEmbeddingBridge, 

14 UnembeddingBridge, 

15) 

16 

17 

18class MistralArchitectureAdapter(ArchitectureAdapter): 

19 """Architecture adapter for Mistral models.""" 

20 

21 def __init__(self, cfg: Any) -> None: 

22 """Initialize the Mistral architecture adapter.""" 

23 super().__init__(cfg) 

24 

25 # Set config variables for weight processing 

26 self.cfg.normalization_type = "RMS" 

27 self.cfg.positional_embedding_type = "rotary" 

28 self.cfg.final_rms = False 

29 self.cfg.gated_mlp = True 

30 self.cfg.attn_only = False 

31 

32 self.default_config = { 

33 "d_model": cfg.d_model, 

34 "d_head": cfg.d_model // cfg.n_heads, 

35 "n_heads": cfg.n_heads, 

36 "n_layers": cfg.n_layers, 

37 "d_vocab": cfg.d_vocab, 

38 "n_key_value_heads": cfg.n_key_value_heads, 

39 } 

40 

41 self.cfg.uses_rms_norm = True 

42 

43 self.weight_processing_conversions = { 

44 **self._qkvo_weight_conversions(), 

45 } 

46 

47 self.component_mapping = { 

48 "embed": EmbeddingBridge(name="model.embed_tokens"), 

49 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

50 "blocks": BlockBridge( 

51 name="model.layers", 

52 submodules={ 

53 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

54 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

55 "attn": AttentionBridge( 

56 name="self_attn", 

57 config=self.cfg, 

58 requires_position_embeddings=True, 

59 requires_attention_mask=True, 

60 submodules={ 

61 "q": LinearBridge(name="q_proj"), 

62 "k": LinearBridge(name="k_proj"), 

63 "v": LinearBridge(name="v_proj"), 

64 "o": LinearBridge(name="o_proj"), 

65 }, 

66 ), 

67 "mlp": GatedMLPBridge( 

68 name="mlp", 

69 submodules={ 

70 "gate": LinearBridge(name="gate_proj"), 

71 "in": LinearBridge(name="up_proj"), 

72 "out": LinearBridge(name="down_proj"), 

73 }, 

74 ), 

75 }, 

76 ), 

77 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

78 "unembed": UnembeddingBridge(name="lm_head"), 

79 }