Coverage for transformer_lens/model_bridge/supported_architectures/cohere.py: 75%
44 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Cohere architecture adapter.
3Supports CohereForCausalLM models (Command-R family) with:
4- Parallel attention+MLP sharing a single input_layernorm (no post_attention_layernorm)
5- True LayerNorm (CohereLayerNorm) with weight but no bias
6- GQA (grouped-query attention) with separate Q/K/V/O projections
7- Gated SwiGLU MLP (gate_proj, up_proj, down_proj)
8- Logit scaling: output logits multiplied by config.logit_scale (default 1/16)
9- Tied embed/unembed weights by default (tie_word_embeddings=True)
10- Interleaved RoPE via CohereRotaryEmbedding (delegated to HF module)
11"""
13from typing import Any
15import torch
17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
18from transformer_lens.model_bridge.generalized_components import (
19 EmbeddingBridge,
20 GatedMLPBridge,
21 LinearBridge,
22 NormalizationBridge,
23 ParallelBlockBridge,
24 PositionEmbeddingsAttentionBridge,
25 RotaryEmbeddingBridge,
26 UnembeddingBridge,
27)
30class CohereArchitectureAdapter(ArchitectureAdapter):
31 """Architecture adapter for Cohere models (CohereForCausalLM).
33 Architectural quirks vs. standard decoder-only models:
34 - Single input_layernorm per block; NO post_attention_layernorm.
35 Attention and MLP both read the SAME normed hidden states (parallel).
36 - CohereLayerNorm is true LayerNorm (mean-subtracting), NOT RMSNorm.
37 It has a weight parameter but NO bias parameter.
38 - Logit scale: CohereForCausalLM.forward multiplies logits by logit_scale
39 (default 0.0625 = 1/16). Folded into unembed.weight via preprocess_weights.
40 - Rotary embeddings use repeat_interleave instead of cat-split (delegated to HF).
42 Optional parameters (absent from state_dict by default):
43 - blocks.{i}.attn.b_Q/b_K/b_V/b_O — no bias on projections (attention_bias=False)
44 - blocks.{i}.mlp.b_gate/b_in/b_out — no bias on MLP projections
45 - blocks.{i}.ln1.b — CohereLayerNorm has no bias
46 - ln_final.b — CohereLayerNorm has no bias
47 """
49 def __init__(self, cfg: Any) -> None:
50 """Initialize the Cohere architecture adapter."""
51 super().__init__(cfg)
53 # --- Normalization ---
54 # CohereLayerNorm is true LayerNorm (subtracts mean), NOT RMSNorm.
55 # uses_rms_norm=False tells NormalizationBridge to subtract the mean.
56 # eps_attr="variance_epsilon": CohereLayerNorm stores eps as self.variance_epsilon.
57 self.cfg.normalization_type = "LN"
58 self.cfg.uses_rms_norm = False
59 self.cfg.eps_attr = "variance_epsilon"
60 self.cfg.final_rms = False
62 # --- Position embeddings and MLP ---
63 self.cfg.positional_embedding_type = "rotary"
64 self.cfg.gated_mlp = True
65 self.cfg.attn_only = False
67 # --- Parallel block: single norm, no post_attention_layernorm ---
68 self.cfg.parallel_attn_mlp = True
70 # --- Tokenizer: BOS is prepended by default ---
71 # CohereTokenizerFast has add_bos_token=False but HF's __call__ with
72 # add_special_tokens=True (the default) prepends BOS. Verified against
73 # trl-internal-testing/tiny-CohereForCausalLM.
74 self.cfg.default_prepend_bos = True
76 # --- GQA: n_key_value_heads ---
77 # sources/transformers.py copies num_key_value_heads generically.
78 # Re-read here to ensure it's set on cfg for _qkvo_weight_conversions.
79 n_kv = getattr(cfg, "n_key_value_heads", None)
80 if n_kv is not None: 80 ↛ 87line 80 didn't jump to line 87 because the condition on line 80 was always true
81 self.cfg.n_key_value_heads = n_kv
83 # --- Weight processing conversions ---
84 # Standard GQA-aware Q/K/V/O rearrangements (same as Llama/Qwen2).
85 # n_kv is already set on self.cfg; _qkvo_weight_conversions reads it via
86 # getattr(self.cfg, "n_key_value_heads", None) when called with no args.
87 self.weight_processing_conversions = {
88 **self._qkvo_weight_conversions(),
89 }
91 # --- Logit scale ---
92 # CohereConfig.logit_scale is typed float | None; apply explicit None-check
93 # so cfg.logit_scale is always a plain float (never None).
94 # logit_scale is not a declared field on TransformerBridgeConfig; it is a
95 # Cohere-specific dynamic attribute accessed later in preprocess_weights.
96 _ls = getattr(cfg, "logit_scale", None)
97 self.cfg.logit_scale = float(_ls) if _ls is not None else 0.0625 # type: ignore[attr-defined]
99 # --- RoPE theta (informational metadata) ---
100 # CohereRotaryEmbedding reads config.rope_parameters["rope_theta"] directly;
101 # store it in cfg.rotary_base so TL config accurately reflects the model.
102 # TransformerBridgeConfig stores rotary_base as int, matching its declared type.
103 _rope_params = getattr(cfg, "rope_parameters", None) or {}
104 if isinstance(_rope_params, dict): 104 ↛ 107line 104 didn't jump to line 107 because the condition on line 104 was always true
105 _theta = _rope_params.get("rope_theta", getattr(cfg, "default_theta", 10000.0))
106 else:
107 _theta = getattr(cfg, "default_theta", 10000.0)
108 self.cfg.rotary_base = int(_theta)
110 # --- Component mapping ---
111 # Block structure follows Falcon's parallel_attn=True, num_ln_in_parallel_attn=1
112 # mode: single ln1 feeds both attn and MLP; NO ln2.
113 # Submodule shapes follow Llama: separate q/k/v/o projections and SwiGLU MLP.
114 # Rotary and attention both delegate to HF modules, preserving Cohere's
115 # repeat_interleave RoPE convention without re-implementing it in TL.
116 self.component_mapping = {
117 # Embedding: model.embed_tokens (same root as Llama, not transformer.* like Falcon)
118 "embed": EmbeddingBridge(name="model.embed_tokens"),
119 # Rotary embedding: top-level, delegates to CohereRotaryEmbedding.
120 # Pattern matches llama.py:75 and falcon.py:154 — NOT inside blocks.
121 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
122 "blocks": ParallelBlockBridge(
123 name="model.layers",
124 submodules={
125 # Single pre-norm only — Cohere has no post_attention_layernorm.
126 # NormalizationBridge handles weight-only CohereLayerNorm correctly:
127 # it checks `hasattr(original_component, "bias") and bias is not None`
128 # before adding bias, so the missing bias attribute is silently skipped.
129 "ln1": NormalizationBridge(name="input_layernorm", config=self.cfg),
130 # No "ln2" — parallel block, same normed input goes to attn AND mlp.
131 "attn": PositionEmbeddingsAttentionBridge(
132 name="self_attn",
133 config=self.cfg,
134 submodules={
135 "q": LinearBridge(name="q_proj"),
136 "k": LinearBridge(name="k_proj"),
137 "v": LinearBridge(name="v_proj"),
138 "o": LinearBridge(name="o_proj"),
139 },
140 requires_attention_mask=True,
141 requires_position_embeddings=True,
142 ),
143 # GatedMLPBridge: gate/in/out matches Llama's gate_proj/up_proj/down_proj.
144 # Optional use_qk_norm is handled transparently by HF's
145 # CohereAttention.forward delegation (no extra submodules needed).
146 "mlp": GatedMLPBridge(
147 name="mlp",
148 config=self.cfg,
149 submodules={
150 "gate": LinearBridge(name="gate_proj"),
151 "in": LinearBridge(name="up_proj"),
152 "out": LinearBridge(name="down_proj"),
153 },
154 ),
155 },
156 ),
157 # Final LayerNorm (CohereLayerNorm, weight-only) at model.norm
158 "ln_final": NormalizationBridge(name="model.norm", config=self.cfg),
159 # Unembed: lm_head. logit_scale is folded into weight in preprocess_weights.
160 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
161 }
163 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
164 """Fold logit_scale into unembed weights before ProcessWeights runs.
166 bridge.py lines 726-732 clone unembed.weight before calling this, so
167 scaling does not affect the tied embed.weight.
168 logit_scale=1.0 is a no-op (skipped for efficiency).
169 """
170 scale: float = getattr(self.cfg, "logit_scale") # always set by __init__
171 if scale != 1.0:
172 for key in ("unembed.weight", "unembed.bias"):
173 if key in state_dict:
174 orig_dtype = state_dict[key].dtype
175 state_dict[key] = (state_dict[key].float() * scale).to(orig_dtype)
176 return state_dict
178 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
179 """Set rotary embedding reference on attention bridges for component testing.
181 CohereRotaryEmbedding lives at hf_model.model.rotary_emb. The bridge
182 delegates to it directly, preserving the repeat_interleave RoPE convention
183 without re-implementing it in TL.
185 Pattern matches llama.py and qwen2.py.
186 """
187 rotary_emb = hf_model.model.rotary_emb
189 # Set on actual bridge instances in the live model (if available)
190 if bridge_model is not None and hasattr(bridge_model, "blocks"):
191 for block in bridge_model.blocks:
192 if hasattr(block, "attn"):
193 block.attn.set_rotary_emb(rotary_emb)
195 # Also set on the template so get_generalized_component() calls work
196 attn_bridge = self.get_generalized_component("blocks.0.attn")
197 attn_bridge.set_rotary_emb(rotary_emb)