Coverage for transformer_lens/model_bridge/supported_architectures/cohere.py: 75%

43 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Cohere architecture adapter. 

2 

3Supports CohereForCausalLM models (Command-R family) with: 

4- Parallel attention+MLP sharing a single input_layernorm (no post_attention_layernorm) 

5- True LayerNorm (CohereLayerNorm) with weight but no bias 

6- GQA (grouped-query attention) with separate Q/K/V/O projections 

7- Gated SwiGLU MLP (gate_proj, up_proj, down_proj) 

8- Logit scaling: output logits multiplied by config.logit_scale (default 1/16) 

9- Tied embed/unembed weights by default (tie_word_embeddings=True) 

10- Interleaved RoPE via CohereRotaryEmbedding (delegated to HF module) 

11""" 

12 

13from typing import Any 

14 

15import torch 

16 

17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

18from transformer_lens.model_bridge.generalized_components import ( 

19 EmbeddingBridge, 

20 GatedMLPBridge, 

21 LinearBridge, 

22 NormalizationBridge, 

23 ParallelBlockBridge, 

24 PositionEmbeddingsAttentionBridge, 

25 RotaryEmbeddingBridge, 

26 UnembeddingBridge, 

27) 

28 

29 

30class CohereArchitectureAdapter(ArchitectureAdapter): 

31 """Architecture adapter for Cohere models (CohereForCausalLM). 

32 

33 Architectural quirks vs. standard decoder-only models: 

34 - Single input_layernorm per block; NO post_attention_layernorm. 

35 Attention and MLP both read the SAME normed hidden states (parallel). 

36 - CohereLayerNorm is true LayerNorm (mean-subtracting), NOT RMSNorm. 

37 It has a weight parameter but NO bias parameter. 

38 - Logit scale: CohereForCausalLM.forward multiplies logits by logit_scale 

39 (default 0.0625 = 1/16). Folded into unembed.weight via preprocess_weights. 

40 - Rotary embeddings use repeat_interleave instead of cat-split (delegated to HF). 

41 

42 Optional parameters (absent from state_dict by default): 

43 - blocks.{i}.attn.b_Q/b_K/b_V/b_O — no bias on projections (attention_bias=False) 

44 - blocks.{i}.mlp.b_gate/b_in/b_out — no bias on MLP projections 

45 - blocks.{i}.ln1.b — CohereLayerNorm has no bias 

46 - ln_final.b — CohereLayerNorm has no bias 

