Coverage for transformer_lens/utilities/tracr.py: 88%

60 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Utilities for loading Tracr-assembled models into TransformerBridge. 

2 

3These helpers are intentionally duck-typed so importing TransformerLens does not 

4pull in Tracr, JAX, or Haiku. Pass in a Tracr ``AssembledTransformerModel`` from 

5``tracr.compiler.compiling.compile_rasp_to_model``. 

6""" 

7 

8from __future__ import annotations 

9 

10from collections.abc import Mapping 

11from typing import Any 

12 

13import numpy as np 

14import torch 

15 

16from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig 

17 

18 

19def infer_tracr_output_label(model: Any) -> str: 

20 """Infer the RASP output label used in ``model.residual_labels``. 

21 

22 Tracr stores residual basis labels as strings like ``"reverse:3"`` while 

23 its categorical output encoder stores only values and output-column ids. 

24 The output label is the unique residual-label prefix whose value set exactly 

25 matches the output encoder's categorical value set. 

26 """ 

27 encoding_map = _categorical_output_encoding_map(model) 

28 output_values = {str(value) for value in encoding_map} 

29 values_by_label: dict[str, set[str]] = {} 

30 

31 for residual_label in _residual_labels(model): 

32 label, separator, value = residual_label.rpartition(":") 

33 if not separator: 33 ↛ 34line 33 didn't jump to line 34 because the condition on line 33 was never true

34 continue 

35 values_by_label.setdefault(label, set()).add(value) 

36 

37 candidates = [label for label, values in values_by_label.items() if values == output_values] 

38 if len(candidates) != 1: 38 ↛ 43line 38 didn't jump to line 43 because the condition on line 38 was always true

39 raise ValueError( 

40 "Could not infer Tracr output label from residual labels. Pass " 

41 f"output_label explicitly. Candidates: {candidates!r}." 

42 ) 

43 return candidates[0] 

44 

45 

46def make_tracr_categorical_unembed( 

47 model: Any, 

48 output_label: str | None = None, 

49 *, 

50 dtype: np.dtype = np.dtype("float32"), 

51) -> np.ndarray: 

52 """Return Tracr's categorical unembed matrix in TL format. 

53 

54 The returned matrix has shape ``[d_model, d_vocab_out]`` and projects the 

55 residual stream onto the output-basis coordinates selected by Tracr's 

56 categorical output encoder. This avoids assuming those coordinates are the 

57 first residual dimensions. 

58 """ 

59 encoding_map = _categorical_output_encoding_map(model) 

60 output_label = output_label if output_label is not None else infer_tracr_output_label(model) 

61 label_to_residual_index = {label: index for index, label in enumerate(_residual_labels(model))} 

62 unembed = np.zeros((len(label_to_residual_index), len(encoding_map)), dtype=dtype) 

63 

64 for output_value, output_index in encoding_map.items(): 

65 residual_label = f"{output_label}:{output_value}" 

66 if residual_label not in label_to_residual_index: 66 ↛ 67line 66 didn't jump to line 67 because the condition on line 66 was never true

67 raise ValueError( 

68 f"Could not find output basis label {residual_label!r} in " 

69 "model.residual_labels. Pass the final RASP expression label " 

70 "as output_label." 

71 ) 

72 unembed[label_to_residual_index[residual_label], int(output_index)] = 1 

73 

74 return unembed 

75 

76 

77def make_tracr_transformer_bridge_config(model: Any) -> TransformerBridgeConfig: 

78 """Build a ``TransformerBridgeConfig`` matching a categorical Tracr model.""" 

79 model_config = model.model_config 

80 d_model = _param_array(model, "token_embed", "embeddings").shape[1] 

81 d_vocab = _param_array(model, "token_embed", "embeddings").shape[0] 

82 n_ctx = _param_array(model, "pos_embed", "embeddings").shape[0] 

83 

84 return TransformerBridgeConfig( 

85 n_layers=model_config.num_layers, 

86 d_model=d_model, 

87 d_head=model_config.key_size, 

88 n_ctx=n_ctx, 

89 d_vocab=d_vocab, 

90 d_vocab_out=len(_categorical_output_encoding_map(model)), 

91 d_mlp=model_config.mlp_hidden_size, 

92 n_heads=model_config.num_heads, 

93 act_fn="relu", 

94 attention_dir="causal" if model_config.causal else "bidirectional", 

95 normalization_type="LN" if model_config.layer_norm else None, 

96 ) 

97 

98 

99def make_tracr_transformer_bridge_state_dict( 

100 model: Any, 

101 output_label: str | None = None, 

102 *, 

103 dtype: torch.dtype = torch.float32, 

104) -> dict[str, torch.Tensor]: 

105 """Build a ``TransformerBridge.boot_native`` state dict from Tracr weights. 

