transformer_lens.tools.analysis package¶
Submodules¶
Module contents¶
Analysis tools for TransformerLens.
This subpackage collects high-level, single-call interpretability analyses that
sit on top of the hook/cache system. They work with both HookedTransformer
and the newer TransformerBridge (the two share the ActivationCache API).
- Tools:
direct_logit_attribution: Direct Logit Attribution (DLA) over components, layers, or attention heads.
direct_path_patching: Direct path patching for head-to-head circuit analysis.
- class transformer_lens.tools.analysis.DirectLogitAttribution(attribution: Float[Tensor, 'component *batch_and_pos'], labels: List[str], unit: str)¶
Bases:
objectResult of a
direct_logit_attribution()call.- attribution¶
Tensor of logit (or logit-difference) attributions with shape
[component, *batch_and_pos]. The leading axis is aligned withlabels. Whenposselects a single position (the default) the position axis is dropped, leaving[component, batch]— or[component]if the cache had its batch dimension removed.- Type:
jaxtyping.Float[Tensor, ‘component *batch_and_pos’]
- labels¶
Human-readable name for each component, aligned with the leading axis of
attribution(e.g."embed","0_attn_out","L3H7").- Type:
List[str]
- unit¶
The decomposition unit used (“component”, “layer”, or “head”).
- Type:
str
- attribution: Float[Tensor, 'component *batch_and_pos']¶
- labels: List[str]¶
- top(k: int = 5) List[tuple]¶
Return the
khighest-attribution(label, value)pairs.Attribution is reduced to a scalar per component by meaning over any remaining batch/position dimensions, so this is most meaningful when a single position was selected.
- unit: str¶
- transformer_lens.tools.analysis.direct_logit_attribution(model, input: str | List[str] | Tensor | None = None, answer_tokens: str | int | Tensor | None = None, incorrect_tokens: str | int | Tensor | None = None, *, unit: str = 'component', pos: int | Tuple[int] | Tuple[int, int] | Tuple[int, int, int] | List[int] | Tensor | ndarray | None = -1, cache: ActivationCache | None = None) DirectLogitAttribution¶
Compute Direct Logit Attribution for a prompt.
Decomposes the contribution of model components to the logit of
answer_tokens(or, ifincorrect_tokensis given, to the logit differenceanswer - incorrectalong theW_Udirection, which is usually what you want for circuit analysis).The model is run once with caching unless a precomputed
cacheis passed. Works with bothHookedTransformerandTransformerBridge.Note that DLA attributes only the part of a logit that comes from the residual stream through the unembedding direction; the unembedding bias
b_Uis a per-token constant that no component produces. So a complete decomposition reconstructslogit[token] - b_U[token]rather than the raw logit.On a
TransformerBridge, compatibility mode must be enabled (so the final LayerNorm is folded intoW_U) — otherwise the projection direction is wrong and DLA returns silently incorrect numbers. Hybrid architectures (Mamba/SSM/Mixer/LinearAttention) are not yet supported becausedecompose_residonly understands theattn_out + mlp_outblock layout; both conditions raise an explicit error at call time.- Parameters:
model – A
HookedTransformerorTransformerBridge(the latter withenable_compatibility_mode()already called).input – Prompt to run — a string, list of strings, or token tensor. Optional only when a precomputed
cacheis supplied.answer_tokens – The correct token(s) to attribute, as a string, id, or tensor. A string is converted with
model.to_single_token.incorrect_tokens – Optional baseline token(s). When given, attribution is computed for the
answer - incorrectresidual direction. Must broadcast to the same shape asanswer_tokens.unit –
Decomposition granularity:
"component"(default): embedding + each layer’s attention and MLP output (viadecompose_resid)."layer": cumulative residual stream after each sublayer, i.e. logit-lens trajectory (viaaccumulated_resid)."head": each attention head individually, plus a remainder term for everything else (viastack_head_results).
pos – Sequence position(s) to attribute. Defaults to
-1(the final token, the usual choice for next-token DLA). PassNoneto keep every position (the result then has a trailing position axis).cache – Optional precomputed
ActivationCacheto reuse instead of running the model again.
- Returns:
A
DirectLogitAttributionwithattribution(shape[component, *batch_and_pos]) and alignedlabels.- Raises:
ValueError – If
unitis invalid,answer_tokensisNone, neitherinputnorcacheis provided, or aTransformerBridgeis passed without compatibility mode enabled.NotImplementedError – If a
TransformerBridgereports a hybrid block layout (Mamba/SSM/Mixer/LinearAttention).
- transformer_lens.tools.analysis.get_act_patch_direct_path(model: HookedTransformer | TransformerBridge, corrupted_tokens: Tensor, clean_cache: ActivationCache, corrupted_cache: ActivationCache, patching_metric: Callable[[Tensor], Tensor], src_layer: int, src_head: int, component: Literal['q', 'k', 'v'] = 'q', verbose: bool = True) Float[Tensor, 'n_layers n_heads']¶
Sweep direct path patches from one source head to all downstream heads.
For every destination head B = (dst_layer, dst_head) where dst_layer > src_layer, patch the contribution of source head A = (src_layer, src_head) into B’s query (or key / value) input, and record the patching metric.
The patch is a linear approximation:
delta_resid = clean_A_result - corrupted_A_result [batch, pos, d_model] delta_B_comp = (delta_resid / ln1_scale) @ W_comp[dst_head]
where W_comp is W_Q, W_K, or W_V according to component.
- Parameters:
model – A HookedTransformer or TransformerBridge instance.
corrupted_tokens – Token IDs for the corrupted input, shape [batch, seq_len].
clean_cache – Cached activations from the clean (unpatched) run.
corrupted_cache – Cached activations from the corrupted run (needed for ln1 scale).
patching_metric – A function mapping the model’s logits tensor to a scalar.
src_layer – Layer index of the source attention head.
src_head – Head index of the source attention head.
component – Which input to patch at the destination head — “q” (default), “k”, or “v”.
verbose – Whether to show a tqdm progress bar.
- Returns:
results – results[dst_layer, dst_head] is the patching metric when the direct path A → B is patched in. Entries for dst_layer <= src_layer are left as 0.0 (no causal path from A to those layers).
- Return type:
Float[Tensor, “n_layers n_heads”]
- transformer_lens.tools.analysis.get_act_patch_direct_path_all_sources(model: HookedTransformer | TransformerBridge, corrupted_tokens: Tensor, clean_cache: ActivationCache, corrupted_cache: ActivationCache, patching_metric: Callable[[Tensor], Tensor], component: Literal['q', 'k', 'v'] = 'q', verbose: bool = True) Float[Tensor, 'n_layers n_heads n_layers n_heads']¶
Full sweep: all (src_layer, src_head) → (dst_layer, dst_head) direct paths.
Returns a 4-D tensor of shape [n_layers, n_heads, n_layers, n_heads]. result[sl, sh, dl, dh] = patching metric when head (sl,sh)’s output is patched directly into head (dl,dh)’s query/key/value input.
Entries where dl <= sl are 0 (no causal path).
This runs O(n_layers * n_heads * n_layers * n_heads) forward passes and is intended for small models or targeted sub-sweeps. For large models prefer calling get_act_patch_direct_path per source head.