Coverage for transformer_lens/model_bridge/supported_architectures/opt.py: 81%
26 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""OPT architecture adapter."""
3from typing import Any
5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
6from transformer_lens.conversion_utils.param_processing_conversion import (
7 ParamProcessingConversion,
8)
9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
10from transformer_lens.model_bridge.generalized_components import (
11 AttentionBridge,
12 BlockBridge,
13 EmbeddingBridge,
14 LinearBridge,
15 NormalizationBridge,
16 PosEmbedBridge,
17 SymbolicBridge,
18 UnembeddingBridge,
19)
22class OptArchitectureAdapter(ArchitectureAdapter):
23 """Architecture adapter for OPT models."""
25 def __init__(self, cfg: Any) -> None:
26 """Initialize the OPT architecture adapter."""
27 super().__init__(cfg)
29 # Set config variables for weight processing
30 self.cfg.normalization_type = "LN"
31 self.cfg.positional_embedding_type = "standard"
32 self.cfg.final_rms = False
33 self.cfg.gated_mlp = False
34 self.cfg.attn_only = False
36 # OPT models were trained with BOS tokens (inherits default_prepend_bos = True)
38 # Post-norm: disable fold_ln and center_writing_weights (pre-norm only).
39 is_post_norm = not getattr(self.cfg, "do_layer_norm_before", True)
40 if is_post_norm: 40 ↛ 41line 40 didn't jump to line 41 because the condition on line 40 was never true
41 self.supports_fold_ln = False
42 self.supports_center_writing_weights = False
44 self.weight_processing_conversions = {
45 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
46 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
47 ),
48 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
49 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
50 ),
51 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
52 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
53 ),
54 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
55 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
56 ),
57 }
59 # OPT-350m is uniquely the only OPT size where word_embed_proj_dim (512)
60 # != hidden_size (1024). It uses project_in/project_out linear layers
61 # instead of a final_layer_norm. Detect this and conditionally include
62 # ln_final only when the model actually has one.
63 word_embed_proj_dim = getattr(self.cfg, "word_embed_proj_dim", self.cfg.d_model)
64 has_final_layer_norm = word_embed_proj_dim == self.cfg.d_model
66 self.component_mapping = {
67 "embed": EmbeddingBridge(name="model.decoder.embed_tokens"),
68 "pos_embed": PosEmbedBridge(name="model.decoder.embed_positions"),
69 "blocks": BlockBridge(
70 name="model.decoder.layers",
71 submodules={
72 "ln1": NormalizationBridge(
73 name="self_attn_layer_norm",
74 config=self.cfg,
75 use_native_layernorm_autograd=True,
76 ),
77 "attn": AttentionBridge(
78 name="self_attn",
79 config=self.cfg,
80 requires_attention_mask=True, # OPT requires attention_mask
81 attention_mask_4d=True, # OPT expects 4D mask [batch, 1, tgt_len, src_len]
82 submodules={
83 "q": LinearBridge(name="q_proj"),
84 "k": LinearBridge(name="k_proj"),
85 "v": LinearBridge(name="v_proj"),
86 "o": LinearBridge(name="out_proj"),
87 },
88 ),
89 "ln2": NormalizationBridge(
90 name="final_layer_norm",
91 config=self.cfg,
92 use_native_layernorm_autograd=True,
93 ),
94 # OPT has fc1/fc2 directly on the block, not in an MLP container.
95 # Use SymbolicBridge to maintain TransformerLens structure while
96 # correctly mapping to the underlying architecture.
97 "mlp": SymbolicBridge(
98 submodules={
99 "in": LinearBridge(name="fc1"),
100 "out": LinearBridge(name="fc2"),
101 },
102 ),
103 },
104 ),
105 "unembed": UnembeddingBridge(name="lm_head"),
106 }
107 if has_final_layer_norm: 107 ↛ 114line 107 didn't jump to line 114 because the condition on line 107 was always true
108 self.component_mapping["ln_final"] = NormalizationBridge(
109 name="model.decoder.final_layer_norm",
110 config=self.cfg,
111 use_native_layernorm_autograd=True,
112 )
113 # project_in/project_out bridge word_embed_proj_dim <-> hidden_size.
114 if not has_final_layer_norm: 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true
115 self.component_mapping["project_in"] = LinearBridge(name="model.decoder.project_in")
116 self.component_mapping["project_out"] = LinearBridge(name="model.decoder.project_out")