Coverage for transformer_lens/tools/analysis/direct_logit_attribution.py: 97%

52 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Direct Logit Attribution (DLA). 

2 

3Direct Logit Attribution decomposes a model's output logit (or a logit 

4*difference* between a correct and an incorrect token) into the additive 

5contributions of upstream components — the embedding, each attention and MLP 

6sublayer, or each individual attention head. Because the unembedding is linear 

7and the residual stream is a sum of component outputs, the final logit is 

8(up to the final LayerNorm) a sum of per-component dot products with the 

9unembedding direction of the token of interest. DLA reads off those dot 

10products. See the `logit lens 

11<https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>`_ 

12and `Interpretability in the Wild <https://arxiv.org/abs/2211.00593>`_ for the 

13canonical uses. 

14 

15This module exposes a single entry point, :func:`direct_logit_attribution`, 

16that wraps the lower-level ``ActivationCache`` primitives 

17(:meth:`~transformer_lens.ActivationCache.ActivationCache.decompose_resid`, 

18:meth:`~transformer_lens.ActivationCache.ActivationCache.accumulated_resid`, 

19:meth:`~transformer_lens.ActivationCache.ActivationCache.stack_head_results` 

20and :meth:`~transformer_lens.ActivationCache.ActivationCache.logit_attrs`) into 

21one call. It works unchanged with both ``HookedTransformer`` and 

22``TransformerBridge`` because they share the cache API. 

23 

24Example:: 

25 

26 from transformer_lens import HookedTransformer 

27 from transformer_lens.tools.analysis import direct_logit_attribution 

28 

29 model = HookedTransformer.from_pretrained("gpt2", device="cpu") 

30 result = direct_logit_attribution( 

31 model, 

32 "The Eiffel Tower is in the city of", 

33 answer_tokens=" Paris", 

34 incorrect_tokens=" London", 

35 unit="component", 

36 ) 

37 for label, value in zip(result.labels, result.attribution.squeeze()): 

38 print(f"{label:>12}: {value.item():+.3f}") 

39""" 

40 

41from dataclasses import dataclass 

42from typing import List, Optional, Union 

43 

44import torch 

45from jaxtyping import Float 

46 

47from transformer_lens.ActivationCache import ActivationCache 

48from transformer_lens.utilities import SliceInput 

49 

50# Token-like inputs accepted for the correct/incorrect answers, mirroring 

51# ActivationCache.logit_attrs. 

52TokenInput = Union[ 

53 str, 

54 int, 

55 torch.Tensor, 

56] 

57 

58# Which structural unit the residual stream is decomposed into. 

59Unit = str # one of: "component", "layer", "head" 

60 

61_VALID_UNITS = ("component", "layer", "head") 

62 

63# Block variants that lack the attn_out + mlp_out structure decompose_resid expects. 

64# When TransformerBridge.layer_types() reports any of these we refuse early — the 

65# downstream decompose_resid would otherwise raise a confusing KeyError. 

66_HYBRID_VARIANT_NAMES = ("mamba", "ssm", "mixer", "linear_attn") 

67 

68 

69@dataclass 

70class DirectLogitAttribution: 

71 """Result of a :func:`direct_logit_attribution` call. 

72 

73 Attributes: 

74 attribution: 

75 Tensor of logit (or logit-difference) attributions with shape 

76 ``[component, *batch_and_pos]``. The leading axis is aligned with 

77 ``labels``. When ``pos`` selects a single position (the default) the 

78 position axis is dropped, leaving ``[component, batch]`` — or 

79 ``[component]`` if the cache had its batch dimension removed. 

80 labels: 

81 Human-readable name for each component, aligned with the leading 

82 axis of ``attribution`` (e.g. ``"embed"``, ``"0_attn_out"``, 

83 ``"L3H7"``). 

84 unit: 

85 The decomposition unit used ("component", "layer", or "head"). 

86 """ 

87 

88 attribution: Float[torch.Tensor, "component *batch_and_pos"] 

89 labels: List[str] 

90 unit: Unit 

91 

92 def top(self, k: int = 5) -> List[tuple]: 

93 """Return the ``k`` highest-attribution ``(label, value)`` pairs. 

94 

