transformer_lens.weight_processing module

Weight Processing Functions for Transformer Models.

This module contains all the weight processing functions extracted from HookedTransformer, organized into a single ProcessWeights class with static methods. These functions are used to modify transformer model weights for better interpretability and analysis.

class transformer_lens.weight_processing.ProcessWeights

Bases: object

A collection of static methods for processing transformer model weights.

These methods are extracted from HookedTransformer and provide various weight transformations for improved model interpretability: - LayerNorm folding: Merges LayerNorm parameters into subsequent linear layers - Weight centering: Centers weights that write to the residual stream - Unembed centering: Centers unembedding weights (translation invariant) - Value bias folding: Consolidates value biases into output biases - Attention matrix refactoring: Experimental QK/OV matrix factorization

When an architecture adapter is provided, the methods will translate TransformerLens parameter names to the target format (e.g., HuggingFace) for processing.

static center_attention_weights(wq_tensor: Tensor, wk_tensor: Tensor, wv_tensor: Tensor) tuple[Tensor, Tensor, Tensor]

Center attention weights by subtracting the mean.

Parameters:
  • wq_tensor – Weight tensors [n_heads, d_model, d_head]

  • wk_tensor – Weight tensors [n_heads, d_model, d_head]

  • wv_tensor – Weight tensors [n_heads, d_model, d_head]

Returns:

Tuple of (centered_wq, centered_wk, centered_wv)

static center_unembed(state_dict: Dict[str, Tensor], cfg=None, adapter=None) Dict[str, Tensor]

Center the unembedding weights W_U.

This is done by subtracting the mean of the weights from the weights themselves. This is done in-place. As softmax is translation invariant, this changes the logits but not the log probs, and makes the model logits (slightly) more interpretable - when trying to understand how components contribute to the logits, we’ll be less misled by components that just add something to every logit.

Parameters:
  • state_dict (Dict[str, torch.Tensor]) – State dict of the model.

  • cfg – Model configuration (used to determine d_vocab for correct centering dimension).

  • adapter – Optional architecture adapter for parameter key translation.

Returns:

Modified state dict with centered unembedding weights.

Return type:

Dict[str, torch.Tensor]

static center_weight_single(w_tensor: Tensor) Tensor

Center a single attention weight by subtracting the mean.

Parameters:

w_tensor – Weight tensor [n_heads, d_model, d_head]

Returns:

Centered weight tensor

static center_writing_weights(state_dict: Dict[str, Tensor], cfg, adapter=None) Dict[str, Tensor]

Center Writing Weights.

Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and W_out. This is done by subtracting the mean of the weights from the weights themselves. This is done in-place. See fold_layer_norm for more details.

Parameters:
  • state_dict (Dict[str, torch.Tensor]) – State dict of the model.

  • cfg – Model configuration object.

  • adapter – Optional architecture adapter for parameter key translation.

Returns:

Modified state dict with centered writing weights.

Return type:

Dict[str, torch.Tensor]

static convert_tensor_to_hf_format(param_name: str, tensor: Tensor, cfg: TransformerLensConfig | None, adapter: ArchitectureAdapter | None = None, layer_idx: int | None = None) Tensor
static convert_tensor_to_hf_format(param_name: str, tensor: None, cfg: TransformerLensConfig | None, adapter: ArchitectureAdapter | None = None, layer_idx: int | None = None) None

Convert a tensor from TransformerLens format back to its original format.

Parameters:
  • param_name – The parameter name in TransformerLens format (e.g., “blocks.0.attn.W_Q”)

  • tensor – The tensor to convert (in TransformerLens format), or None if parameter is optional

  • cfg – Model configuration

  • adapter – Optional architecture adapter for component retrieval and key translation. If None, the tensor is returned unchanged.

  • layer_idx – Layer index (required for layer-specific parameters)

Returns:

The tensor converted back to original format, or None if tensor was None. If adapter is None, returns the tensor unchanged.

static convert_tensor_to_tl_format(param_name: str, model_state_dict: Dict[str, Tensor], tensor: Tensor, cfg: TransformerLensConfig | None, adapter: ArchitectureAdapter | None = None, layer_idx: int | None = None) Tensor
static convert_tensor_to_tl_format(param_name: str, model_state_dict: Dict[str, Tensor], tensor: None, cfg: TransformerLensConfig | None, adapter: ArchitectureAdapter | None = None, layer_idx: int | None = None) None

Convert a tensor from its original format to TransformerLens format.

