Coverage for transformer_lens/model_bridge/supported_architectures/cohere.py: 75%

44 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Cohere architecture adapter. 

2 

3Supports CohereForCausalLM models (Command-R family) with: 

4- Parallel attention+MLP sharing a single input_layernorm (no post_attention_layernorm) 

5- True LayerNorm (CohereLayerNorm) with weight but no bias 

6- GQA (grouped-query attention) with separate Q/K/V/O projections 

7- Gated SwiGLU MLP (gate_proj, up_proj, down_proj) 

8- Logit scaling: output logits multiplied by config.logit_scale (default 1/16) 

9- Tied embed/unembed weights by default (tie_word_embeddings=True) 

10- Interleaved RoPE via CohereRotaryEmbedding (delegated to HF module) 

11""" 

12 

13from typing import Any 

14 

15import torch 

16 

17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

18from transformer_lens.model_bridge.generalized_components import ( 

19 EmbeddingBridge, 

20 GatedMLPBridge, 

21 LinearBridge, 

22 NormalizationBridge, 

23 ParallelBlockBridge, 

24 PositionEmbeddingsAttentionBridge, 

25 RotaryEmbeddingBridge, 

26 UnembeddingBridge, 

27) 

28 

29 

30class CohereArchitectureAdapter(ArchitectureAdapter): 

31 """Architecture adapter for Cohere models (CohereForCausalLM). 

32 

33 Architectural quirks vs. standard decoder-only models: 

34 - Single input_layernorm per block; NO post_attention_layernorm. 

35 Attention and MLP both read the SAME normed hidden states (parallel). 

36 - CohereLayerNorm is true LayerNorm (mean-subtracting), NOT RMSNorm. 

37 It has a weight parameter but NO bias parameter. 

38 - Logit scale: CohereForCausalLM.forward multiplies logits by logit_scale 

39 (default 0.0625 = 1/16). Folded into unembed.weight via preprocess_weights. 

40 - Rotary embeddings use repeat_interleave instead of cat-split (delegated to HF). 

41 

42 Optional parameters (absent from state_dict by default): 

43 - blocks.{i}.attn.b_Q/b_K/b_V/b_O — no bias on projections (attention_bias=False) 

44 - blocks.{i}.mlp.b_gate/b_in/b_out — no bias on MLP projections 

45 - blocks.{i}.ln1.b — CohereLayerNorm has no bias 

46 - ln_final.b — CohereLayerNorm has no bias 

