Coverage for transformer_lens/model_bridge/supported_architectures/llama.py: 63%

27 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Llama architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

6from transformer_lens.model_bridge.generalized_components import ( 

7 BlockBridge, 

8 EmbeddingBridge, 

9 GatedMLPBridge, 

10 LinearBridge, 

11 PositionEmbeddingsAttentionBridge, 

12 RMSNormalizationBridge, 

13 RotaryEmbeddingBridge, 

14 UnembeddingBridge, 

15) 

16 

17 

18class LlamaArchitectureAdapter(ArchitectureAdapter): 

19 """Architecture adapter for Llama models. 

20 

21 Optional Parameters (may not exist in state_dict): 

22 ------------------------------------------------- 

23 LLaMA models do NOT have biases on attention and MLP projections: 

24 

25 - blocks.{i}.attn.b_Q - No bias on query projection 

26 - blocks.{i}.attn.b_K - No bias on key projection 

27 - blocks.{i}.attn.b_V - No bias on value projection 

28 - blocks.{i}.attn.b_O - No bias on output projection 

29 - blocks.{i}.mlp.b_in - No bias on MLP input (up_proj) 

30 - blocks.{i}.mlp.b_gate - No bias on MLP gate projection 

31 - blocks.{i}.mlp.b_out - No bias on MLP output (down_proj) 

32 - blocks.{i}.ln1.b - RMSNorm has no bias 

33 - blocks.{i}.ln2.b - RMSNorm has no bias 

34 - ln_final.b - RMSNorm has no bias 

35 

36 Weight processing must handle these missing biases gracefully using 

37 ProcessWeights._safe_get_tensor() or by checking for None values. 

38 """ 

39 

40 def __init__(self, cfg: Any) -> None: 

41 """Initialize the Llama architecture adapter.""" 

42 super().__init__(cfg) 

43 

44 # Set config variables for weight processing 

45 self.cfg.normalization_type = "RMS" 

46 self.cfg.positional_embedding_type = "rotary" 

47 self.cfg.final_rms = True 

48 self.cfg.gated_mlp = True 

49 self.cfg.attn_only = False 

50 

51 self.default_config = { 

52 "d_model": cfg.d_model, 

53 "d_head": cfg.d_model // cfg.n_heads, 

54 "n_heads": cfg.n_heads, 

55 "n_layers": cfg.n_layers, 

56 "d_vocab": cfg.d_vocab, 

57 } 

58 

59 # Add GQA support for Llama 3.1, 3.2, and later models 

60 # Must set directly on cfg, not just in default_config 

61 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 61 ↛ 65line 61 didn't jump to line 65 because the condition on line 61 was always true

62 self.default_config["n_key_value_heads"] = cfg.n_key_value_heads 

63 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

64 

65 self.cfg.uses_rms_norm = True 

66 # Llama uses 'variance_epsilon' instead of 'eps' for RMSNorm 

67 self.cfg.eps_attr = "variance_epsilon" 

68 

69 self.weight_processing_conversions = { 

70 **self._qkvo_weight_conversions(), 

71 } 

72 

73 self.component_mapping = { 

74 "embed": EmbeddingBridge(name="model.embed_tokens"), 

75 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), 

76 "blocks": BlockBridge( 

77 name="model.layers", 

78 submodules={ 

79 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

80 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

81 "attn": PositionEmbeddingsAttentionBridge( 

82 name="self_attn", 

83 config=self.cfg, 

84 submodules={ 

85 "q": LinearBridge(name="q_proj"), 

86 "k": LinearBridge(name="k_proj"), 

87 "v": LinearBridge(name="v_proj"), 

88 "o": LinearBridge(name="o_proj"), 

89 }, 

90 requires_attention_mask=True, 

91 requires_position_embeddings=True, 

92 ), 

93 "mlp": GatedMLPBridge( 

94 name="mlp", 

95 config=self.cfg, 

96 submodules={ 

97 "gate": LinearBridge(name="gate_proj"), 

98 "in": LinearBridge(name="up_proj"), 

99 "out": LinearBridge(name="down_proj"), 

100 }, 

101 ), 

102 }, 

103 ), 

104 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

105 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

106 } 

107 

108 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

109 """Set up rotary embedding references for Llama component testing. 

110 

111 Llama uses RoPE (Rotary Position Embeddings). We set the rotary_emb reference 

112 on all attention bridge instances for component testing. 

113 

114 Args: 

115 hf_model: The HuggingFace Llama model instance 

116 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances) 

117 """ 

118 # Get rotary embedding instance from the model 

119 rotary_emb = hf_model.model.rotary_emb 

120 

121 # Set rotary_emb on actual bridge instances in bridge_model if available 

122 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

123 # Set on each layer's actual attention bridge instance 

124 for block in bridge_model.blocks: 

125 if hasattr(block, "attn"): 

126 block.attn.set_rotary_emb(rotary_emb) 

127 

128 # Also set on the template for get_generalized_component() calls 

129 attn_bridge = self.get_generalized_component("blocks.0.attn") 

130 attn_bridge.set_rotary_emb(rotary_emb)