Coverage for transformer_lens/utilities/tracr.py: 88%
60 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Utilities for loading Tracr-assembled models into TransformerBridge.
3These helpers are intentionally duck-typed so importing TransformerLens does not
4pull in Tracr, JAX, or Haiku. Pass in a Tracr ``AssembledTransformerModel`` from
5``tracr.compiler.compiling.compile_rasp_to_model``.
6"""
8from __future__ import annotations
10from collections.abc import Mapping
11from typing import Any
13import numpy as np
14import torch
16from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig
19def infer_tracr_output_label(model: Any) -> str:
20 """Infer the RASP output label used in ``model.residual_labels``.
22 Tracr stores residual basis labels as strings like ``"reverse:3"`` while
23 its categorical output encoder stores only values and output-column ids.
24 The output label is the unique residual-label prefix whose value set exactly
25 matches the output encoder's categorical value set.
26 """
27 encoding_map = _categorical_output_encoding_map(model)
28 output_values = {str(value) for value in encoding_map}
29 values_by_label: dict[str, set[str]] = {}
31 for residual_label in _residual_labels(model):
32 label, separator, value = residual_label.rpartition(":")
33 if not separator: 33 ↛ 34line 33 didn't jump to line 34 because the condition on line 33 was never true
34 continue
35 values_by_label.setdefault(label, set()).add(value)
37 candidates = [label for label, values in values_by_label.items() if values == output_values]
38 if len(candidates) != 1: 38 ↛ 43line 38 didn't jump to line 43 because the condition on line 38 was always true
39 raise ValueError(
40 "Could not infer Tracr output label from residual labels. Pass "
41 f"output_label explicitly. Candidates: {candidates!r}."
42 )
43 return candidates[0]
46def make_tracr_categorical_unembed(
47 model: Any,
48 output_label: str | None = None,
49 *,
50 dtype: np.dtype = np.dtype("float32"),
51) -> np.ndarray:
52 """Return Tracr's categorical unembed matrix in TL format.
54 The returned matrix has shape ``[d_model, d_vocab_out]`` and projects the
55 residual stream onto the output-basis coordinates selected by Tracr's
56 categorical output encoder. This avoids assuming those coordinates are the
57 first residual dimensions.
58 """
59 encoding_map = _categorical_output_encoding_map(model)
60 output_label = output_label if output_label is not None else infer_tracr_output_label(model)
61 label_to_residual_index = {label: index for index, label in enumerate(_residual_labels(model))}
62 unembed = np.zeros((len(label_to_residual_index), len(encoding_map)), dtype=dtype)
64 for output_value, output_index in encoding_map.items():
65 residual_label = f"{output_label}:{output_value}"
66 if residual_label not in label_to_residual_index: 66 ↛ 67line 66 didn't jump to line 67 because the condition on line 66 was never true
67 raise ValueError(
68 f"Could not find output basis label {residual_label!r} in "
69 "model.residual_labels. Pass the final RASP expression label "
70 "as output_label."
71 )
72 unembed[label_to_residual_index[residual_label], int(output_index)] = 1
74 return unembed
77def make_tracr_transformer_bridge_config(model: Any) -> TransformerBridgeConfig:
78 """Build a ``TransformerBridgeConfig`` matching a categorical Tracr model."""
79 model_config = model.model_config
80 d_model = _param_array(model, "token_embed", "embeddings").shape[1]
81 d_vocab = _param_array(model, "token_embed", "embeddings").shape[0]
82 n_ctx = _param_array(model, "pos_embed", "embeddings").shape[0]
84 return TransformerBridgeConfig(
85 n_layers=model_config.num_layers,
86 d_model=d_model,
87 d_head=model_config.key_size,
88 n_ctx=n_ctx,
89 d_vocab=d_vocab,
90 d_vocab_out=len(_categorical_output_encoding_map(model)),
91 d_mlp=model_config.mlp_hidden_size,
92 n_heads=model_config.num_heads,
93 act_fn="relu",
94 attention_dir="causal" if model_config.causal else "bidirectional",
95 normalization_type="LN" if model_config.layer_norm else None,
96 )
99def make_tracr_transformer_bridge_state_dict(
100 model: Any,
101 output_label: str | None = None,
102 *,
103 dtype: torch.dtype = torch.float32,
104) -> dict[str, torch.Tensor]:
105 """Build a ``TransformerBridge.boot_native`` state dict from Tracr weights.
107 The state-dict keys use the native PyTorch module names accepted by
108 ``TransformerBridge.load_state_dict`` after ``boot_native``.
109 """
110 if model.model_config.layer_norm:
111 raise NotImplementedError(
112 "Tracr layer_norm=True models are not supported by this converter yet."
113 )
115 n_layers = model.model_config.num_layers
116 state_dict: dict[str, torch.Tensor] = {
117 "tok_embed.weight": _tensor(model, "token_embed", "embeddings", dtype=dtype),
118 "pos.weight": _tensor(model, "pos_embed", "embeddings", dtype=dtype),
119 "head.weight": torch.tensor(
120 make_tracr_categorical_unembed(model, output_label).T,
121 dtype=dtype,
122 ),
123 }
125 for layer in range(n_layers):
126 prefix = f"transformer/layer_{layer}"
127 state_dict.update(
128 {
129 f"layers.{layer}.attn.k.weight": _tensor(
130 model, f"{prefix}/attn/key", "w", dtype=dtype
131 ).T,
132 f"layers.{layer}.attn.k.bias": _tensor(
133 model, f"{prefix}/attn/key", "b", dtype=dtype
134 ),
135 f"layers.{layer}.attn.q.weight": _tensor(
136 model, f"{prefix}/attn/query", "w", dtype=dtype
137 ).T,
138 f"layers.{layer}.attn.q.bias": _tensor(
139 model, f"{prefix}/attn/query", "b", dtype=dtype
140 ),
141 f"layers.{layer}.attn.v.weight": _tensor(
142 model, f"{prefix}/attn/value", "w", dtype=dtype
143 ).T,
144 f"layers.{layer}.attn.v.bias": _tensor(
145 model, f"{prefix}/attn/value", "b", dtype=dtype
146 ),
147 f"layers.{layer}.attn.o.weight": _tensor(
148 model, f"{prefix}/attn/linear", "w", dtype=dtype
149 ).T,
150 f"layers.{layer}.attn.o.bias": _tensor(
151 model, f"{prefix}/attn/linear", "b", dtype=dtype
152 ),
153 f"layers.{layer}.mlp.fc_in.weight": _tensor(
154 model, f"{prefix}/mlp/linear_1", "w", dtype=dtype
155 ).T,
156 f"layers.{layer}.mlp.fc_in.bias": _tensor(
157 model, f"{prefix}/mlp/linear_1", "b", dtype=dtype
158 ),
159 f"layers.{layer}.mlp.fc_out.weight": _tensor(
160 model, f"{prefix}/mlp/linear_2", "w", dtype=dtype
161 ).T,
162 f"layers.{layer}.mlp.fc_out.bias": _tensor(
163 model, f"{prefix}/mlp/linear_2", "b", dtype=dtype
164 ),
165 }
166 )
168 return state_dict
171def _categorical_output_encoding_map(model: Any) -> Mapping[Any, int]:
172 output_encoder = getattr(model, "output_encoder", None)
173 encoding_map = getattr(output_encoder, "encoding_map", None)
174 if not isinstance(encoding_map, Mapping): 174 ↛ 175line 174 didn't jump to line 175 because the condition on line 174 was never true
175 raise NotImplementedError(
176 "Only categorical Tracr outputs are supported; expected "
177 "model.output_encoder.encoding_map."
178 )
179 return encoding_map
182def _residual_labels(model: Any) -> list[str]:
183 residual_labels = getattr(model, "residual_labels", None)
184 if residual_labels is None: 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true
185 raise ValueError("Expected Tracr model to expose residual_labels.")
186 return list(residual_labels)
189def _param_array(model: Any, module_name: str, param_name: str) -> np.ndarray:
190 return np.asarray(model.params[module_name][param_name])
193def _tensor(
194 model: Any,
195 module_name: str,
196 param_name: str,
197 *,
198 dtype: torch.dtype,
199) -> torch.Tensor:
200 return torch.tensor(_param_array(model, module_name, param_name), dtype=dtype)