Coverage for transformer_lens/model_bridge/supported_architectures/xglm.py: 100%

27 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""XGLM architecture adapter. 

2 

3Supports XGLMForCausalLM (facebook/xglm-*). 

4Assumes add_cross_attention=False (all published XGLM checkpoints). 

5""" 

6 

7from typing import Any 

8 

9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

10from transformer_lens.model_bridge.generalized_components import ( 

11 AttentionBridge, 

12 BlockBridge, 

13 EmbeddingBridge, 

14 LinearBridge, 

15 NormalizationBridge, 

16 SymbolicBridge, 

17 UnembeddingBridge, 

18) 

19 

20 

21class XGLMArchitectureAdapter(ArchitectureAdapter): 

22 """Architecture adapter for XGLM models. 

23 

24 XGLM uses pre-norm LayerNorm, sinusoidal positional embeddings (no 

25 learnable weights), standard MHA with separate q/k/v/out_proj, and a 

26 2-layer MLP (fc1/fc2) that lives directly on the decoder block rather 

27 than inside an mlp sub-module. 

28 

29 All attention projections and fc1/fc2 carry biases. lm_head has no bias. 

30 Embeddings are scaled by sqrt(d_model) at runtime in XGLMScaledWordEmbedding. 

31 

32 Optional Parameters (may not exist in state_dict): 

33 -------------------------------------------------- 

34 None — all published XGLM checkpoints include all parameters listed above. 

35 """ 

36 

37 def __init__(self, cfg: Any) -> None: 

38 """Initialize the XGLM architecture adapter.""" 

39 super().__init__(cfg) 

40 

41 # LayerNorm throughout (not RMSNorm) 

42 self.cfg.normalization_type = "LN" 

43 # Sinusoidal positional embeddings — added to token embeddings before blocks, 

44 # no learnable weights, no RoPE 

45 self.cfg.positional_embedding_type = "standard" 

46 self.cfg.final_rms = False 

47 # Standard 2-layer MLP (fc1 -> gelu -> fc2), no gate projection 

48 self.cfg.gated_mlp = False 

49 self.cfg.attn_only = False 

50 self.cfg.uses_rms_norm = False 

51 

52 # Sinusoidal positional embeddings have no weights in the state_dict, so 

53 # center_writing_weights cannot center pos_embed. Disable it for XGLM. 

54 self.supports_center_writing_weights = False 

55 

56 # Standard MHA: n_heads == n_kv_heads for all XGLM sizes 

57 self.weight_processing_conversions = { 

58 **self._qkvo_weight_conversions(), 

59 } 

60 

61 self.component_mapping = { 

62 "embed": EmbeddingBridge(name="model.embed_tokens"), 

63 # No "pos_embed": sinusoidal embeddings are a non-persistent buffer with 

64 # no learnable weights — embed_positions does not appear in state_dict. 

65 "blocks": BlockBridge( 

66 name="model.layers", 

67 submodules={ 

68 "ln1": NormalizationBridge( 

69 name="self_attn_layer_norm", # pre-attn norm on XGLMDecoderLayer 

70 config=self.cfg, 

71 use_native_layernorm_autograd=True, 

72 ), 

73 "attn": AttentionBridge( 

74 name="self_attn", 

75 config=self.cfg, 

76 requires_attention_mask=True, 

77 attention_mask_4d=True, # (batch, 1, tgt_len, src_len) 

78 submodules={ 

79 "q": LinearBridge(name="q_proj"), 

80 "k": LinearBridge(name="k_proj"), 

81 "v": LinearBridge(name="v_proj"), 

82 "o": LinearBridge(name="out_proj"), # out_proj, not o_proj 

83 }, 

84 ), 

85 "ln2": NormalizationBridge( 

86 name="final_layer_norm", # pre-MLP norm on XGLMDecoderLayer 

87 config=self.cfg, 

88 use_native_layernorm_autograd=True, 

89 ), 

90 # fc1/fc2 live directly on XGLMDecoderLayer — no "mlp" container. 

91 # SymbolicBridge preserves TL structure without a real HF submodule. 

92 "mlp": SymbolicBridge( 

93 submodules={ 

94 "in": LinearBridge(name="fc1"), 

95 "out": LinearBridge(name="fc2"), 

96 }, 

97 ), 

98 }, 

99 ), 

100 "ln_final": NormalizationBridge( 

101 name="model.layer_norm", # note: layer_norm, not norm 

102 config=self.cfg, 

103 use_native_layernorm_autograd=True, 

104 ), 

105 "unembed": UnembeddingBridge(name="lm_head"), 

106 } 

107 

108 def setup_hook_compatibility(self, bridge: Any) -> None: 

109 """Scale hook_embed by sqrt(d_model) to match XGLMScaledWordEmbedding.forward(). 

110 

111 XGLMScaledWordEmbedding multiplies the embedding lookup by embed_scale = 

112 sqrt(d_model) at runtime. Without this override, hook_embed would capture 

113 the raw (unscaled) table output, diverging from actual model activations. 

114 """ 

115 from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( 

116 BaseTensorConversion, 

117 ) 

118 

119 class EmbeddingScaleConversion(BaseTensorConversion): 

120 """Scale embeddings by sqrt(d_model) for XGLM models.""" 

121 

122 def __init__(self, scale: float) -> None: 

123 super().__init__() 

124 self.scale = scale 

125 

126 def handle_conversion(self, input_value: Any, *full_context: Any) -> Any: 

127 return input_value * self.scale 

128 

129 def revert(self, input_value: Any, *full_context: Any) -> Any: 

130 return input_value / self.scale 

131 

132 if hasattr(bridge, "embed") and hasattr(bridge.embed, "hook_out"): 

133 bridge.embed.hook_out.hook_conversion = EmbeddingScaleConversion( 

134 self.cfg.d_model**0.5 

135 )