Coverage for transformer_lens/model_bridge/generalized_components/embedding.py: 65%

44 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Embedding bridge component. 

2 

3This module contains the bridge component for embedding layers. 

4""" 

5import inspect 

6from typing import Any, Dict, Optional 

7 

8import torch 

9 

10from transformer_lens.model_bridge.generalized_components.base import ( 

11 GeneralizedComponent, 

12) 

13 

14 

15class EmbeddingBridge(GeneralizedComponent): 

16 """Embedding bridge that wraps transformer embedding layers. 

17 

18 This component provides standardized input/output hooks. 

19 """ 

20 

21 property_aliases = {"W_E": "e.weight", "W_pos": "pos.weight"} 

22 

23 def __init__( 

24 self, 

25 name: str, 

26 config: Optional[Any] = None, 

27 submodules: Optional[Dict[str, GeneralizedComponent]] = {}, 

28 ): 

29 """Initialize the embedding bridge. 

30 

31 Args: 

32 name: The name of this component 

33 config: Optional configuration (unused for EmbeddingBridge) 

34 submodules: Dictionary of GeneralizedComponent submodules to register 

35 """ 

36 super().__init__(name, config, submodules=submodules) 

37 

38 @property 

39 def W_E(self) -> torch.Tensor: 

40 """Return the embedding weight matrix.""" 

41 if hasattr(self, "_use_processed_weights") and self._use_processed_weights: 41 ↛ 42line 41 didn't jump to line 42 because the condition on line 41 was never true

42 if hasattr(self, "_processed_weight"): 

43 return self._processed_weight 

44 if self.original_component is None: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true

45 raise RuntimeError(f"Original component not set for {self.name}") 

46 if hasattr(self.original_component, "inv_freq") and ( 46 ↛ 49line 46 didn't jump to line 49 because the condition on line 46 was never true

47 not hasattr(self.original_component, "weight") 

48 ): 

49 inv_freq = self.original_component.inv_freq 

50 assert isinstance(inv_freq, torch.Tensor), f"inv_freq is not a tensor for {self.name}" 

51 return inv_freq 

52 assert hasattr( 

53 self.original_component, "weight" 

54 ), f"Component {self.name} has neither weight nor inv_freq attribute" 

55 weight = self.original_component.weight 

56 assert isinstance(weight, torch.Tensor), f"Weight is not a tensor for {self.name}" 

57 return weight 

58 

59 def forward( 

60 self, input_ids: torch.Tensor, position_ids: torch.Tensor | None = None, **kwargs: Any 

61 ) -> torch.Tensor: 

62 """Forward pass through the embedding bridge. 

63 

64 Args: 

65 input_ids: Input token IDs 

66 position_ids: Optional position IDs (ignored if not supported) 

67 **kwargs: Additional arguments to pass to the original component 

68 

69 Returns: 

70 Embedded output 

71 """ 

72 if self.original_component is None: 72 ↛ 73line 72 didn't jump to line 73 because the condition on line 72 was never true

73 raise RuntimeError( 

74 f"Original component not set for {self.name}. Call set_original_component() first." 

75 ) 

76 target_dtype = None 

77 try: 

78 target_dtype = next(self.original_component.parameters()).dtype 

79 except StopIteration: 

80 pass 

81 input_ids = self.hook_in(input_ids) 

82 sig = inspect.signature(self.original_component.forward) 

83 supports_position_ids = "position_ids" in sig.parameters 

84 if not hasattr(self.original_component, "forward") or not supports_position_ids: 84 ↛ 88line 84 didn't jump to line 88 because the condition on line 84 was always true

85 kwargs.pop("position_ids", None) 

86 output = self.original_component(input_ids, **kwargs) 

87 else: 

88 output = self.original_component(input_ids, position_ids=position_ids, **kwargs) 

89 if isinstance(output, tuple): 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true

90 output = output[0] 

91 if target_dtype is not None and output.dtype != target_dtype: 91 ↛ 92line 91 didn't jump to line 92 because the condition on line 91 was never true

92 output = output.to(dtype=target_dtype) 

93 output = self.hook_out(output) 

94 return output