95 Attribution is reduced to a scalar per component by meaning over any 

96 remaining batch/position dimensions, so this is most meaningful when a 

97 single position was selected. 

98 """ 

99 flat = self.attribution 

100 if flat.ndim > 1: 100 ↛ 102line 100 didn't jump to line 102 because the condition on line 100 was always true

101 flat = flat.flatten(start_dim=1).mean(dim=-1) 

102 values, indices = torch.topk(flat, min(k, flat.shape[0])) 

103 return [(self.labels[i], values[j].item()) for j, i in enumerate(indices.tolist())] 

104 

105 

106def _residual_stack_and_labels( 

107 cache: ActivationCache, 

108 unit: Unit, 

109 pos_slice: SliceInput, 

110): 

111 """Decompose the residual stream into ``unit`` components plus labels. 

112 

113 LayerNorm is intentionally *not* applied here — ``logit_attrs`` applies the 

114 final-layer scaling itself, so applying it twice would double-count. 

115 """ 

116 if unit == "component": 

117 # embed (+ pos_embed) and each layer's attn_out / mlp_out. 

118 return cache.decompose_resid(apply_ln=False, pos_slice=pos_slice, return_labels=True) 

119 if unit == "layer": 

120 # Cumulative residual stream after each sublayer — logit-lens style. 

121 return cache.accumulated_resid( 

122 apply_ln=False, incl_mid=True, pos_slice=pos_slice, return_labels=True 

123 ) 

124 if unit == "head": 124 ↛ 129line 124 didn't jump to line 129 because the condition on line 124 was always true

125 # Each attention head's contribution, plus the MLP/embedding remainder. 

126 return cache.stack_head_results( 

127 apply_ln=False, pos_slice=pos_slice, incl_remainder=True, return_labels=True 

128 ) 

129 raise ValueError(f"unit must be one of {_VALID_UNITS}, got {unit!r}") 

130 

131 

132def _validate_bridge_compatibility(model) -> None: 

133 """Reject Bridge inputs that DLA can't produce correct numbers for. 

134 

135 HookedTransformer always has LN folded into W_U, so these checks only fire 

136 for TransformerBridge. The compatibility-mode check catches a silent- 

137 correctness footgun: without folded LN, the projection direction in 

138 ``logit_attrs`` is wrong on a Bridge. The hybrid-arch check catches Mamba/ 

139 SSM blocks early with a clear error rather than letting ``decompose_resid`` 

140 raise a confusing KeyError downstream. 

141 """ 

142 # Lazy import — keeps the module importable without dragging in the bridge. 

143 from transformer_lens.model_bridge import TransformerBridge 

144 

145 if not isinstance(model, TransformerBridge): 

146 return 

147 

148 if not getattr(model, "compatibility_mode", False): 

149 raise ValueError( 

150 "DLA on a TransformerBridge requires compatibility mode so that LayerNorm " 

151 "weights are folded into W_U. Call `model.enable_compatibility_mode()` " 

152 "after loading the bridge, then re-run DLA." 

153 ) 

154 

155 layer_types = model.layer_types() 

156 hybrid = [lt for lt in layer_types if any(p in _HYBRID_VARIANT_NAMES for p in lt.split("+"))] 

157 if hybrid: 

158 raise NotImplementedError( 

159 f"DLA does not yet support hybrid architectures (found block types {hybrid}). " 

160 f"Only standard attention + MLP transformers (e.g. GPT-2, LLaMA, Pythia) are " 

161 f"supported; hybrid support requires extending ActivationCache.decompose_resid." 

162 ) 

163 

164 

165def direct_logit_attribution( 

166 model, 

167 input: Union[str, List[str], torch.Tensor, None] = None, 

168 answer_tokens: Optional[TokenInput] = None, 

169 incorrect_tokens: Optional[TokenInput] = None, 

170 *, 

171 unit: Unit = "component", 

172 pos: SliceInput = -1, 

173 cache: Optional[ActivationCache] = None, 

174) -> DirectLogitAttribution: 

175 """Compute Direct Logit Attribution for a prompt. 

176 

177 Decomposes the contribution of model components to the logit of 

