Coverage for transformer_lens/model_bridge/supported_architectures/phi.py: 41%

35 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Phi architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

6from transformer_lens.conversion_utils.param_processing_conversion import ( 

7 ParamProcessingConversion, 

8) 

9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

10from transformer_lens.model_bridge.generalized_components import ( 

11 EmbeddingBridge, 

12 LinearBridge, 

13 MLPBridge, 

14 NormalizationBridge, 

15 ParallelBlockBridge, 

16 PositionEmbeddingsAttentionBridge, 

17 RotaryEmbeddingBridge, 

18 UnembeddingBridge, 

19) 

20 

21 

22class PhiArchitectureAdapter(ArchitectureAdapter): 

23 """Architecture adapter for Phi models.""" 

24 

25 default_cfg = {"use_fast": False} 

26 

27 def __init__(self, cfg: Any) -> None: 

28 """Initialize the Phi architecture adapter. 

29 

30 Args: 

31 cfg: The configuration object. 

32 """ 

33 super().__init__(cfg) 

34 

35 # Set config variables for weight processing 

36 self.cfg.normalization_type = "LN" 

37 self.cfg.positional_embedding_type = "rotary" 

38 self.cfg.final_rms = False 

39 self.cfg.gated_mlp = False 

40 self.cfg.attn_only = False 

41 self.cfg.parallel_attn_mlp = True 

42 

43 self.cfg.default_prepend_bos = False 

44 

45 self.weight_processing_conversions = { 

46 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

47 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

48 ), 

49 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

50 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

51 ), 

52 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

53 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

54 ), 

55 "blocks.{i}.attn.q.bias": ParamProcessingConversion( 

56 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads), 

57 ), 

58 "blocks.{i}.attn.k.bias": ParamProcessingConversion( 

59 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads), 

60 ), 

61 "blocks.{i}.attn.v.bias": ParamProcessingConversion( 

62 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads), 

63 ), 

64 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

65 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

66 ), 

67 } 

68 

69 # Set up component mapping 

70 self.component_mapping = { 

71 "embed": EmbeddingBridge(name="model.embed_tokens"), 

72 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), 

73 "blocks": ParallelBlockBridge( 

74 name="model.layers", 

75 submodules={ 

76 "ln1": NormalizationBridge( 

77 name="input_layernorm", 

78 config=self.cfg, 

79 use_native_layernorm_autograd=True, 

80 ), 

81 "attn": PositionEmbeddingsAttentionBridge( 

82 name="self_attn", 

83 config=self.cfg, 

84 submodules={ 

85 "q": LinearBridge(name="q_proj"), 

86 "k": LinearBridge(name="k_proj"), 

87 "v": LinearBridge(name="v_proj"), 

88 "o": LinearBridge(name="dense"), 

89 }, 

90 requires_attention_mask=True, 

91 requires_position_embeddings=True, 

92 ), 

93 "mlp": MLPBridge( 

94 name="mlp", 

95 submodules={ 

96 "in": LinearBridge(name="fc1"), 

97 "out": LinearBridge(name="fc2"), 

98 }, 

99 ), 

100 }, 

101 ), 

102 "ln_final": NormalizationBridge( 

103 name="model.final_layernorm", 

104 config=self.cfg, 

105 use_native_layernorm_autograd=True, 

106 ), 

107 "unembed": UnembeddingBridge(name="lm_head"), 

108 } 

109 

110 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

111 """Set up rotary embedding references for Phi component testing. 

112 

113 Phi uses RoPE (Rotary Position Embeddings). We set the rotary_emb reference 

114 on all attention bridge instances for component testing. 

115 

116 Args: 

117 hf_model: The HuggingFace Phi model instance 

118 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances) 

119 """ 

120 # Get rotary embedding instance from the model 

121 # Phi models have rotary_emb at model.model.rotary_emb 

122 if hasattr(hf_model, "model") and hasattr(hf_model.model, "rotary_emb"): 

123 rotary_emb = hf_model.model.rotary_emb 

124 else: 

125 # Fallback: try to get from first layer 

126 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

127 if len(hf_model.model.layers) > 0: 

128 first_layer = hf_model.model.layers[0] 

129 if hasattr(first_layer, "self_attn") and hasattr( 

130 first_layer.self_attn, "rotary_emb" 

131 ): 

132 rotary_emb = first_layer.self_attn.rotary_emb 

133 else: 

134 return # Can't find rotary_emb 

135 else: 

136 return 

137 else: 

138 return 

139 

140 # Set rotary_emb on actual bridge instances in bridge_model if available 

141 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

142 # Set on each layer's actual attention bridge instance 

143 for block in bridge_model.blocks: 

144 if hasattr(block, "attn"): 

145 block.attn.set_rotary_emb(rotary_emb) 

146 

147 # Also set on the template for get_generalized_component() calls 

148 attn_bridge = self.get_generalized_component("blocks.0.attn") 

149 attn_bridge.set_rotary_emb(rotary_emb)