Coverage for transformer_lens/model_bridge/supported_architectures/phi.py: 41%
35 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Phi architecture adapter."""
3from typing import Any
5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
6from transformer_lens.conversion_utils.param_processing_conversion import (
7 ParamProcessingConversion,
8)
9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
10from transformer_lens.model_bridge.generalized_components import (
11 EmbeddingBridge,
12 LinearBridge,
13 MLPBridge,
14 NormalizationBridge,
15 ParallelBlockBridge,
16 PositionEmbeddingsAttentionBridge,
17 RotaryEmbeddingBridge,
18 UnembeddingBridge,
19)
22class PhiArchitectureAdapter(ArchitectureAdapter):
23 """Architecture adapter for Phi models."""
25 default_cfg = {"use_fast": False}
27 def __init__(self, cfg: Any) -> None:
28 """Initialize the Phi architecture adapter.
30 Args:
31 cfg: The configuration object.
32 """
33 super().__init__(cfg)
35 # Set config variables for weight processing
36 self.cfg.normalization_type = "LN"
37 self.cfg.positional_embedding_type = "rotary"
38 self.cfg.final_rms = False
39 self.cfg.gated_mlp = False
40 self.cfg.attn_only = False
41 self.cfg.parallel_attn_mlp = True
43 self.cfg.default_prepend_bos = False
45 self.weight_processing_conversions = {
46 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
47 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
48 ),
49 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
50 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
51 ),
52 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
53 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
54 ),
55 "blocks.{i}.attn.q.bias": ParamProcessingConversion(
56 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads),
57 ),
58 "blocks.{i}.attn.k.bias": ParamProcessingConversion(
59 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads),
60 ),
61 "blocks.{i}.attn.v.bias": ParamProcessingConversion(
62 tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads),
63 ),
64 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
65 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
66 ),
67 }
69 # Set up component mapping
70 self.component_mapping = {
71 "embed": EmbeddingBridge(name="model.embed_tokens"),
72 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
73 "blocks": ParallelBlockBridge(
74 name="model.layers",
75 submodules={
76 "ln1": NormalizationBridge(
77 name="input_layernorm",
78 config=self.cfg,
79 use_native_layernorm_autograd=True,
80 ),
81 "attn": PositionEmbeddingsAttentionBridge(
82 name="self_attn",
83 config=self.cfg,
84 submodules={
85 "q": LinearBridge(name="q_proj"),
86 "k": LinearBridge(name="k_proj"),
87 "v": LinearBridge(name="v_proj"),
88 "o": LinearBridge(name="dense"),
89 },
90 requires_attention_mask=True,
91 requires_position_embeddings=True,
92 ),
93 "mlp": MLPBridge(
94 name="mlp",
95 submodules={
96 "in": LinearBridge(name="fc1"),
97 "out": LinearBridge(name="fc2"),
98 },
99 ),
100 },
101 ),
102 "ln_final": NormalizationBridge(
103 name="model.final_layernorm",
104 config=self.cfg,
105 use_native_layernorm_autograd=True,
106 ),
107 "unembed": UnembeddingBridge(name="lm_head"),
108 }
110 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
111 """Set up rotary embedding references for Phi component testing.
113 Phi uses RoPE (Rotary Position Embeddings). We set the rotary_emb reference
114 on all attention bridge instances for component testing.
116 Args:
117 hf_model: The HuggingFace Phi model instance
118 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances)
119 """
120 # Get rotary embedding instance from the model
121 # Phi models have rotary_emb at model.model.rotary_emb
122 if hasattr(hf_model, "model") and hasattr(hf_model.model, "rotary_emb"):
123 rotary_emb = hf_model.model.rotary_emb
124 else:
125 # Fallback: try to get from first layer
126 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
127 if len(hf_model.model.layers) > 0:
128 first_layer = hf_model.model.layers[0]
129 if hasattr(first_layer, "self_attn") and hasattr(
130 first_layer.self_attn, "rotary_emb"
131 ):
132 rotary_emb = first_layer.self_attn.rotary_emb
133 else:
134 return # Can't find rotary_emb
135 else:
136 return
137 else:
138 return
140 # Set rotary_emb on actual bridge instances in bridge_model if available
141 if bridge_model is not None and hasattr(bridge_model, "blocks"):
142 # Set on each layer's actual attention bridge instance
143 for block in bridge_model.blocks:
144 if hasattr(block, "attn"):
145 block.attn.set_rotary_emb(rotary_emb)
147 # Also set on the template for get_generalized_component() calls
148 attn_bridge = self.get_generalized_component("blocks.0.attn")
149 attn_bridge.set_rotary_emb(rotary_emb)