Coverage for transformer_lens/model_bridge/generalized_components/linear.py: 64%
47 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Linear bridge component for wrapping linear layers with hook points."""
2from typing import Any, Dict, Mapping
4import einops
5import torch
7from transformer_lens.model_bridge.generalized_components.base import (
8 GeneralizedComponent,
9)
12class LinearBridge(GeneralizedComponent):
13 """Bridge component for linear layers.
15 This component wraps a linear layer (nn.Linear) and provides hook points
16 for intercepting the input and output activations.
18 Note: For Conv1D layers (used in GPT-2 style models), use Conv1DBridge instead.
19 """
21 def forward(self, input: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
22 """Forward pass through the linear layer with hooks.
24 Args:
25 input: Input tensor
26 *args: Additional positional arguments
27 **kwargs: Additional keyword arguments
29 Returns:
30 Output tensor after linear transformation
31 """
32 if self.original_component is None: 32 ↛ 33line 32 didn't jump to line 33 because the condition on line 32 was never true
33 raise RuntimeError(
34 f"Original component not set for {self.name}. Call set_original_component() first."
35 )
36 input = self.hook_in(input)
37 output = self.original_component(input, *args, **kwargs)
38 output = self.hook_out(output)
39 return output
41 def __repr__(self) -> str:
42 """String representation of the LinearBridge."""
43 if self.original_component is not None: 43 ↛ 52line 43 didn't jump to line 52 because the condition on line 43 was always true
44 try:
45 in_features = self.original_component.in_features
46 out_features = self.original_component.out_features
47 bias = self.original_component.bias is not None
48 return f"LinearBridge({in_features} -> {out_features}, bias={bias}, original_component={type(self.original_component).__name__})"
49 except AttributeError:
50 return f"LinearBridge(name={self.name}, original_component={type(self.original_component).__name__})"
51 else:
52 return f"LinearBridge(name={self.name}, original_component=None)"
54 def set_processed_weights(
55 self, weights: Mapping[str, torch.Tensor | None], verbose: bool = False
56 ) -> None:
57 """Set the processed weights by loading them into the original component.
59 This loads the processed weights directly into the original_component's parameters,
60 so when forward() delegates to original_component, it uses the processed weights.
62 Handles Linear layers (shape [out, in]).
63 Also handles 3D weights [n_heads, d_model, d_head] by flattening them first.
65 Args:
66 weights: Dictionary containing:
67 - weight: The processed weight tensor. Can be:
68 - 2D [in, out] format (will be transposed to [out, in] for Linear)
69 - 3D [n_heads, d_model, d_head] format (will be flattened to 2D)
70 - bias: The processed bias tensor (optional). Can be:
71 - 1D [out] format
72 - 2D [n_heads, d_head] format (will be flattened to 1D)
73 verbose: If True, print detailed information about weight setting
74 """
75 if verbose: 75 ↛ 76line 75 didn't jump to line 76 because the condition on line 75 was never true
76 print(f"\n set_processed_weights: LinearBridge (name={self.name})")
77 print(f" Received {len(weights)} weight keys")
79 if self.original_component is None: 79 ↛ 80line 79 didn't jump to line 80 because the condition on line 79 was never true
80 raise RuntimeError(f"Original component not set for {self.name}")
81 weight = weights.get("weight")
82 if weight is None: 82 ↛ 83line 82 didn't jump to line 83 because the condition on line 82 was never true
83 raise ValueError("Processed weights for LinearBridge must include 'weight'.")
84 bias = weights.get("bias")
86 if verbose: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true
87 print(f" Found weight key with shape: {weight.shape}")
88 if bias is not None:
89 print(f" Found bias key with shape: {bias.shape}")
91 # Flatten 3D→2D; contiguous() needed for correct bfloat16 matmul order
92 if weight.ndim == 3: 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true
93 n_heads, dim1, dim2 = weight.shape
94 if dim1 > dim2:
95 # [n_heads, d_model, d_head] -> [n_heads * d_head, d_model] (nn.Linear format)
96 weight = einops.rearrange(
97 weight, "n_heads d_model d_head -> (n_heads d_head) d_model"
98 ).contiguous()
99 else:
100 # [n_heads, d_head, d_model] -> [d_model, n_heads * d_head]
101 weight = einops.rearrange(
102 weight, "n_heads d_head d_model -> d_model (n_heads d_head)"
103 ).contiguous()
105 # Handle 2D bias by flattening to 1D
106 if bias is not None and bias.ndim == 2: 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true
107 bias = einops.rearrange(bias, "n_heads d_head -> (n_heads d_head)")
109 processed_weights: Dict[str, torch.Tensor] = {
110 "weight": weight,
111 }
113 if bias is not None:
114 processed_weights["bias"] = bias
116 super().set_processed_weights(processed_weights, verbose=verbose)