Parameters:
  • param_name – The parameter name in TransformerLens format (e.g., “blocks.0.attn.W_Q”)

  • model_state_dict – The model’s state dictionary containing the actual tensors

  • tensor – The tensor to convert, or None for optional parameters

  • cfg – Model configuration

  • adapter – Optional architecture adapter for component retrieval and key translation. If None, the tensor is returned unchanged.

  • layer_idx – Layer index (required for layer-specific parameters)

Returns:

The tensor converted to TransformerLens format, or None if the parameter doesn’t exist (which is valid for optional parameters like biases in models that don’t use them). If adapter is None, returns the tensor unchanged.

static distribute_weights_to_components(state_dict: Dict[str, Tensor], component_mapping: Dict[str, Any], verbose: bool = False) None

Distribute processed weights from state_dict to generalized components.

This function loops through the component_mapping and extracts relevant weights for each component using filter_dict_by_prefix, then calls set_processed_weights on each component. For list components (like blocks), it determines the number of items and distributes weights to each indexed component.

Parameters:
  • state_dict – Dictionary of processed weights in MODERN TransformerLens format (e.g., blocks.0.attn.q.weight, not transformer.h.0.attn.q.weight)

  • component_mapping – Dictionary (real_components) mapping TL keys to tuples of (remote_path, component_instance), where component_instance can be either a single component or a list of components

  • verbose – If True, print detailed information about weight distribution

Example

For a real_components mapping like: {

“embed”: (“transformer.wte”, <EmbeddingBridge instance>), “blocks”: (“transformer.h”, [<BlockBridge 0>, <BlockBridge 1>, …]), “unembed”: (“lm_head”, <UnembeddingBridge instance>)

}

With modern TL keys in state_dict like “embed.weight”, “blocks.0.attn.q.weight”: 1. Extract weights starting with “embed” and pass to embed component 2. For blocks, extract all “blocks.*” weights, determine the number of blocks,

then for each block index, extract weights for that specific block

  1. Extract “unembed” weights and pass to unembed component

static extract_attention_tensors_for_folding(state_dict: Dict[str, Tensor], cfg, layer: int, adapter) Dict[str, Tensor | None | Dict[str, str]]

Extract attention tensors in TransformerLens format for layer norm folding.

Parameters:
  • state_dict – The state dictionary containing tensors

  • cfg – Model configuration object

  • layer – Layer index

  • adapter – Optional architecture adapter for parameter key translation

Returns:

‘wq’, ‘wk’, ‘wv’, ‘bq’, ‘bk’, ‘bv’, ‘ln1_b’, ‘ln1_w’ All tensors are in TransformerLens format for consistent processing

Return type:

Dictionary with keys

static fold_layer_norm(state_dict: Dict[str, Tensor], cfg, fold_biases: bool = True, center_weights: bool = True, adapter=None) Dict[str, Tensor]

Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False.

Takes in a state dict from a pretrained model, formatted to be consistent with HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring weights. See further_comments.md for more details.

Parameters:
  • state_dict (Dict[str, torch.Tensor]) – State dict of pretrained model.

  • cfg – Model configuration object with n_layers, n_key_value_heads, etc.

  • fold_biases (bool) – Enables folding of LN biases. Should be disabled when RMS Norm is used.

  • center_weights (bool) – Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used.

  • adapter – Optional architecture adapter for parameter key translation.

Returns:

Modified state dict with LayerNorm folded into linear layers.

Return type:

Dict[str, torch.Tensor]

static fold_layer_norm_bias_single(w_tensor: Tensor, b_tensor: Tensor, ln_bias: Tensor) Tensor

Fold LayerNorm bias into a single attention bias.

Parameters:
  • w_tensor – Weight tensor [n_heads, d_model, d_head]

  • b_tensor – Bias tensor [n_heads, d_head]

  • ln_bias – LayerNorm bias [d_model]

Returns:

New bias tensor with folded LayerNorm bias

static fold_layer_norm_biases(wq_tensor: Tensor, wk_tensor: Tensor, wv_tensor: Tensor, bq_tensor: Tensor | None, bk_tensor: Tensor | None, bv_tensor: Tensor | None, ln_bias: Tensor) tuple[Tensor, Tensor, Tensor]

Fold LayerNorm bias into attention biases.

When QKV biases don’t exist (e.g., GPT-Neo), creates zero-initialized biases to absorb the LN bias contribution, similar to how MLP folding handles missing biases.

Parameters:
  • wq_tensor – Weight tensors [n_heads, d_model, d_head]

  • wk_tensor – Weight tensors [n_heads, d_model, d_head]

  • wv_tensor – Weight tensors [n_heads, d_model, d_head]

  • bq_tensor – Bias tensors [n_heads, d_head] or None if no bias

  • bk_tensor – Bias tensors [n_heads, d_head] or None if no bias

  • bv_tensor – Bias tensors [n_heads, d_head] or None if no bias

  • ln_bias – LayerNorm bias [d_model]

