Coverage for transformer_lens/model_bridge/generalized_components/pos_embed.py: 87%

29 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Positional embedding bridge component. 

2 

3This module contains the bridge component for positional embedding layers. 

4""" 

5from typing import Any, Dict, Optional 

6 

7import torch 

8 

9from transformer_lens.model_bridge.generalized_components.base import ( 

10 GeneralizedComponent, 

11) 

12 

13 

14class PosEmbedBridge(GeneralizedComponent): 

15 """Positional embedding bridge that wraps transformer positional embedding layers. 

16 

17 This component provides standardized input/output hooks for positional embeddings. 

18 """ 

19 

20 property_aliases = {"W_pos": "weight"} 

21 

22 def __init__( 

23 self, 

24 name: str, 

25 config: Optional[Any] = None, 

26 submodules: Optional[Dict[str, GeneralizedComponent]] = {}, 

27 ): 

28 """Initialize the positional embedding bridge. 

29 

30 Args: 

31 name: The name of this component 

32 config: Optional configuration (unused for PosEmbedBridge) 

33 submodules: Dictionary of GeneralizedComponent submodules to register 

34 """ 

35 super().__init__(name, config, submodules=submodules) 

36 

37 @property 

38 def W_pos(self) -> torch.Tensor: 

39 """Return the positional embedding weight matrix.""" 

40 if self.original_component is None: 40 ↛ 41line 40 didn't jump to line 41 because the condition on line 40 was never true

41 raise RuntimeError(f"Original component not set for {self.name}") 

42 assert hasattr( 

43 self.original_component, "weight" 

44 ), f"Component {self.name} has no weight attribute" 

45 weight = self.original_component.weight 

46 assert isinstance(weight, torch.Tensor), f"Weight is not a tensor for {self.name}" 

47 return weight 

48 

49 def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: 

50 """Forward pass through the positional embedding bridge. 

51 

52 This method accepts variable arguments to support different architectures: 

53 - Standard models (GPT-2, GPT-Neo): (input_ids, position_ids=None) 

54 - OPT models: (attention_mask, past_key_values_length=0, position_ids=None) 

55 - Others may have different signatures 

56 

57 Args: 

58 *args: Positional arguments forwarded to the original component 

59 **kwargs: Keyword arguments forwarded to the original component 

60 

61 Returns: 

62 Positional embeddings 

63 """ 

64 if self.original_component is None: 64 ↛ 65line 64 didn't jump to line 65 because the condition on line 64 was never true

65 raise RuntimeError( 

66 f"Original component not set for {self.name}. Call set_original_component() first." 

67 ) 

68 if args: 68 ↛ 71line 68 didn't jump to line 71 because the condition on line 68 was always true

69 first_arg = self.hook_in(args[0]) 

70 args = (first_arg,) + args[1:] 

71 output = self.original_component(*args, **kwargs) 

72 

73 # Expand batch=1 pos embeddings to match actual batch size for hooks. 

74 batch_size = getattr(self, "_current_batch_size", None) 

75 

76 # Read-and-clear to avoid stale values during generate() steps. 

77 if batch_size is not None: 

78 self._current_batch_size = None 

79 if ( 

80 batch_size is not None 

81 and batch_size > 1 

82 and isinstance(output, torch.Tensor) 

83 and output.ndim >= 1 

84 and output.shape[0] == 1 

85 ): 

86 output = output.expand(batch_size, *[-1] * (output.ndim - 1)) 

87 

88 output = self.hook_out(output) 

89 return output