transformer_lens.model_bridge.supported_architectures.cohere module

Cohere architecture adapter.

Supports CohereForCausalLM models (Command-R family) with: - Parallel attention+MLP sharing a single input_layernorm (no post_attention_layernorm) - True LayerNorm (CohereLayerNorm) with weight but no bias - GQA (grouped-query attention) with separate Q/K/V/O projections - Gated SwiGLU MLP (gate_proj, up_proj, down_proj) - Logit scaling: output logits multiplied by config.logit_scale (default 1/16) - Tied embed/unembed weights by default (tie_word_embeddings=True) - Interleaved RoPE via CohereRotaryEmbedding (delegated to HF module)

class transformer_lens.model_bridge.supported_architectures.cohere.CohereArchitectureAdapter(cfg: Any)

Bases: ArchitectureAdapter

Architecture adapter for Cohere models (CohereForCausalLM).

Architectural quirks vs. standard decoder-only models: - Single input_layernorm per block; NO post_attention_layernorm.

Attention and MLP both read the SAME normed hidden states (parallel).

  • CohereLayerNorm is true LayerNorm (mean-subtracting), NOT RMSNorm. It has a weight parameter but NO bias parameter.

  • Logit scale: CohereForCausalLM.forward multiplies logits by logit_scale (default 0.0625 = 1/16). Folded into unembed.weight via preprocess_weights.

  • Rotary embeddings use repeat_interleave instead of cat-split (delegated to HF).

Optional parameters (absent from state_dict by default): - blocks.{i}.attn.b_Q/b_K/b_V/b_O — no bias on projections (attention_bias=False) - blocks.{i}.mlp.b_gate/b_in/b_out — no bias on MLP projections - blocks.{i}.ln1.b — CohereLayerNorm has no bias - ln_final.b — CohereLayerNorm has no bias

__init__(cfg: Any) None

Initialize the Cohere architecture adapter.

preprocess_weights(state_dict: dict[str, Tensor]) dict[str, Tensor]

Fold logit_scale into unembed weights before ProcessWeights runs.

bridge.py lines 726-732 clone unembed.weight before calling this, so scaling does not affect the tied embed.weight. logit_scale=1.0 is a no-op (skipped for efficiency).

setup_component_testing(hf_model: Any, bridge_model: Any = None) None

Set rotary embedding reference on attention bridges for component testing.

CohereRotaryEmbedding lives at hf_model.model.rotary_emb. The bridge delegates to it directly, preserving the repeat_interleave RoPE convention without re-implementing it in TL.

Pattern matches llama.py and qwen2.py.