Coverage for transformer_lens/model_bridge/supported_architectures/llama.py: 63%
27 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Llama architecture adapter."""
3from typing import Any
5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
6from transformer_lens.model_bridge.generalized_components import (
7 BlockBridge,
8 EmbeddingBridge,
9 GatedMLPBridge,
10 LinearBridge,
11 PositionEmbeddingsAttentionBridge,
12 RMSNormalizationBridge,
13 RotaryEmbeddingBridge,
14 UnembeddingBridge,
15)
18class LlamaArchitectureAdapter(ArchitectureAdapter):
19 """Architecture adapter for Llama models.
21 Optional Parameters (may not exist in state_dict):
22 -------------------------------------------------
23 LLaMA models do NOT have biases on attention and MLP projections:
25 - blocks.{i}.attn.b_Q - No bias on query projection
26 - blocks.{i}.attn.b_K - No bias on key projection
27 - blocks.{i}.attn.b_V - No bias on value projection
28 - blocks.{i}.attn.b_O - No bias on output projection
29 - blocks.{i}.mlp.b_in - No bias on MLP input (up_proj)
30 - blocks.{i}.mlp.b_gate - No bias on MLP gate projection
31 - blocks.{i}.mlp.b_out - No bias on MLP output (down_proj)
32 - blocks.{i}.ln1.b - RMSNorm has no bias
33 - blocks.{i}.ln2.b - RMSNorm has no bias
34 - ln_final.b - RMSNorm has no bias
36 Weight processing must handle these missing biases gracefully using
37 ProcessWeights._safe_get_tensor() or by checking for None values.
38 """
40 def __init__(self, cfg: Any) -> None:
41 """Initialize the Llama architecture adapter."""
42 super().__init__(cfg)
44 # Set config variables for weight processing
45 self.cfg.normalization_type = "RMS"
46 self.cfg.positional_embedding_type = "rotary"
47 self.cfg.final_rms = True
48 self.cfg.gated_mlp = True
49 self.cfg.attn_only = False
51 self.default_config = {
52 "d_model": cfg.d_model,
53 "d_head": cfg.d_model // cfg.n_heads,
54 "n_heads": cfg.n_heads,
55 "n_layers": cfg.n_layers,
56 "d_vocab": cfg.d_vocab,
57 }
59 # Add GQA support for Llama 3.1, 3.2, and later models
60 # Must set directly on cfg, not just in default_config
61 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 61 ↛ 65line 61 didn't jump to line 65 because the condition on line 61 was always true
62 self.default_config["n_key_value_heads"] = cfg.n_key_value_heads
63 self.cfg.n_key_value_heads = cfg.n_key_value_heads
65 self.cfg.uses_rms_norm = True
66 # Llama uses 'variance_epsilon' instead of 'eps' for RMSNorm
67 self.cfg.eps_attr = "variance_epsilon"
69 self.weight_processing_conversions = {
70 **self._qkvo_weight_conversions(),
71 }
73 self.component_mapping = {
74 "embed": EmbeddingBridge(name="model.embed_tokens"),
75 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
76 "blocks": BlockBridge(
77 name="model.layers",
78 submodules={
79 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
80 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
81 "attn": PositionEmbeddingsAttentionBridge(
82 name="self_attn",
83 config=self.cfg,
84 submodules={
85 "q": LinearBridge(name="q_proj"),
86 "k": LinearBridge(name="k_proj"),
87 "v": LinearBridge(name="v_proj"),
88 "o": LinearBridge(name="o_proj"),
89 },
90 requires_attention_mask=True,
91 requires_position_embeddings=True,
92 ),
93 "mlp": GatedMLPBridge(
94 name="mlp",
95 config=self.cfg,
96 submodules={
97 "gate": LinearBridge(name="gate_proj"),
98 "in": LinearBridge(name="up_proj"),
99 "out": LinearBridge(name="down_proj"),
100 },
101 ),
102 },
103 ),
104 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
105 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
106 }
108 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
109 """Set up rotary embedding references for Llama component testing.
111 Llama uses RoPE (Rotary Position Embeddings). We set the rotary_emb reference
112 on all attention bridge instances for component testing.
114 Args:
115 hf_model: The HuggingFace Llama model instance
116 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances)
117 """
118 # Get rotary embedding instance from the model
119 rotary_emb = hf_model.model.rotary_emb
121 # Set rotary_emb on actual bridge instances in bridge_model if available
122 if bridge_model is not None and hasattr(bridge_model, "blocks"):
123 # Set on each layer's actual attention bridge instance
124 for block in bridge_model.blocks:
125 if hasattr(block, "attn"):
126 block.attn.set_rotary_emb(rotary_emb)
128 # Also set on the template for get_generalized_component() calls
129 attn_bridge = self.get_generalized_component("blocks.0.attn")
130 attn_bridge.set_rotary_emb(rotary_emb)