Coverage for transformer_lens/model_bridge/supported_architectures/xglm.py: 100%
27 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""XGLM architecture adapter.
3Supports XGLMForCausalLM (facebook/xglm-*).
4Assumes add_cross_attention=False (all published XGLM checkpoints).
5"""
7from typing import Any
9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
10from transformer_lens.model_bridge.generalized_components import (
11 AttentionBridge,
12 BlockBridge,
13 EmbeddingBridge,
14 LinearBridge,
15 NormalizationBridge,
16 SymbolicBridge,
17 UnembeddingBridge,
18)
21class XGLMArchitectureAdapter(ArchitectureAdapter):
22 """Architecture adapter for XGLM models.
24 XGLM uses pre-norm LayerNorm, sinusoidal positional embeddings (no
25 learnable weights), standard MHA with separate q/k/v/out_proj, and a
26 2-layer MLP (fc1/fc2) that lives directly on the decoder block rather
27 than inside an mlp sub-module.
29 All attention projections and fc1/fc2 carry biases. lm_head has no bias.
30 Embeddings are scaled by sqrt(d_model) at runtime in XGLMScaledWordEmbedding.
32 Optional Parameters (may not exist in state_dict):
33 --------------------------------------------------
34 None — all published XGLM checkpoints include all parameters listed above.
35 """
37 def __init__(self, cfg: Any) -> None:
38 """Initialize the XGLM architecture adapter."""
39 super().__init__(cfg)
41 # LayerNorm throughout (not RMSNorm)
42 self.cfg.normalization_type = "LN"
43 # Sinusoidal positional embeddings — added to token embeddings before blocks,
44 # no learnable weights, no RoPE
45 self.cfg.positional_embedding_type = "standard"
46 self.cfg.final_rms = False
47 # Standard 2-layer MLP (fc1 -> gelu -> fc2), no gate projection
48 self.cfg.gated_mlp = False
49 self.cfg.attn_only = False
50 self.cfg.uses_rms_norm = False
52 # Sinusoidal positional embeddings have no weights in the state_dict, so
53 # center_writing_weights cannot center pos_embed. Disable it for XGLM.
54 self.supports_center_writing_weights = False
56 # Standard MHA: n_heads == n_kv_heads for all XGLM sizes
57 self.weight_processing_conversions = {
58 **self._qkvo_weight_conversions(),
59 }
61 self.component_mapping = {
62 "embed": EmbeddingBridge(name="model.embed_tokens"),
63 # No "pos_embed": sinusoidal embeddings are a non-persistent buffer with
64 # no learnable weights — embed_positions does not appear in state_dict.
65 "blocks": BlockBridge(
66 name="model.layers",
67 submodules={
68 "ln1": NormalizationBridge(
69 name="self_attn_layer_norm", # pre-attn norm on XGLMDecoderLayer
70 config=self.cfg,
71 use_native_layernorm_autograd=True,
72 ),
73 "attn": AttentionBridge(
74 name="self_attn",
75 config=self.cfg,
76 requires_attention_mask=True,
77 attention_mask_4d=True, # (batch, 1, tgt_len, src_len)
78 submodules={
79 "q": LinearBridge(name="q_proj"),
80 "k": LinearBridge(name="k_proj"),
81 "v": LinearBridge(name="v_proj"),
82 "o": LinearBridge(name="out_proj"), # out_proj, not o_proj
83 },
84 ),
85 "ln2": NormalizationBridge(
86 name="final_layer_norm", # pre-MLP norm on XGLMDecoderLayer
87 config=self.cfg,
88 use_native_layernorm_autograd=True,
89 ),
90 # fc1/fc2 live directly on XGLMDecoderLayer — no "mlp" container.
91 # SymbolicBridge preserves TL structure without a real HF submodule.
92 "mlp": SymbolicBridge(
93 submodules={
94 "in": LinearBridge(name="fc1"),
95 "out": LinearBridge(name="fc2"),
96 },
97 ),
98 },
99 ),
100 "ln_final": NormalizationBridge(
101 name="model.layer_norm", # note: layer_norm, not norm
102 config=self.cfg,
103 use_native_layernorm_autograd=True,
104 ),
105 "unembed": UnembeddingBridge(name="lm_head"),
106 }
108 def setup_hook_compatibility(self, bridge: Any) -> None:
109 """Scale hook_embed by sqrt(d_model) to match XGLMScaledWordEmbedding.forward().
111 XGLMScaledWordEmbedding multiplies the embedding lookup by embed_scale =
112 sqrt(d_model) at runtime. Without this override, hook_embed would capture
113 the raw (unscaled) table output, diverging from actual model activations.
114 """
115 from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
116 BaseTensorConversion,
117 )
119 class EmbeddingScaleConversion(BaseTensorConversion):
120 """Scale embeddings by sqrt(d_model) for XGLM models."""
122 def __init__(self, scale: float) -> None:
123 super().__init__()
124 self.scale = scale
126 def handle_conversion(self, input_value: Any, *full_context: Any) -> Any:
127 return input_value * self.scale
129 def revert(self, input_value: Any, *full_context: Any) -> Any:
130 return input_value / self.scale
132 if hasattr(bridge, "embed") and hasattr(bridge.embed, "hook_out"):
133 bridge.embed.hook_out.hook_conversion = EmbeddingScaleConversion(
134 self.cfg.d_model**0.5
135 )