transformer_lens.utilities.tracr module

Utilities for loading Tracr-assembled models into TransformerBridge.

These helpers are intentionally duck-typed so importing TransformerLens does not pull in Tracr, JAX, or Haiku. Pass in a Tracr AssembledTransformerModel from tracr.compiler.compiling.compile_rasp_to_model.

transformer_lens.utilities.tracr.infer_tracr_output_label(model: Any) str

Infer the RASP output label used in model.residual_labels.

Tracr stores residual basis labels as strings like "reverse:3" while its categorical output encoder stores only values and output-column ids. The output label is the unique residual-label prefix whose value set exactly matches the output encoder’s categorical value set.

transformer_lens.utilities.tracr.make_tracr_categorical_unembed(model: Any, output_label: str | None = None, *, dtype: dtype = dtype('float32')) ndarray

Return Tracr’s categorical unembed matrix in TL format.

The returned matrix has shape [d_model, d_vocab_out] and projects the residual stream onto the output-basis coordinates selected by Tracr’s categorical output encoder. This avoids assuming those coordinates are the first residual dimensions.

transformer_lens.utilities.tracr.make_tracr_transformer_bridge_config(model: Any) TransformerBridgeConfig

Build a TransformerBridgeConfig matching a categorical Tracr model.

transformer_lens.utilities.tracr.make_tracr_transformer_bridge_state_dict(model: Any, output_label: str | None = None, *, dtype: dtype = torch.float32) dict[str, Tensor]

Build a TransformerBridge.boot_native state dict from Tracr weights.

The state-dict keys use the native PyTorch module names accepted by TransformerBridge.load_state_dict after boot_native.