106 

107 The state-dict keys use the native PyTorch module names accepted by 

108 ``TransformerBridge.load_state_dict`` after ``boot_native``. 

109 """ 

110 if model.model_config.layer_norm: 

111 raise NotImplementedError( 

112 "Tracr layer_norm=True models are not supported by this converter yet." 

113 ) 

114 

115 n_layers = model.model_config.num_layers 

116 state_dict: dict[str, torch.Tensor] = { 

117 "tok_embed.weight": _tensor(model, "token_embed", "embeddings", dtype=dtype), 

118 "pos.weight": _tensor(model, "pos_embed", "embeddings", dtype=dtype), 

119 "head.weight": torch.tensor( 

120 make_tracr_categorical_unembed(model, output_label).T, 

121 dtype=dtype, 

122 ), 

123 } 

124 

125 for layer in range(n_layers): 

126 prefix = f"transformer/layer_{layer}" 

127 state_dict.update( 

128 { 

129 f"layers.{layer}.attn.k.weight": _tensor( 

130 model, f"{prefix}/attn/key", "w", dtype=dtype 

131 ).T, 

132 f"layers.{layer}.attn.k.bias": _tensor( 

133 model, f"{prefix}/attn/key", "b", dtype=dtype 

134 ), 

135 f"layers.{layer}.attn.q.weight": _tensor( 

136 model, f"{prefix}/attn/query", "w", dtype=dtype 

137 ).T, 

138 f"layers.{layer}.attn.q.bias": _tensor( 

139 model, f"{prefix}/attn/query", "b", dtype=dtype 

140 ), 

141 f"layers.{layer}.attn.v.weight": _tensor( 

142 model, f"{prefix}/attn/value", "w", dtype=dtype 

143 ).T, 

144 f"layers.{layer}.attn.v.bias": _tensor( 

145 model, f"{prefix}/attn/value", "b", dtype=dtype 

146 ), 

147 f"layers.{layer}.attn.o.weight": _tensor( 

148 model, f"{prefix}/attn/linear", "w", dtype=dtype 

149 ).T, 

150 f"layers.{layer}.attn.o.bias": _tensor( 

151 model, f"{prefix}/attn/linear", "b", dtype=dtype 

152 ), 

153 f"layers.{layer}.mlp.fc_in.weight": _tensor( 

154 model, f"{prefix}/mlp/linear_1", "w", dtype=dtype 

155 ).T, 

156 f"layers.{layer}.mlp.fc_in.bias": _tensor( 

157 model, f"{prefix}/mlp/linear_1", "b", dtype=dtype 

158 ), 

159 f"layers.{layer}.mlp.fc_out.weight": _tensor( 

160 model, f"{prefix}/mlp/linear_2", "w", dtype=dtype 

161 ).T, 

162 f"layers.{layer}.mlp.fc_out.bias": _tensor( 

163 model, f"{prefix}/mlp/linear_2", "b", dtype=dtype 

164 ), 

165 } 

166 ) 

167 

168 return state_dict 

169 

170 

171def _categorical_output_encoding_map(model: Any) -> Mapping[Any, int]: 

172 output_encoder = getattr(model, "output_encoder", None) 

173 encoding_map = getattr(output_encoder, "encoding_map", None) 

174 if not isinstance(encoding_map, Mapping): 174 ↛ 175line 174 didn't jump to line 175 because the condition on line 174 was never true

175 raise NotImplementedError( 

176 "Only categorical Tracr outputs are supported; expected " 

177 "model.output_encoder.encoding_map." 

178 ) 

179 return encoding_map 

180 

181 

182def _residual_labels(model: Any) -> list[str]: 

183 residual_labels = getattr(model, "residual_labels", None) 

184 if residual_labels is None: 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true

185 raise ValueError("Expected Tracr model to expose residual_labels.") 

186 return list(residual_labels) 

187 

188 

189def _param_array(model: Any, module_name: str, param_name: str) -> np.ndarray: 

190 return np.asarray(model.params[module_name][param_name]) 

191 

192 

193def _tensor( 

194 model: Any, 

195 module_name: str, 

196 param_name: str, 

197 *, 

198 dtype: torch.dtype, 

199) -> torch.Tensor: 

200 return torch.tensor(_param_array(model, module_name, param_name), dtype=dtype)