transformer_lens.model_bridge.supported_architectures.cohere module¶
Cohere architecture adapter.
Supports CohereForCausalLM models (Command-R family) with: - Parallel attention+MLP sharing a single input_layernorm (no post_attention_layernorm) - True LayerNorm (CohereLayerNorm) with weight but no bias - GQA (grouped-query attention) with separate Q/K/V/O projections - Gated SwiGLU MLP (gate_proj, up_proj, down_proj) - Logit scaling: output logits multiplied by config.logit_scale (default 1/16) - Tied embed/unembed weights by default (tie_word_embeddings=True) - Interleaved RoPE via CohereRotaryEmbedding (delegated to HF module)
- class transformer_lens.model_bridge.supported_architectures.cohere.CohereArchitectureAdapter(cfg: Any)¶
Bases:
ArchitectureAdapterArchitecture adapter for Cohere models (CohereForCausalLM).
Architectural quirks vs. standard decoder-only models: - Single input_layernorm per block; NO post_attention_layernorm.
Attention and MLP both read the SAME normed hidden states (parallel).
CohereLayerNorm is true LayerNorm (mean-subtracting), NOT RMSNorm. It has a weight parameter but NO bias parameter.
Logit scale: CohereForCausalLM.forward multiplies logits by logit_scale (default 0.0625 = 1/16). Folded into unembed.weight via preprocess_weights.
Rotary embeddings use repeat_interleave instead of cat-split (delegated to HF).
Optional parameters (absent from state_dict by default): - blocks.{i}.attn.b_Q/b_K/b_V/b_O — no bias on projections (attention_bias=False) - blocks.{i}.mlp.b_gate/b_in/b_out — no bias on MLP projections - blocks.{i}.ln1.b — CohereLayerNorm has no bias - ln_final.b — CohereLayerNorm has no bias
- __init__(cfg: Any) None¶
Initialize the Cohere architecture adapter.
- preprocess_weights(state_dict: dict[str, Tensor]) dict[str, Tensor]¶
Fold logit_scale into unembed weights before ProcessWeights runs.
bridge.py lines 726-732 clone unembed.weight before calling this, so scaling does not affect the tied embed.weight. logit_scale=1.0 is a no-op (skipped for efficiency).
- setup_component_testing(hf_model: Any, bridge_model: Any = None) None¶
Set rotary embedding reference on attention bridges for component testing.
CohereRotaryEmbedding lives at hf_model.model.rotary_emb. The bridge delegates to it directly, preserving the repeat_interleave RoPE convention without re-implementing it in TL.
Pattern matches llama.py and qwen2.py.