Coverage for transformer_lens/model_bridge/supported_architectures/cohere.py: 75%
43 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Cohere architecture adapter.
3Supports CohereForCausalLM models (Command-R family) with:
4- Parallel attention+MLP sharing a single input_layernorm (no post_attention_layernorm)
5- True LayerNorm (CohereLayerNorm) with weight but no bias
6- GQA (grouped-query attention) with separate Q/K/V/O projections
7- Gated SwiGLU MLP (gate_proj, up_proj, down_proj)
8- Logit scaling: output logits multiplied by config.logit_scale (default 1/16)
9- Tied embed/unembed weights by default (tie_word_embeddings=True)
10- Interleaved RoPE via CohereRotaryEmbedding (delegated to HF module)
11"""
13from typing import Any
15import torch
17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
18from transformer_lens.model_bridge.generalized_components import (
19 EmbeddingBridge,
20 GatedMLPBridge,
21 LinearBridge,
22 NormalizationBridge,
23 ParallelBlockBridge,
24 PositionEmbeddingsAttentionBridge,
25 RotaryEmbeddingBridge,
26 UnembeddingBridge,
27)
30class CohereArchitectureAdapter(ArchitectureAdapter):
31 """Architecture adapter for Cohere models (CohereForCausalLM).
33 Architectural quirks vs. standard decoder-only models:
34 - Single input_layernorm per block; NO post_attention_layernorm.
35 Attention and MLP both read the SAME normed hidden states (parallel).
36 - CohereLayerNorm is true LayerNorm (mean-subtracting), NOT RMSNorm.
37 It has a weight parameter but NO bias parameter.
38 - Logit scale: CohereForCausalLM.forward multiplies logits by logit_scale
39 (default 0.0625 = 1/16). Folded into unembed.weight via preprocess_weights.
40 - Rotary embeddings use repeat_interleave instead of cat-split (delegated to HF).
42 Optional parameters (absent from state_dict by default):
43 - blocks.{i}.attn.b_Q/b_K/b_V/b_O — no bias on projections (attention_bias=False)
44 - blocks.{i}.mlp.b_gate/b_in/b_out — no bias on MLP projections
45 - blocks.{i}.ln1.b — CohereLayerNorm has no bias
46 - ln_final.b — CohereLayerNorm has no bias
47 """
49 def __init__(self, cfg: Any) -> None:
50 """Initialize the Cohere architecture adapter."""
51 super().__init__(cfg)
53 # --- Normalization ---
54 # CohereLayerNorm is true LayerNorm (subtracts mean), NOT RMSNorm.
55 # uses_rms_norm=False tells NormalizationBridge to subtract the mean.
56 self.cfg.normalization_type = "LN"
57 self.cfg.uses_rms_norm = False
58 self.cfg.final_rms = False
60 # --- Position embeddings and MLP ---
61 self.cfg.positional_embedding_type = "rotary"
62 self.cfg.gated_mlp = True
63 self.cfg.attn_only = False
65 # --- Parallel block: single norm, no post_attention_layernorm ---
66 self.cfg.parallel_attn_mlp = True
68 # --- Tokenizer: BOS is prepended by default ---
69 # CohereTokenizerFast has add_bos_token=False but HF's __call__ with
70 # add_special_tokens=True (the default) prepends BOS. Verified against
71 # trl-internal-testing/tiny-CohereForCausalLM.
72 self.cfg.default_prepend_bos = True
74 # --- GQA: n_key_value_heads ---
75 # sources/transformers.py copies num_key_value_heads generically.
76 # Re-read here to ensure it's set on cfg for _qkvo_weight_conversions.
77 n_kv = getattr(cfg, "n_key_value_heads", None)
78 if n_kv is not None: 78 ↛ 85line 78 didn't jump to line 85 because the condition on line 78 was always true
79 self.cfg.n_key_value_heads = n_kv
81 # --- Weight processing conversions ---
82 # Standard GQA-aware Q/K/V/O rearrangements (same as Llama/Qwen2).
83 # n_kv is already set on self.cfg; _qkvo_weight_conversions reads it via
84 # getattr(self.cfg, "n_key_value_heads", None) when called with no args.
85 self.weight_processing_conversions = {
86 **self._qkvo_weight_conversions(),
87 }
89 # --- Logit scale ---
90 # CohereConfig.logit_scale is typed float | None; apply explicit None-check
91 # so cfg.logit_scale is always a plain float (never None).
92 # logit_scale is not a declared field on TransformerBridgeConfig; it is a
93 # Cohere-specific dynamic attribute accessed later in preprocess_weights.
94 _ls = getattr(cfg, "logit_scale", None)
95 self.cfg.logit_scale = float(_ls) if _ls is not None else 0.0625 # type: ignore[attr-defined]
97 # --- RoPE theta (informational metadata) ---
98 # CohereRotaryEmbedding reads config.rope_parameters["rope_theta"] directly;
99 # store it in cfg.rotary_base so TL config accurately reflects the model.
100 # TransformerBridgeConfig stores rotary_base as int, matching its declared type.
101 _rope_params = getattr(cfg, "rope_parameters", None) or {}
102 if isinstance(_rope_params, dict): 102 ↛ 105line 102 didn't jump to line 105 because the condition on line 102 was always true
103 _theta = _rope_params.get("rope_theta", getattr(cfg, "default_theta", 10000.0))
104 else:
105 _theta = getattr(cfg, "default_theta", 10000.0)
106 self.cfg.rotary_base = int(_theta)
108 # --- Component mapping ---
109 # Block structure follows Falcon's parallel_attn=True, num_ln_in_parallel_attn=1
110 # mode: single ln1 feeds both attn and MLP; NO ln2.
111 # Submodule shapes follow Llama: separate q/k/v/o projections and SwiGLU MLP.
112 # Rotary and attention both delegate to HF modules, preserving Cohere's
113 # repeat_interleave RoPE convention without re-implementing it in TL.
114 self.component_mapping = {
115 # Embedding: model.embed_tokens (same root as Llama, not transformer.* like Falcon)
116 "embed": EmbeddingBridge(name="model.embed_tokens"),
117 # Rotary embedding: top-level, delegates to CohereRotaryEmbedding.
118 # Pattern matches llama.py:75 and falcon.py:154 — NOT inside blocks.
119 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
120 "blocks": ParallelBlockBridge(
121 name="model.layers",
122 submodules={
123 # Single pre-norm only — Cohere has no post_attention_layernorm.
124 # NormalizationBridge handles weight-only CohereLayerNorm correctly:
125 # it checks `hasattr(original_component, "bias") and bias is not None`
126 # before adding bias, so the missing bias attribute is silently skipped.
127 "ln1": NormalizationBridge(name="input_layernorm", config=self.cfg),
128 # No "ln2" — parallel block, same normed input goes to attn AND mlp.
129 "attn": PositionEmbeddingsAttentionBridge(
130 name="self_attn",
131 config=self.cfg,
132 submodules={
133 "q": LinearBridge(name="q_proj"),
134 "k": LinearBridge(name="k_proj"),
135 "v": LinearBridge(name="v_proj"),
136 "o": LinearBridge(name="o_proj"),
137 },
138 requires_attention_mask=True,
139 requires_position_embeddings=True,
140 ),
141 # GatedMLPBridge: gate/in/out matches Llama's gate_proj/up_proj/down_proj.
142 # Optional use_qk_norm is handled transparently by HF's
143 # CohereAttention.forward delegation (no extra submodules needed).
144 "mlp": GatedMLPBridge(
145 name="mlp",
146 config=self.cfg,
147 submodules={
148 "gate": LinearBridge(name="gate_proj"),
149 "in": LinearBridge(name="up_proj"),
150 "out": LinearBridge(name="down_proj"),
151 },
152 ),
153 },
154 ),
155 # Final LayerNorm (CohereLayerNorm, weight-only) at model.norm
156 "ln_final": NormalizationBridge(name="model.norm", config=self.cfg),
157 # Unembed: lm_head. logit_scale is folded into weight in preprocess_weights.
158 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
159 }
161 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
162 """Fold logit_scale into unembed weights before ProcessWeights runs.
164 bridge.py lines 726-732 clone unembed.weight before calling this, so
165 scaling does not affect the tied embed.weight.
166 logit_scale=1.0 is a no-op (skipped for efficiency).
167 """
168 scale: float = getattr(self.cfg, "logit_scale") # always set by __init__
169 if scale != 1.0:
170 for key in ("unembed.weight", "unembed.bias"):
171 if key in state_dict:
172 orig_dtype = state_dict[key].dtype
173 state_dict[key] = (state_dict[key].float() * scale).to(orig_dtype)
174 return state_dict
176 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
177 """Set rotary embedding reference on attention bridges for component testing.
179 CohereRotaryEmbedding lives at hf_model.model.rotary_emb. The bridge
180 delegates to it directly, preserving the repeat_interleave RoPE convention
181 without re-implementing it in TL.
183 Pattern matches llama.py and qwen2.py.
184 """
185 rotary_emb = hf_model.model.rotary_emb
187 # Set on actual bridge instances in the live model (if available)
188 if bridge_model is not None and hasattr(bridge_model, "blocks"):
189 for block in bridge_model.blocks:
190 if hasattr(block, "attn"):
191 block.attn.set_rotary_emb(rotary_emb)
193 # Also set on the template so get_generalized_component() calls work
194 attn_bridge = self.get_generalized_component("blocks.0.attn")
195 attn_bridge.set_rotary_emb(rotary_emb)