Coverage for transformer_lens/model_bridge/generalized_components/pos_embed.py: 87%
29 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Positional embedding bridge component.
3This module contains the bridge component for positional embedding layers.
4"""
5from typing import Any, Dict, Optional
7import torch
9from transformer_lens.model_bridge.generalized_components.base import (
10 GeneralizedComponent,
11)
14class PosEmbedBridge(GeneralizedComponent):
15 """Positional embedding bridge that wraps transformer positional embedding layers.
17 This component provides standardized input/output hooks for positional embeddings.
18 """
20 property_aliases = {"W_pos": "weight"}
22 def __init__(
23 self,
24 name: str,
25 config: Optional[Any] = None,
26 submodules: Optional[Dict[str, GeneralizedComponent]] = {},
27 ):
28 """Initialize the positional embedding bridge.
30 Args:
31 name: The name of this component
32 config: Optional configuration (unused for PosEmbedBridge)
33 submodules: Dictionary of GeneralizedComponent submodules to register
34 """
35 super().__init__(name, config, submodules=submodules)
37 @property
38 def W_pos(self) -> torch.Tensor:
39 """Return the positional embedding weight matrix."""
40 if self.original_component is None: 40 ↛ 41line 40 didn't jump to line 41 because the condition on line 40 was never true
41 raise RuntimeError(f"Original component not set for {self.name}")
42 assert hasattr(
43 self.original_component, "weight"
44 ), f"Component {self.name} has no weight attribute"
45 weight = self.original_component.weight
46 assert isinstance(weight, torch.Tensor), f"Weight is not a tensor for {self.name}"
47 return weight
49 def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
50 """Forward pass through the positional embedding bridge.
52 This method accepts variable arguments to support different architectures:
53 - Standard models (GPT-2, GPT-Neo): (input_ids, position_ids=None)
54 - OPT models: (attention_mask, past_key_values_length=0, position_ids=None)
55 - Others may have different signatures
57 Args:
58 *args: Positional arguments forwarded to the original component
59 **kwargs: Keyword arguments forwarded to the original component
61 Returns:
62 Positional embeddings
63 """
64 if self.original_component is None: 64 ↛ 65line 64 didn't jump to line 65 because the condition on line 64 was never true
65 raise RuntimeError(
66 f"Original component not set for {self.name}. Call set_original_component() first."
67 )
68 if args: 68 ↛ 71line 68 didn't jump to line 71 because the condition on line 68 was always true
69 first_arg = self.hook_in(args[0])
70 args = (first_arg,) + args[1:]
71 output = self.original_component(*args, **kwargs)
73 # Expand batch=1 pos embeddings to match actual batch size for hooks.
74 batch_size = getattr(self, "_current_batch_size", None)
76 # Read-and-clear to avoid stale values during generate() steps.
77 if batch_size is not None:
78 self._current_batch_size = None
79 if (
80 batch_size is not None
81 and batch_size > 1
82 and isinstance(output, torch.Tensor)
83 and output.ndim >= 1
84 and output.shape[0] == 1
85 ):
86 output = output.expand(batch_size, *[-1] * (output.ndim - 1))
88 output = self.hook_out(output)
89 return output