Coverage for transformer_lens/model_bridge/generalized_components/linear.py: 64%

47 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Linear bridge component for wrapping linear layers with hook points.""" 

2from typing import Any, Dict, Mapping 

3 

4import einops 

5import torch 

6 

7from transformer_lens.model_bridge.generalized_components.base import ( 

8 GeneralizedComponent, 

9) 

10 

11 

12class LinearBridge(GeneralizedComponent): 

13 """Bridge component for linear layers. 

14 

15 This component wraps a linear layer (nn.Linear) and provides hook points 

16 for intercepting the input and output activations. 

17 

18 Note: For Conv1D layers (used in GPT-2 style models), use Conv1DBridge instead. 

19 """ 

20 

21 def forward(self, input: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: 

22 """Forward pass through the linear layer with hooks. 

23 

24 Args: 

25 input: Input tensor 

26 *args: Additional positional arguments 

27 **kwargs: Additional keyword arguments 

28 

29 Returns: 

30 Output tensor after linear transformation 

31 """ 

32 if self.original_component is None: 32 ↛ 33line 32 didn't jump to line 33 because the condition on line 32 was never true

33 raise RuntimeError( 

34 f"Original component not set for {self.name}. Call set_original_component() first." 

35 ) 

36 input = self.hook_in(input) 

37 output = self.original_component(input, *args, **kwargs) 

38 output = self.hook_out(output) 

39 return output 

40 

41 def __repr__(self) -> str: 

42 """String representation of the LinearBridge.""" 

43 if self.original_component is not None: 43 ↛ 52line 43 didn't jump to line 52 because the condition on line 43 was always true

44 try: 

45 in_features = self.original_component.in_features 

46 out_features = self.original_component.out_features 

47 bias = self.original_component.bias is not None 

48 return f"LinearBridge({in_features} -> {out_features}, bias={bias}, original_component={type(self.original_component).__name__})" 

49 except AttributeError: 

50 return f"LinearBridge(name={self.name}, original_component={type(self.original_component).__name__})" 

51 else: 

52 return f"LinearBridge(name={self.name}, original_component=None)" 

53 

54 def set_processed_weights( 

55 self, weights: Mapping[str, torch.Tensor | None], verbose: bool = False 

56 ) -> None: 

57 """Set the processed weights by loading them into the original component. 

58 

59 This loads the processed weights directly into the original_component's parameters, 

60 so when forward() delegates to original_component, it uses the processed weights. 

61 

62 Handles Linear layers (shape [out, in]). 

63 Also handles 3D weights [n_heads, d_model, d_head] by flattening them first. 

64 

65 Args: 

66 weights: Dictionary containing: 

67 - weight: The processed weight tensor. Can be: 

68 - 2D [in, out] format (will be transposed to [out, in] for Linear) 

69 - 3D [n_heads, d_model, d_head] format (will be flattened to 2D) 

70 - bias: The processed bias tensor (optional). Can be: 

71 - 1D [out] format 

72 - 2D [n_heads, d_head] format (will be flattened to 1D) 

73 verbose: If True, print detailed information about weight setting 

74 """ 

75 if verbose: 75 ↛ 76line 75 didn't jump to line 76 because the condition on line 75 was never true

76 print(f"\n set_processed_weights: LinearBridge (name={self.name})") 

77 print(f" Received {len(weights)} weight keys") 

78 

79 if self.original_component is None: 79 ↛ 80line 79 didn't jump to line 80 because the condition on line 79 was never true

80 raise RuntimeError(f"Original component not set for {self.name}") 

81 weight = weights.get("weight") 

82 if weight is None: 82 ↛ 83line 82 didn't jump to line 83 because the condition on line 82 was never true

83 raise ValueError("Processed weights for LinearBridge must include 'weight'.") 

84 bias = weights.get("bias") 

85 

86 if verbose: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true

87 print(f" Found weight key with shape: {weight.shape}") 

88 if bias is not None: 

89 print(f" Found bias key with shape: {bias.shape}") 

90 

91 # Flatten 3D→2D; contiguous() needed for correct bfloat16 matmul order 

92 if weight.ndim == 3: 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true

93 n_heads, dim1, dim2 = weight.shape 

94 if dim1 > dim2: 

95 # [n_heads, d_model, d_head] -> [n_heads * d_head, d_model] (nn.Linear format) 

96 weight = einops.rearrange( 

97 weight, "n_heads d_model d_head -> (n_heads d_head) d_model" 

98 ).contiguous() 

99 else: 

100 # [n_heads, d_head, d_model] -> [d_model, n_heads * d_head] 

101 weight = einops.rearrange( 

102 weight, "n_heads d_head d_model -> d_model (n_heads d_head)" 

103 ).contiguous() 

104 

105 # Handle 2D bias by flattening to 1D 

106 if bias is not None and bias.ndim == 2: 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true

107 bias = einops.rearrange(bias, "n_heads d_head -> (n_heads d_head)") 

108 

109 processed_weights: Dict[str, torch.Tensor] = { 

110 "weight": weight, 

111 } 

112 

113 if bias is not None: 

114 processed_weights["bias"] = bias 

115 

116 super().set_processed_weights(processed_weights, verbose=verbose)