Coverage for transformer_lens/model_bridge/supported_architectures/granite.py: 61%
41 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Granite architecture adapter.
3Base adapter for the IBM Granite model family. Provides shared config setup and
4helper methods used by GraniteMoe and GraniteMoeHybrid variants.
5"""
7from typing import Any
9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
10from transformer_lens.model_bridge.generalized_components import (
11 BlockBridge,
12 EmbeddingBridge,
13 GatedMLPBridge,
14 LinearBridge,
15 PositionEmbeddingsAttentionBridge,
16 RMSNormalizationBridge,
17 RotaryEmbeddingBridge,
18 UnembeddingBridge,
19)
22class GraniteArchitectureAdapter(ArchitectureAdapter):
23 """Architecture adapter for IBM Granite models (dense).
25 Granite is a Llama-like architecture with RMSNorm, rotary position embeddings
26 (RoPE), GQA, and a gated MLP (SiLU activation). Granite-specific scaling
27 multipliers are handled by the HF model's native forward pass.
29 Optional Parameters (may not exist in state_dict):
30 -------------------------------------------------
31 Granite models do NOT have biases on attention and MLP projections:
33 - blocks.{i}.attn.b_Q/b_K/b_V/b_O - No bias on attention projections
34 - blocks.{i}.mlp.b_in/b_gate/b_out - No bias on MLP projections
35 - blocks.{i}.ln1.b, blocks.{i}.ln2.b, ln_final.b - RMSNorm has no bias
36 """
38 def __init__(self, cfg: Any) -> None:
39 """Initialize the Granite architecture adapter."""
40 super().__init__(cfg)
42 self._setup_common_config(cfg)
43 self.weight_processing_conversions = {**self._qkvo_weight_conversions()}
44 self.component_mapping = self._build_component_mapping()
46 def _setup_common_config(self, cfg: Any) -> None:
47 """Set up config variables shared across all Granite variants."""
48 self.cfg.normalization_type = "RMS"
49 self.cfg.positional_embedding_type = "rotary"
50 self.cfg.final_rms = True
51 self.cfg.gated_mlp = True
52 self.cfg.attn_only = False
53 self.cfg.uses_rms_norm = True
54 self.cfg.default_prepend_bos = False
55 self.cfg.eps_attr = "variance_epsilon"
57 self.default_config = {
58 "d_model": cfg.d_model,
59 "d_head": cfg.d_model // cfg.n_heads,
60 "n_heads": cfg.n_heads,
61 "n_layers": cfg.n_layers,
62 "d_vocab": cfg.d_vocab,
63 }
65 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 65 ↛ exitline 65 didn't return from function '_setup_common_config' because the condition on line 65 was always true
66 self.default_config["n_key_value_heads"] = cfg.n_key_value_heads
67 self.cfg.n_key_value_heads = cfg.n_key_value_heads
69 def _build_attention_bridge(self, optional: bool = False) -> PositionEmbeddingsAttentionBridge:
70 """Build the standard Granite attention bridge."""
71 return PositionEmbeddingsAttentionBridge(
72 name="self_attn",
73 config=self.cfg,
74 optional=optional,
75 submodules={
76 "q": LinearBridge(name="q_proj"),
77 "k": LinearBridge(name="k_proj"),
78 "v": LinearBridge(name="v_proj"),
79 "o": LinearBridge(name="o_proj"),
80 },
81 requires_attention_mask=True,
82 requires_position_embeddings=True,
83 )
85 def _build_mlp_bridge(self) -> GatedMLPBridge:
86 """Build the dense gated MLP bridge."""
87 return GatedMLPBridge(
88 name="mlp",
89 config=self.cfg,
90 submodules={
91 "gate": LinearBridge(name="gate_proj"),
92 "in": LinearBridge(name="up_proj"),
93 "out": LinearBridge(name="down_proj"),
94 },
95 )
97 def _build_component_mapping(self) -> dict:
98 """Build the full component mapping for dense Granite."""
99 return {
100 "embed": EmbeddingBridge(name="model.embed_tokens"),
101 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
102 "blocks": BlockBridge(
103 name="model.layers",
104 submodules={
105 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
106 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
107 "attn": self._build_attention_bridge(),
108 "mlp": self._build_mlp_bridge(),
109 },
110 ),
111 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
112 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
113 }
115 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
116 """Set up rotary embedding references for Granite component testing.
118 Args:
119 hf_model: The HuggingFace Granite model instance
120 bridge_model: The TransformerBridge model (if available)
121 """
122 if not hasattr(hf_model.model, "rotary_emb"):
123 return
125 rotary_emb = hf_model.model.rotary_emb
127 if bridge_model is not None and hasattr(bridge_model, "blocks"):
128 for block in bridge_model.blocks:
129 if "attn" in block._modules:
130 block.attn.set_rotary_emb(rotary_emb)
132 try:
133 attn_bridge = self.get_generalized_component("blocks.0.attn")
134 attn_bridge.set_rotary_emb(rotary_emb)
135 except (AttributeError, KeyError, ValueError):
136 pass