transformer_lens.utilities.tracr module¶
Utilities for loading Tracr-assembled models into TransformerBridge.
These helpers are intentionally duck-typed so importing TransformerLens does not
pull in Tracr, JAX, or Haiku. Pass in a Tracr AssembledTransformerModel from
tracr.compiler.compiling.compile_rasp_to_model.
- transformer_lens.utilities.tracr.infer_tracr_output_label(model: Any) str¶
Infer the RASP output label used in
model.residual_labels.Tracr stores residual basis labels as strings like
"reverse:3"while its categorical output encoder stores only values and output-column ids. The output label is the unique residual-label prefix whose value set exactly matches the output encoder’s categorical value set.
- transformer_lens.utilities.tracr.make_tracr_categorical_unembed(model: Any, output_label: str | None = None, *, dtype: dtype = dtype('float32')) ndarray¶
Return Tracr’s categorical unembed matrix in TL format.
The returned matrix has shape
[d_model, d_vocab_out]and projects the residual stream onto the output-basis coordinates selected by Tracr’s categorical output encoder. This avoids assuming those coordinates are the first residual dimensions.
- transformer_lens.utilities.tracr.make_tracr_transformer_bridge_config(model: Any) TransformerBridgeConfig¶
Build a
TransformerBridgeConfigmatching a categorical Tracr model.
- transformer_lens.utilities.tracr.make_tracr_transformer_bridge_state_dict(model: Any, output_label: str | None = None, *, dtype: dtype = torch.float32) dict[str, Tensor]¶
Build a
TransformerBridge.boot_nativestate dict from Tracr weights.The state-dict keys use the native PyTorch module names accepted by
TransformerBridge.load_state_dictafterboot_native.