transformer_lens.tools.analysis package

Submodules

Module contents

Analysis tools for TransformerLens.

This subpackage collects high-level, single-call interpretability analyses that sit on top of the hook/cache system. They work with both HookedTransformer and the newer TransformerBridge (the two share the ActivationCache API).

Tools:
  • direct_logit_attribution: Direct Logit Attribution (DLA) over components, layers, or attention heads.

  • direct_path_patching: Direct path patching for head-to-head circuit analysis.

class transformer_lens.tools.analysis.DirectLogitAttribution(attribution: Float[Tensor, 'component *batch_and_pos'], labels: List[str], unit: str)

Bases: object

Result of a direct_logit_attribution() call.

attribution

Tensor of logit (or logit-difference) attributions with shape [component, *batch_and_pos]. The leading axis is aligned with labels. When pos selects a single position (the default) the position axis is dropped, leaving [component, batch] — or [component] if the cache had its batch dimension removed.

Type:

jaxtyping.Float[Tensor, ‘component *batch_and_pos’]

labels

Human-readable name for each component, aligned with the leading axis of attribution (e.g. "embed", "0_attn_out", "L3H7").

Type:

List[str]

unit

The decomposition unit used (“component”, “layer”, or “head”).

Type:

str

attribution: Float[Tensor, 'component *batch_and_pos']
labels: List[str]
top(k: int = 5) List[tuple]

Return the k highest-attribution (label, value) pairs.

Attribution is reduced to a scalar per component by meaning over any remaining batch/position dimensions, so this is most meaningful when a single position was selected.

unit: str
transformer_lens.tools.analysis.direct_logit_attribution(model, input: str | List[str] | Tensor | None = None, answer_tokens: str | int | Tensor | None = None, incorrect_tokens: str | int | Tensor | None = None, *, unit: str = 'component', pos: int | Tuple[int] | Tuple[int, int] | Tuple[int, int, int] | List[int] | Tensor | ndarray | None = -1, cache: ActivationCache | None = None) DirectLogitAttribution

Compute Direct Logit Attribution for a prompt.

Decomposes the contribution of model components to the logit of answer_tokens (or, if incorrect_tokens is given, to the logit difference answer - incorrect along the W_U direction, which is usually what you want for circuit analysis).

The model is run once with caching unless a precomputed cache is passed. Works with both HookedTransformer and TransformerBridge.

Note that DLA attributes only the part of a logit that comes from the residual stream through the unembedding direction; the unembedding bias b_U is a per-token constant that no component produces. So a complete decomposition reconstructs logit[token] - b_U[token] rather than the raw logit.

On a TransformerBridge, compatibility mode must be enabled (so the final LayerNorm is folded into W_U) — otherwise the projection direction is wrong and DLA returns silently incorrect numbers. Hybrid architectures (Mamba/SSM/Mixer/LinearAttention) are not yet supported because decompose_resid only understands the attn_out + mlp_out block layout; both conditions raise an explicit error at call time.

Parameters:
  • model – A HookedTransformer or TransformerBridge (the latter with enable_compatibility_mode() already called).

  • input – Prompt to run — a string, list of strings, or token tensor. Optional only when a precomputed cache is supplied.

  • answer_tokens – The correct token(s) to attribute, as a string, id, or tensor. A string is converted with model.to_single_token.

  • incorrect_tokens – Optional baseline token(s). When given, attribution is computed for the answer - incorrect residual direction. Must broadcast to the same shape as answer_tokens.

  • unit

    Decomposition granularity:

    • "component" (default): embedding + each layer’s attention and MLP output (via decompose_resid).

    • "layer": cumulative residual stream after each sublayer, i.e. logit-lens trajectory (via accumulated_resid).

    • "head": each attention head individually, plus a remainder term for everything else (via stack_head_results).

  • pos – Sequence position(s) to attribute. Defaults to -1 (the final token, the usual choice for next-token DLA). Pass None to keep every position (the result then has a trailing position axis).

  • cache – Optional precomputed ActivationCache to reuse instead of running the model again.

Returns:

A DirectLogitAttribution with attribution (shape [component, *batch_and_pos]) and aligned labels.

Raises:
  • ValueError – If unit is invalid, answer_tokens is None, neither input nor cache is provided, or a TransformerBridge is passed without compatibility mode enabled.

  • NotImplementedError – If a TransformerBridge reports a hybrid block layout (Mamba/SSM/Mixer/LinearAttention).

transformer_lens.tools.analysis.get_act_patch_direct_path(model: HookedTransformer | TransformerBridge, corrupted_tokens: Tensor, clean_cache: ActivationCache, corrupted_cache: ActivationCache, patching_metric: Callable[[Tensor], Tensor], src_layer: int, src_head: int, component: Literal['q', 'k', 'v'] = 'q', verbose: bool = True) Float[Tensor, 'n_layers n_heads']

Sweep direct path patches from one source head to all downstream heads.

For every destination head B = (dst_layer, dst_head) where dst_layer > src_layer, patch the contribution of source head A = (src_layer, src_head) into B’s query (or key / value) input, and record the patching metric.

The patch is a linear approximation:

delta_resid = clean_A_result - corrupted_A_result [batch, pos, d_model] delta_B_comp = (delta_resid / ln1_scale) @ W_comp[dst_head]

where W_comp is W_Q, W_K, or W_V according to component.

Parameters:
  • model – A HookedTransformer or TransformerBridge instance.

  • corrupted_tokens – Token IDs for the corrupted input, shape [batch, seq_len].

  • clean_cache – Cached activations from the clean (unpatched) run.

  • corrupted_cache – Cached activations from the corrupted run (needed for ln1 scale).

  • patching_metric – A function mapping the model’s logits tensor to a scalar.

  • src_layer – Layer index of the source attention head.

  • src_head – Head index of the source attention head.

  • component – Which input to patch at the destination head — “q” (default), “k”, or “v”.

  • verbose – Whether to show a tqdm progress bar.

Returns:

results – results[dst_layer, dst_head] is the patching metric when the direct path A → B is patched in. Entries for dst_layer <= src_layer are left as 0.0 (no causal path from A to those layers).

Return type:

Float[Tensor, “n_layers n_heads”]

transformer_lens.tools.analysis.get_act_patch_direct_path_all_sources(model: HookedTransformer | TransformerBridge, corrupted_tokens: Tensor, clean_cache: ActivationCache, corrupted_cache: ActivationCache, patching_metric: Callable[[Tensor], Tensor], component: Literal['q', 'k', 'v'] = 'q', verbose: bool = True) Float[Tensor, 'n_layers n_heads n_layers n_heads']

Full sweep: all (src_layer, src_head) → (dst_layer, dst_head) direct paths.

Returns a 4-D tensor of shape [n_layers, n_heads, n_layers, n_heads]. result[sl, sh, dl, dh] = patching metric when head (sl,sh)’s output is patched directly into head (dl,dh)’s query/key/value input.

Entries where dl <= sl are 0 (no causal path).

This runs O(n_layers * n_heads * n_layers * n_heads) forward passes and is intended for small models or targeted sub-sweeps. For large models prefer calling get_act_patch_direct_path per source head.