47 """ 

48 

49 def __init__(self, cfg: Any) -> None: 

50 """Initialize the Cohere architecture adapter.""" 

51 super().__init__(cfg) 

52 

53 # --- Normalization --- 

54 # CohereLayerNorm is true LayerNorm (subtracts mean), NOT RMSNorm. 

55 # uses_rms_norm=False tells NormalizationBridge to subtract the mean. 

56 self.cfg.normalization_type = "LN" 

57 self.cfg.uses_rms_norm = False 

58 self.cfg.final_rms = False 

59 

60 # --- Position embeddings and MLP --- 

61 self.cfg.positional_embedding_type = "rotary" 

62 self.cfg.gated_mlp = True 

63 self.cfg.attn_only = False 

64 

65 # --- Parallel block: single norm, no post_attention_layernorm --- 

66 self.cfg.parallel_attn_mlp = True 

67 

68 # --- Tokenizer: BOS is prepended by default --- 

69 # CohereTokenizerFast has add_bos_token=False but HF's __call__ with 

70 # add_special_tokens=True (the default) prepends BOS. Verified against 

71 # trl-internal-testing/tiny-CohereForCausalLM. 

72 self.cfg.default_prepend_bos = True 

73 

74 # --- GQA: n_key_value_heads --- 

75 # sources/transformers.py copies num_key_value_heads generically. 

76 # Re-read here to ensure it's set on cfg for _qkvo_weight_conversions. 

77 n_kv = getattr(cfg, "n_key_value_heads", None) 

78 if n_kv is not None: 78 ↛ 85line 78 didn't jump to line 85 because the condition on line 78 was always true

79 self.cfg.n_key_value_heads = n_kv 

80 

81 # --- Weight processing conversions --- 

82 # Standard GQA-aware Q/K/V/O rearrangements (same as Llama/Qwen2). 

83 # n_kv is already set on self.cfg; _qkvo_weight_conversions reads it via 

84 # getattr(self.cfg, "n_key_value_heads", None) when called with no args. 

85 self.weight_processing_conversions = { 

86 **self._qkvo_weight_conversions(), 

87 } 

88 

89 # --- Logit scale --- 

90 # CohereConfig.logit_scale is typed float | None; apply explicit None-check 

91 # so cfg.logit_scale is always a plain float (never None). 

92 # logit_scale is not a declared field on TransformerBridgeConfig; it is a 

93 # Cohere-specific dynamic attribute accessed later in preprocess_weights. 

94 _ls = getattr(cfg, "logit_scale", None) 

95 self.cfg.logit_scale = float(_ls) if _ls is not None else 0.0625 # type: ignore[attr-defined] 

96 

97 # --- RoPE theta (informational metadata) --- 

98 # CohereRotaryEmbedding reads config.rope_parameters["rope_theta"] directly; 

99 # store it in cfg.rotary_base so TL config accurately reflects the model. 

100 # TransformerBridgeConfig stores rotary_base as int, matching its declared type. 

101 _rope_params = getattr(cfg, "rope_parameters", None) or {} 

102 if isinstance(_rope_params, dict): 102 ↛ 105line 102 didn't jump to line 105 because the condition on line 102 was always true

103 _theta = _rope_params.get("rope_theta", getattr(cfg, "default_theta", 10000.0)) 

104 else: 

105 _theta = getattr(cfg, "default_theta", 10000.0) 

106 self.cfg.rotary_base = int(_theta) 

107 

108 # --- Component mapping --- 

109 # Block structure follows Falcon's parallel_attn=True, num_ln_in_parallel_attn=1 

110 # mode: single ln1 feeds both attn and MLP; NO ln2. 

111 # Submodule shapes follow Llama: separate q/k/v/o projections and SwiGLU MLP. 

112 # Rotary and attention both delegate to HF modules, preserving Cohere's 

113 # repeat_interleave RoPE convention without re-implementing it in TL. 

114 self.component_mapping = { 

115 # Embedding: model.embed_tokens (same root as Llama, not transformer.* like Falcon) 

116 "embed": EmbeddingBridge(name="model.embed_tokens"), 

117 # Rotary embedding: top-level, delegates to CohereRotaryEmbedding. 

118 # Pattern matches llama.py:75 and falcon.py:154 — NOT inside blocks. 

119 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

120 "blocks": ParallelBlockBridge( 

121 name="model.layers", 

122 submodules={ 

123 # Single pre-norm only — Cohere has no post_attention_layernorm. 

124 # NormalizationBridge handles weight-only CohereLayerNorm correctly: 

125 # it checks `hasattr(original_component, "bias") and bias is not None` 

126 # before adding bias, so the missing bias attribute is silently skipped. 

127 "ln1": NormalizationBridge(name="input_layernorm", config=self.cfg), 

128 # No "ln2" — parallel block, same normed input goes to attn AND mlp. 

129 "attn": PositionEmbeddingsAttentionBridge( 

130 name="self_attn", 

131 config=self.cfg, 

132 submodules={ 

133 "q": LinearBridge(name="q_proj"), 

134 "k": LinearBridge(name="k_proj"), 

135 "v": LinearBridge(name="v_proj"), 

136 "o": LinearBridge(name="o_proj"), 

137 }, 

138 requires_attention_mask=True, 

139 requires_position_embeddings=True, 

140 ), 

141 # GatedMLPBridge: gate/in/out matches Llama's gate_proj/up_proj/down_proj. 

142 # Optional use_qk_norm is handled transparently by HF's 

143 # CohereAttention.forward delegation (no extra submodules needed). 

144 "mlp": GatedMLPBridge( 

145 name="mlp", 

146 config=self.cfg, 

147 submodules={ 

148 "gate": LinearBridge(name="gate_proj"), 

149 "in": LinearBridge(name="up_proj"), 

150 "out": LinearBridge(name="down_proj"), 

151 }, 

152 ), 

153 }, 

154 ), 

155 # Final LayerNorm (CohereLayerNorm, weight-only) at model.norm 

156 "ln_final": NormalizationBridge(name="model.norm", config=self.cfg), 

157 # Unembed: lm_head. logit_scale is folded into weight in preprocess_weights. 

158 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

159 } 

160 

161 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 

162 """Fold logit_scale into unembed weights before ProcessWeights runs. 

163 

164 bridge.py lines 726-732 clone unembed.weight before calling this, so 

165 scaling does not affect the tied embed.weight. 

166 logit_scale=1.0 is a no-op (skipped for efficiency). 

167 """ 

168 scale: float = getattr(self.cfg, "logit_scale") # always set by __init__ 

169 if scale != 1.0: 

170 for key in ("unembed.weight", "unembed.bias"): 

171 if key in state_dict: 

172 orig_dtype = state_dict[key].dtype 

173 state_dict[key] = (state_dict[key].float() * scale).to(orig_dtype) 

174 return state_dict 

175 

176 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

177 """Set rotary embedding reference on attention bridges for component testing. 

178 

179 CohereRotaryEmbedding lives at hf_model.model.rotary_emb. The bridge 

180 delegates to it directly, preserving the repeat_interleave RoPE convention 

181 without re-implementing it in TL. 

182 

183 Pattern matches llama.py and qwen2.py. 

184 """ 

185 rotary_emb = hf_model.model.rotary_emb 

186 

187 # Set on actual bridge instances in the live model (if available) 

188 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

189 for block in bridge_model.blocks: 

190 if hasattr(block, "attn"): 

191 block.attn.set_rotary_emb(rotary_emb) 

192 

193 # Also set on the template so get_generalized_component() calls work 

194 attn_bridge = self.get_generalized_component("blocks.0.attn") 

195 attn_bridge.set_rotary_emb(rotary_emb)