Coverage for transformer_lens/tools/analysis/direct_logit_attribution.py: 97%
52 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Direct Logit Attribution (DLA).
3Direct Logit Attribution decomposes a model's output logit (or a logit
4*difference* between a correct and an incorrect token) into the additive
5contributions of upstream components — the embedding, each attention and MLP
6sublayer, or each individual attention head. Because the unembedding is linear
7and the residual stream is a sum of component outputs, the final logit is
8(up to the final LayerNorm) a sum of per-component dot products with the
9unembedding direction of the token of interest. DLA reads off those dot
10products. See the `logit lens
11<https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>`_
12and `Interpretability in the Wild <https://arxiv.org/abs/2211.00593>`_ for the
13canonical uses.
15This module exposes a single entry point, :func:`direct_logit_attribution`,
16that wraps the lower-level ``ActivationCache`` primitives
17(:meth:`~transformer_lens.ActivationCache.ActivationCache.decompose_resid`,
18:meth:`~transformer_lens.ActivationCache.ActivationCache.accumulated_resid`,
19:meth:`~transformer_lens.ActivationCache.ActivationCache.stack_head_results`
20and :meth:`~transformer_lens.ActivationCache.ActivationCache.logit_attrs`) into
21one call. It works unchanged with both ``HookedTransformer`` and
22``TransformerBridge`` because they share the cache API.
24Example::
26 from transformer_lens import HookedTransformer
27 from transformer_lens.tools.analysis import direct_logit_attribution
29 model = HookedTransformer.from_pretrained("gpt2", device="cpu")
30 result = direct_logit_attribution(
31 model,
32 "The Eiffel Tower is in the city of",
33 answer_tokens=" Paris",
34 incorrect_tokens=" London",
35 unit="component",
36 )
37 for label, value in zip(result.labels, result.attribution.squeeze()):
38 print(f"{label:>12}: {value.item():+.3f}")
39"""
41from dataclasses import dataclass
42from typing import List, Optional, Union
44import torch
45from jaxtyping import Float
47from transformer_lens.ActivationCache import ActivationCache
48from transformer_lens.utilities import SliceInput
50# Token-like inputs accepted for the correct/incorrect answers, mirroring
51# ActivationCache.logit_attrs.
52TokenInput = Union[
53 str,
54 int,
55 torch.Tensor,
56]
58# Which structural unit the residual stream is decomposed into.
59Unit = str # one of: "component", "layer", "head"
61_VALID_UNITS = ("component", "layer", "head")
63# Block variants that lack the attn_out + mlp_out structure decompose_resid expects.
64# When TransformerBridge.layer_types() reports any of these we refuse early — the
65# downstream decompose_resid would otherwise raise a confusing KeyError.
66_HYBRID_VARIANT_NAMES = ("mamba", "ssm", "mixer", "linear_attn")
69@dataclass
70class DirectLogitAttribution:
71 """Result of a :func:`direct_logit_attribution` call.
73 Attributes:
74 attribution:
75 Tensor of logit (or logit-difference) attributions with shape
76 ``[component, *batch_and_pos]``. The leading axis is aligned with
77 ``labels``. When ``pos`` selects a single position (the default) the
78 position axis is dropped, leaving ``[component, batch]`` — or
79 ``[component]`` if the cache had its batch dimension removed.
80 labels:
81 Human-readable name for each component, aligned with the leading
82 axis of ``attribution`` (e.g. ``"embed"``, ``"0_attn_out"``,
83 ``"L3H7"``).
84 unit:
85 The decomposition unit used ("component", "layer", or "head").
86 """
88 attribution: Float[torch.Tensor, "component *batch_and_pos"]
89 labels: List[str]
90 unit: Unit
92 def top(self, k: int = 5) -> List[tuple]:
93 """Return the ``k`` highest-attribution ``(label, value)`` pairs.
95 Attribution is reduced to a scalar per component by meaning over any
96 remaining batch/position dimensions, so this is most meaningful when a
97 single position was selected.
98 """
99 flat = self.attribution
100 if flat.ndim > 1: 100 ↛ 102line 100 didn't jump to line 102 because the condition on line 100 was always true
101 flat = flat.flatten(start_dim=1).mean(dim=-1)
102 values, indices = torch.topk(flat, min(k, flat.shape[0]))
103 return [(self.labels[i], values[j].item()) for j, i in enumerate(indices.tolist())]
106def _residual_stack_and_labels(
107 cache: ActivationCache,
108 unit: Unit,
109 pos_slice: SliceInput,
110):
111 """Decompose the residual stream into ``unit`` components plus labels.
113 LayerNorm is intentionally *not* applied here — ``logit_attrs`` applies the
114 final-layer scaling itself, so applying it twice would double-count.
115 """
116 if unit == "component":
117 # embed (+ pos_embed) and each layer's attn_out / mlp_out.
118 return cache.decompose_resid(apply_ln=False, pos_slice=pos_slice, return_labels=True)
119 if unit == "layer":
120 # Cumulative residual stream after each sublayer — logit-lens style.
121 return cache.accumulated_resid(
122 apply_ln=False, incl_mid=True, pos_slice=pos_slice, return_labels=True
123 )
124 if unit == "head": 124 ↛ 129line 124 didn't jump to line 129 because the condition on line 124 was always true
125 # Each attention head's contribution, plus the MLP/embedding remainder.
126 return cache.stack_head_results(
127 apply_ln=False, pos_slice=pos_slice, incl_remainder=True, return_labels=True
128 )
129 raise ValueError(f"unit must be one of {_VALID_UNITS}, got {unit!r}")
132def _validate_bridge_compatibility(model) -> None:
133 """Reject Bridge inputs that DLA can't produce correct numbers for.
135 HookedTransformer always has LN folded into W_U, so these checks only fire
136 for TransformerBridge. The compatibility-mode check catches a silent-
137 correctness footgun: without folded LN, the projection direction in
138 ``logit_attrs`` is wrong on a Bridge. The hybrid-arch check catches Mamba/
139 SSM blocks early with a clear error rather than letting ``decompose_resid``
140 raise a confusing KeyError downstream.
141 """
142 # Lazy import — keeps the module importable without dragging in the bridge.
143 from transformer_lens.model_bridge import TransformerBridge
145 if not isinstance(model, TransformerBridge):
146 return
148 if not getattr(model, "compatibility_mode", False):
149 raise ValueError(
150 "DLA on a TransformerBridge requires compatibility mode so that LayerNorm "
151 "weights are folded into W_U. Call `model.enable_compatibility_mode()` "
152 "after loading the bridge, then re-run DLA."
153 )
155 layer_types = model.layer_types()
156 hybrid = [lt for lt in layer_types if any(p in _HYBRID_VARIANT_NAMES for p in lt.split("+"))]
157 if hybrid:
158 raise NotImplementedError(
159 f"DLA does not yet support hybrid architectures (found block types {hybrid}). "
160 f"Only standard attention + MLP transformers (e.g. GPT-2, LLaMA, Pythia) are "
161 f"supported; hybrid support requires extending ActivationCache.decompose_resid."
162 )
165def direct_logit_attribution(
166 model,
167 input: Union[str, List[str], torch.Tensor, None] = None,
168 answer_tokens: Optional[TokenInput] = None,
169 incorrect_tokens: Optional[TokenInput] = None,
170 *,
171 unit: Unit = "component",
172 pos: SliceInput = -1,
173 cache: Optional[ActivationCache] = None,
174) -> DirectLogitAttribution:
175 """Compute Direct Logit Attribution for a prompt.
177 Decomposes the contribution of model components to the logit of
178 ``answer_tokens`` (or, if ``incorrect_tokens`` is given, to the logit
179 *difference* ``answer - incorrect`` along the ``W_U`` direction, which is
180 usually what you want for circuit analysis).
182 The model is run once with caching unless a precomputed ``cache`` is passed.
183 Works with both ``HookedTransformer`` and ``TransformerBridge``.
185 Note that DLA attributes only the part of a logit that comes from the
186 residual stream through the unembedding direction; the unembedding bias
187 ``b_U`` is a per-token constant that no component produces. So a complete
188 decomposition reconstructs ``logit[token] - b_U[token]`` rather than the raw
189 logit.
191 On a ``TransformerBridge``, compatibility mode must be enabled (so the final
192 LayerNorm is folded into ``W_U``) — otherwise the projection direction is
193 wrong and DLA returns silently incorrect numbers. Hybrid architectures
194 (Mamba/SSM/Mixer/LinearAttention) are not yet supported because
195 ``decompose_resid`` only understands the ``attn_out + mlp_out`` block layout;
196 both conditions raise an explicit error at call time.
198 Args:
199 model:
200 A ``HookedTransformer`` or ``TransformerBridge`` (the latter with
201 ``enable_compatibility_mode()`` already called).
202 input:
203 Prompt to run — a string, list of strings, or token tensor. Optional
204 only when a precomputed ``cache`` is supplied.
205 answer_tokens:
206 The correct token(s) to attribute, as a string, id, or tensor. A
207 string is converted with ``model.to_single_token``.
208 incorrect_tokens:
209 Optional baseline token(s). When given, attribution is computed for
210 the ``answer - incorrect`` residual direction. Must broadcast to the
211 same shape as ``answer_tokens``.
212 unit:
213 Decomposition granularity:
215 - ``"component"`` (default): embedding + each layer's attention and
216 MLP output (via ``decompose_resid``).
217 - ``"layer"``: cumulative residual stream after each sublayer, i.e.
218 logit-lens trajectory (via ``accumulated_resid``).
219 - ``"head"``: each attention head individually, plus a remainder
220 term for everything else (via ``stack_head_results``).
221 pos:
222 Sequence position(s) to attribute. Defaults to ``-1`` (the final
223 token, the usual choice for next-token DLA). Pass ``None`` to keep
224 every position (the result then has a trailing position axis).
225 cache:
226 Optional precomputed ``ActivationCache`` to reuse instead of running
227 the model again.
229 Returns:
230 A :class:`DirectLogitAttribution` with ``attribution`` (shape
231 ``[component, *batch_and_pos]``) and aligned ``labels``.
233 Raises:
234 ValueError: If ``unit`` is invalid, ``answer_tokens`` is ``None``,
235 neither ``input`` nor ``cache`` is provided, or a
236 ``TransformerBridge`` is passed without compatibility mode enabled.
237 NotImplementedError: If a ``TransformerBridge`` reports a hybrid block
238 layout (Mamba/SSM/Mixer/LinearAttention).
239 """
240 if unit not in _VALID_UNITS:
241 raise ValueError(f"unit must be one of {_VALID_UNITS}, got {unit!r}")
242 if answer_tokens is None:
243 raise ValueError("answer_tokens is required")
245 _validate_bridge_compatibility(model)
247 if cache is None:
248 if input is None:
249 raise ValueError("provide either `input` to run the model, or a precomputed `cache`")
250 _, cache = model.run_with_cache(input)
252 residual_stack, labels = _residual_stack_and_labels(cache, unit, pos)
254 # logit_attrs applies the final LayerNorm scaling (with the same pos slice)
255 # and dots each component against the (correct - incorrect) unembed direction.
256 attribution = cache.logit_attrs(
257 residual_stack,
258 tokens=answer_tokens,
259 incorrect_tokens=incorrect_tokens,
260 pos_slice=pos,
261 has_batch_dim=cache.has_batch_dim,
262 )
264 return DirectLogitAttribution(attribution=attribution, labels=labels, unit=unit)