Coverage for transformer_lens/model_bridge/generalized_components/mlp.py: 86%

23 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""MLP bridge component. 

2 

3This module contains the bridge component for MLP layers. 

4""" 

5from typing import Any, Dict, Optional 

6 

7import torch 

8 

9from transformer_lens.model_bridge.generalized_components.base import ( 

10 GeneralizedComponent, 

11) 

12 

13 

14class MLPBridge(GeneralizedComponent): 

15 """Bridge component for MLP layers. 

16 

17 This component wraps an MLP layer from a remote model and provides a consistent interface 

18 for accessing its weights and performing MLP operations. 

19 """ 

20 

21 hook_aliases = {"hook_pre": "in.hook_out", "hook_post": "out.hook_in"} 

22 property_aliases = { 

23 "W_gate": "gate.weight", 

24 "b_gate": "gate.bias", 

25 "W_in": "in.weight", 

26 "b_in": "in.bias", 

27 "W_out": "out.weight", 

28 "b_out": "out.bias", 

29 } 

30 

31 def __init__( 

32 self, 

33 name: Optional[str], 

34 config: Optional[Any] = None, 

35 submodules: Optional[Dict[str, GeneralizedComponent]] = {}, 

36 optional: bool = False, 

37 ): 

38 """Initialize the MLP bridge. 

39 

40 Args: 

41 name: The name of the component in the model (None if no container exists) 

42 config: Optional configuration (unused for MLPBridge) 

43 submodules: Dictionary of submodules to register (e.g., gate_proj, up_proj, down_proj) 

44 optional: If True, setup skips this bridge when absent (hybrid architectures). 

45 """ 

46 super().__init__(name, config, submodules=submodules, optional=optional) 

47 

48 def forward(self, *args, **kwargs) -> torch.Tensor: 

49 """Forward pass through the MLP bridge. 

50 

51 Args: 

52 *args: Positional arguments for the original component 

53 **kwargs: Keyword arguments for the original component 

54 

55 Returns: 

56 Output hidden states 

57 """ 

58 hidden_states = args[0] 

59 hidden_states = self.hook_in(hidden_states) 

60 in_module = getattr(self, "in", None) or getattr(self, "input", None) 

61 if in_module is not None and hasattr(in_module, "hook_in"): 61 ↛ 63line 61 didn't jump to line 63 because the condition on line 61 was always true

62 hidden_states = in_module.hook_in(hidden_states) # type: ignore[misc] 

63 new_args = (hidden_states,) + args[1:] 

64 original_component = self.original_component 

65 if original_component is None: 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true

66 raise RuntimeError( 

67 f"Original component not set for {self.name}. Call set_original_component() first." 

68 ) 

69 output = original_component(*new_args, **kwargs) 

70 output = self.hook_out(output) 

71 if hasattr(self, "out") and hasattr(self.out, "hook_out"): 71 ↛ 73line 71 didn't jump to line 73 because the condition on line 71 was always true

72 output = self.out.hook_out(output) 

73 return output