Coverage for transformer_lens/model_bridge/generalized_components/bloom_mlp.py: 38%

18 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""BLOOM-specific MLP bridge component. 

2 

3BLOOM MLP requires a special 'residual' argument that standard MLPBridge doesn't handle. 

4This custom component passes the residual argument through to the original component. 

5""" 

6from typing import Any, Dict, Optional 

7 

8import torch 

9 

10from transformer_lens.model_bridge.generalized_components.base import ( 

11 GeneralizedComponent, 

12) 

13from transformer_lens.model_bridge.generalized_components.mlp import MLPBridge 

14 

15 

16class BloomMLPBridge(MLPBridge): 

17 """MLP bridge for BLOOM models that handles residual connections. 

18 

19 BLOOM MLP has a unique forward signature that requires: 

20 - hidden_states (first positional arg) 

21 - residual (keyword arg): The residual connection tensor 

22 

23 This bridge ensures the residual argument is properly passed through. 

24 """ 

25 

26 def __init__( 

27 self, 

28 name: Optional[str], 

29 config: Optional[Any] = None, 

30 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

31 ): 

32 """Initialize the BLOOM MLP bridge. 

33 

34 Args: 

35 name: The name of the component in the model 

36 config: Optional configuration 

37 submodules: Dictionary of submodules to register (e.g., dense_h_to_4h, dense_4h_to_h) 

38 """ 

39 super().__init__(name, config, submodules or {}) 

40 

41 def forward(self, *args: Any, **kwargs: Any) -> Any: 

42 """Forward pass through BLOOM MLP with hooks. 

43 

44 BLOOM MLP requires these arguments: 

45 - hidden_states (first positional arg) 

46 - residual (second positional arg) 

47 

48 Args: 

49 *args: Input arguments (hidden_states, residual) 

50 **kwargs: Additional keyword arguments (if any) 

51 

52 Returns: 

53 Output tensor from BLOOM MLP 

54 """ 

55 if self.original_component is None: 

56 raise RuntimeError( 

57 f"Original component not set for {self.name}. Call set_original_component() first." 

58 ) 

59 

60 # Apply hook_in to hidden_states (first positional argument) 

61 if len(args) > 0 and isinstance(args[0], torch.Tensor): 

62 hooked_input = self.hook_in(args[0]) 

63 args = (hooked_input,) + args[1:] 

64 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor): 

65 kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"]) 

66 

67 # BLOOM MLP requires residual as second positional arg 

68 # The original BLOOM block passes it, so we just pass everything through 

69 # No need to validate since the original component will handle it 

70 

71 # Call the original BLOOM MLP component with all arguments 

72 output = self.original_component(*args, **kwargs) 

73 

74 # Apply hook_out 

75 output = self.hook_out(output) 

76 

77 return output