47 """ 

48 

49 def __init__(self, cfg: Any) -> None: 

50 """Initialize the Cohere architecture adapter.""" 

51 super().__init__(cfg) 

52 

53 # --- Normalization --- 

54 # CohereLayerNorm is true LayerNorm (subtracts mean), NOT RMSNorm. 

55 # uses_rms_norm=False tells NormalizationBridge to subtract the mean. 

56 # eps_attr="variance_epsilon": CohereLayerNorm stores eps as self.variance_epsilon. 

57 self.cfg.normalization_type = "LN" 

58 self.cfg.uses_rms_norm = False 

59 self.cfg.eps_attr = "variance_epsilon" 

60 self.cfg.final_rms = False 

61 

62 # --- Position embeddings and MLP --- 

63 self.cfg.positional_embedding_type = "rotary" 

64 self.cfg.gated_mlp = True 

65 self.cfg.attn_only = False 

66 

67 # --- Parallel block: single norm, no post_attention_layernorm --- 

68 self.cfg.parallel_attn_mlp = True 

69 

70 # --- Tokenizer: BOS is prepended by default --- 

71 # CohereTokenizerFast has add_bos_token=False but HF's __call__ with 

72 # add_special_tokens=True (the default) prepends BOS. Verified against 

73 # trl-internal-testing/tiny-CohereForCausalLM. 

74 self.cfg.default_prepend_bos = True 

75 

76 # --- GQA: n_key_value_heads --- 

77 # sources/transformers.py copies num_key_value_heads generically. 

78 # Re-read here to ensure it's set on cfg for _qkvo_weight_conversions. 

79 n_kv = getattr(cfg, "n_key_value_heads", None) 

80 if n_kv is not None: 80 ↛ 87line 80 didn't jump to line 87 because the condition on line 80 was always true

81 self.cfg.n_key_value_heads = n_kv 

82 

83 # --- Weight processing conversions --- 

84 # Standard GQA-aware Q/K/V/O rearrangements (same as Llama/Qwen2). 

85 # n_kv is already set on self.cfg; _qkvo_weight_conversions reads it via 

86 # getattr(self.cfg, "n_key_value_heads", None) when called with no args. 

87 self.weight_processing_conversions = { 

88 **self._qkvo_weight_conversions(), 

89 } 

90 

91 # --- Logit scale --- 

92 # CohereConfig.logit_scale is typed float | None; apply explicit None-check 

93 # so cfg.logit_scale is always a plain float (never None). 

94 # logit_scale is not a declared field on TransformerBridgeConfig; it is a 

95 # Cohere-specific dynamic attribute accessed later in preprocess_weights. 

96 _ls = getattr(cfg, "logit_scale", None) 

97 self.cfg.logit_scale = float(_ls) if _ls is not None else 0.0625 # type: ignore[attr-defined] 

98 

99 # --- RoPE theta (informational metadata) --- 

100 # CohereRotaryEmbedding reads config.rope_parameters["rope_theta"] directly; 

101 # store it in cfg.rotary_base so TL config accurately reflects the model. 

102 # TransformerBridgeConfig stores rotary_base as int, matching its declared type. 

103 _rope_params = getattr(cfg, "rope_parameters", None) or {} 

104 if isinstance(_rope_params, dict): 104 ↛ 107line 104 didn't jump to line 107 because the condition on line 104 was always true

105 _theta = _rope_params.get("rope_theta", getattr(cfg, "default_theta", 10000.0)) 

106 else: 

107 _theta = getattr(cfg, "default_theta", 10000.0) 

108 self.cfg.rotary_base = int(_theta) 

109 

110 # --- Component mapping --- 

111 # Block structure follows Falcon's parallel_attn=True, num_ln_in_parallel_attn=1 

112 # mode: single ln1 feeds both attn and MLP; NO ln2. 

113 # Submodule shapes follow Llama: separate q/k/v/o projections and SwiGLU MLP. 

114 # Rotary and attention both delegate to HF modules, preserving Cohere's 

115 # repeat_interleave RoPE convention without re-implementing it in TL. 

116 self.component_mapping = { 

117 # Embedding: model.embed_tokens (same root as Llama, not transformer.* like Falcon) 

118 "embed": EmbeddingBridge(name="model.embed_tokens"), 

119 # Rotary embedding: top-level, delegates to CohereRotaryEmbedding. 

120 # Pattern matches llama.py:75 and falcon.py:154 — NOT inside blocks. 

121 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

122 "blocks": ParallelBlockBridge( 

123 name="model.layers", 

124 submodules={ 

125 # Single pre-norm only — Cohere has no post_attention_layernorm. 

126 # NormalizationBridge handles weight-only CohereLayerNorm correctly: 

127 # it checks `hasattr(original_component, "bias") and bias is not None` 

128 # before adding bias, so the missing bias attribute is silently skipped. 

129 "ln1": NormalizationBridge(name="input_layernorm", config=self.cfg), 

130 # No "ln2" — parallel block, same normed input goes to attn AND mlp. 

131 "attn": PositionEmbeddingsAttentionBridge( 

132 name="self_attn", 

133 config=self.cfg, 

134 submodules={ 

135 "q": LinearBridge(name="q_proj"), 

136 "k": LinearBridge(name="k_proj"), 

137 "v": LinearBridge(name="v_proj"), 

138 "o": LinearBridge(name="o_proj"), 

139 }, 

140 requires_attention_mask=True, 

141 requires_position_embeddings=True, 

142 ), 

143 # GatedMLPBridge: gate/in/out matches Llama's gate_proj/up_proj/down_proj. 

144 # Optional use_qk_norm is handled transparently by HF's 

145 # CohereAttention.forward delegation (no extra submodules needed). 

146 "mlp": GatedMLPBridge( 

147 name="mlp", 

148 config=self.cfg, 

149 submodules={ 

150 "gate": LinearBridge(name="gate_proj"), 

151 "in": LinearBridge(name="up_proj"), 

152 "out": LinearBridge(name="down_proj"), 

153 }, 

154 ), 

155 }, 

156 ), 

157 # Final LayerNorm (CohereLayerNorm, weight-only) at model.norm 

158 "ln_final": NormalizationBridge(name="model.norm", config=self.cfg), 

159 # Unembed: lm_head. logit_scale is folded into weight in preprocess_weights. 

160 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

161 } 

162 

163 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 

164 """Fold logit_scale into unembed weights before ProcessWeights runs. 

165 

166 bridge.py lines 726-732 clone unembed.weight before calling this, so 

167 scaling does not affect the tied embed.weight. 

168 logit_scale=1.0 is a no-op (skipped for efficiency). 

169 """ 

170 scale: float = getattr(self.cfg, "logit_scale") # always set by __init__ 

171 if scale != 1.0: 

172 for key in ("unembed.weight", "unembed.bias"): 

173 if key in state_dict: 

174 orig_dtype = state_dict[key].dtype 

175 state_dict[key] = (state_dict[key].float() * scale).to(orig_dtype) 

176 return state_dict 

177 

178 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

179 """Set rotary embedding reference on attention bridges for component testing. 

180 

181 CohereRotaryEmbedding lives at hf_model.model.rotary_emb. The bridge 

182 delegates to it directly, preserving the repeat_interleave RoPE convention 

183 without re-implementing it in TL. 

184 

185 Pattern matches llama.py and qwen2.py. 

186 """ 

187 rotary_emb = hf_model.model.rotary_emb 

188 

189 # Set on actual bridge instances in the live model (if available) 

190 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

191 for block in bridge_model.blocks: 

192 if hasattr(block, "attn"): 

193 block.attn.set_rotary_emb(rotary_emb) 

194 

195 # Also set on the template so get_generalized_component() calls work 

196 attn_bridge = self.get_generalized_component("blocks.0.attn") 

197 attn_bridge.set_rotary_emb(rotary_emb)