Returns:

Tuple of (new_bq, new_bk, new_bv) with folded biases (always non-None)

static fold_layer_norm_weight_single(w_tensor: Tensor, ln_weight: Tensor) Tensor

Fold LayerNorm weight into a single attention weight.

Parameters:
  • w_tensor – Weight tensor [n_heads, d_model, d_head]

  • ln_weight – LayerNorm weight [d_model]

Returns:

New weight tensor with folded LayerNorm weight

static fold_layer_norm_weights(wq_tensor: Tensor, wk_tensor: Tensor, wv_tensor: Tensor, ln_weight: Tensor) tuple[Tensor, Tensor, Tensor]

Fold LayerNorm weight into attention weights.

Parameters:
  • wq_tensor – Weight tensors [n_heads, d_model, d_head]

  • wk_tensor – Weight tensors [n_heads, d_model, d_head]

  • wv_tensor – Weight tensors [n_heads, d_model, d_head]

  • ln_weight – LayerNorm weight [d_model]

Returns:

Tuple of (new_wq, new_wk, new_wv) with folded weights

static fold_value_biases(state_dict: Dict[str, Tensor], cfg, adapter=None) Dict[str, Tensor]

Fold the value biases into the output bias.

Because attention patterns add up to 1, the value biases always have a constant effect on a head’s output. Further, as the outputs of each head in a layer add together, each head’s value bias has a constant effect on the layer’s output, which can make it harder to interpret the effect of any given head, and it doesn’t matter which head a bias is associated with. We can factor this all into a single output bias to the layer, and make it easier to interpret the head’s output. Formally, we take b_O_new = b_O_original + sum_head(b_V_head @ W_O_head).

Parameters:
  • state_dict (Dict[str, torch.Tensor]) – State dict of the model.

  • cfg – Model configuration object.

  • adapter – Optional architecture adapter for parameter key translation.

Returns:

Modified state dict with value biases folded into output bias.

Return type:

Dict[str, torch.Tensor]

static process_weights(state_dict: Dict[str, Tensor], cfg, fold_ln: bool = True, center_writing_weights: bool = True, center_unembed: bool = True, fold_value_biases: bool = True, refactor_factored_attn_matrices: bool = False, adapter=None) Dict[str, Tensor]

Apply all weight processing transformations in the correct order.

This is a convenience function that applies all the weight processing steps in the same order as HookedTransformer.load_and_process_state_dict().

Parameters:
  • state_dict (Dict[str, torch.Tensor]) – State dict of the model.

  • cfg – Model configuration object.

  • fold_ln (bool) – Whether to fold LayerNorm weights into subsequent layers.

  • center_writing_weights (bool) – Whether to center weights writing to residual stream.

  • center_unembed (bool) – Whether to center unembedding weights.

  • fold_value_biases (bool) – Whether to fold value biases into output bias.

  • refactor_factored_attn_matrices (bool) – Whether to refactor attention matrices.

  • adapter – Optional architecture adapter for parameter key translation.

Returns:

Fully processed state dict.

Return type:

Dict[str, torch.Tensor]

static refactor_factored_attn_matrices(state_dict: Dict[str, Tensor], cfg, adapter=None) Dict[str, Tensor]

Experimental method for managing queries, keys and values.

As argued in [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and values are somewhat arbitrary intermediate terms when computing with the low rank factored matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing determining head behaviour. But there are many ways to find a low rank factorization to a given matrix, and hopefully some of these are more interpretable than others! This method is one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a rotation and W_Q and W_K having the nth column in each having the same norm. The formula is $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$.

More details:

If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and $W_O=Vh.T$ works just as well. I think this is a more interpretable setup, because now $W_O$ is just a rotation, and doesn’t change the norm, so $z$ has the same norm as the result of the head.

For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$, which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the matrices as having the same norm, since there’s not an obvious asymmetry between the keys and queries.

Biases are more fiddly to deal with. For OV it’s pretty easy - we just need (x @ W_V + b_V) @ W_O + b_O to be preserved, so we can set b_V’ = 0. and b_O’ = b_V @ W_O + b_O (note that b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along the head_index dimension too).

For QK it’s messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K + b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD factorization on this effective matrix, then separate out into final weights and biases.

Parameters:
  • state_dict (Dict[str, torch.Tensor]) – State dict of the model.

  • cfg – Model configuration object.

  • adapter – Optional architecture adapter for parameter key translation.

Returns:

Modified state dict with refactored attention matrices.

Return type:

Dict[str, torch.Tensor]