178 ``answer_tokens`` (or, if ``incorrect_tokens`` is given, to the logit 

179 *difference* ``answer - incorrect`` along the ``W_U`` direction, which is 

180 usually what you want for circuit analysis). 

181 

182 The model is run once with caching unless a precomputed ``cache`` is passed. 

183 Works with both ``HookedTransformer`` and ``TransformerBridge``. 

184 

185 Note that DLA attributes only the part of a logit that comes from the 

186 residual stream through the unembedding direction; the unembedding bias 

187 ``b_U`` is a per-token constant that no component produces. So a complete 

188 decomposition reconstructs ``logit[token] - b_U[token]`` rather than the raw 

189 logit. 

190 

191 On a ``TransformerBridge``, compatibility mode must be enabled (so the final 

192 LayerNorm is folded into ``W_U``) — otherwise the projection direction is 

193 wrong and DLA returns silently incorrect numbers. Hybrid architectures 

194 (Mamba/SSM/Mixer/LinearAttention) are not yet supported because 

195 ``decompose_resid`` only understands the ``attn_out + mlp_out`` block layout; 

196 both conditions raise an explicit error at call time. 

197 

198 Args: 

199 model: 

200 A ``HookedTransformer`` or ``TransformerBridge`` (the latter with 

201 ``enable_compatibility_mode()`` already called). 

202 input: 

203 Prompt to run — a string, list of strings, or token tensor. Optional 

204 only when a precomputed ``cache`` is supplied. 

205 answer_tokens: 

206 The correct token(s) to attribute, as a string, id, or tensor. A 

207 string is converted with ``model.to_single_token``. 

208 incorrect_tokens: 

209 Optional baseline token(s). When given, attribution is computed for 

210 the ``answer - incorrect`` residual direction. Must broadcast to the 

211 same shape as ``answer_tokens``. 

212 unit: 

213 Decomposition granularity: 

214 

215 - ``"component"`` (default): embedding + each layer's attention and 

216 MLP output (via ``decompose_resid``). 

217 - ``"layer"``: cumulative residual stream after each sublayer, i.e. 

218 logit-lens trajectory (via ``accumulated_resid``). 

219 - ``"head"``: each attention head individually, plus a remainder 

220 term for everything else (via ``stack_head_results``). 

221 pos: 

222 Sequence position(s) to attribute. Defaults to ``-1`` (the final 

223 token, the usual choice for next-token DLA). Pass ``None`` to keep 

224 every position (the result then has a trailing position axis). 

225 cache: 

226 Optional precomputed ``ActivationCache`` to reuse instead of running 

227 the model again. 

228 

229 Returns: 

230 A :class:`DirectLogitAttribution` with ``attribution`` (shape 

231 ``[component, *batch_and_pos]``) and aligned ``labels``. 

232 

233 Raises: 

234 ValueError: If ``unit`` is invalid, ``answer_tokens`` is ``None``, 

235 neither ``input`` nor ``cache`` is provided, or a 

236 ``TransformerBridge`` is passed without compatibility mode enabled. 

237 NotImplementedError: If a ``TransformerBridge`` reports a hybrid block 

238 layout (Mamba/SSM/Mixer/LinearAttention). 

239 """ 

240 if unit not in _VALID_UNITS: 

241 raise ValueError(f"unit must be one of {_VALID_UNITS}, got {unit!r}") 

242 if answer_tokens is None: 

243 raise ValueError("answer_tokens is required") 

244 

245 _validate_bridge_compatibility(model) 

246 

247 if cache is None: 

248 if input is None: 

249 raise ValueError("provide either `input` to run the model, or a precomputed `cache`") 

250 _, cache = model.run_with_cache(input) 

251 

252 residual_stack, labels = _residual_stack_and_labels(cache, unit, pos) 

253 

254 # logit_attrs applies the final LayerNorm scaling (with the same pos slice) 

255 # and dots each component against the (correct - incorrect) unembed direction. 

256 attribution = cache.logit_attrs( 

257 residual_stack, 

258 tokens=answer_tokens, 

259 incorrect_tokens=incorrect_tokens, 

260 pos_slice=pos, 

261 has_batch_dim=cache.has_batch_dim, 

262 ) 

263 

264 return DirectLogitAttribution(attribution=attribution, labels=labels, unit=unit)