Coverage for transformer_lens/model_bridge/generalized_components/embedding.py: 65%
44 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Embedding bridge component.
3This module contains the bridge component for embedding layers.
4"""
5import inspect
6from typing import Any, Dict, Optional
8import torch
10from transformer_lens.model_bridge.generalized_components.base import (
11 GeneralizedComponent,
12)
15class EmbeddingBridge(GeneralizedComponent):
16 """Embedding bridge that wraps transformer embedding layers.
18 This component provides standardized input/output hooks.
19 """
21 property_aliases = {"W_E": "e.weight", "W_pos": "pos.weight"}
23 def __init__(
24 self,
25 name: str,
26 config: Optional[Any] = None,
27 submodules: Optional[Dict[str, GeneralizedComponent]] = {},
28 ):
29 """Initialize the embedding bridge.
31 Args:
32 name: The name of this component
33 config: Optional configuration (unused for EmbeddingBridge)
34 submodules: Dictionary of GeneralizedComponent submodules to register
35 """
36 super().__init__(name, config, submodules=submodules)
38 @property
39 def W_E(self) -> torch.Tensor:
40 """Return the embedding weight matrix."""
41 if hasattr(self, "_use_processed_weights") and self._use_processed_weights: 41 ↛ 42line 41 didn't jump to line 42 because the condition on line 41 was never true
42 if hasattr(self, "_processed_weight"):
43 return self._processed_weight
44 if self.original_component is None: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true
45 raise RuntimeError(f"Original component not set for {self.name}")
46 if hasattr(self.original_component, "inv_freq") and ( 46 ↛ 49line 46 didn't jump to line 49 because the condition on line 46 was never true
47 not hasattr(self.original_component, "weight")
48 ):
49 inv_freq = self.original_component.inv_freq
50 assert isinstance(inv_freq, torch.Tensor), f"inv_freq is not a tensor for {self.name}"
51 return inv_freq
52 assert hasattr(
53 self.original_component, "weight"
54 ), f"Component {self.name} has neither weight nor inv_freq attribute"
55 weight = self.original_component.weight
56 assert isinstance(weight, torch.Tensor), f"Weight is not a tensor for {self.name}"
57 return weight
59 def forward(
60 self, input_ids: torch.Tensor, position_ids: torch.Tensor | None = None, **kwargs: Any
61 ) -> torch.Tensor:
62 """Forward pass through the embedding bridge.
64 Args:
65 input_ids: Input token IDs
66 position_ids: Optional position IDs (ignored if not supported)
67 **kwargs: Additional arguments to pass to the original component
69 Returns:
70 Embedded output
71 """
72 if self.original_component is None: 72 ↛ 73line 72 didn't jump to line 73 because the condition on line 72 was never true
73 raise RuntimeError(
74 f"Original component not set for {self.name}. Call set_original_component() first."
75 )
76 target_dtype = None
77 try:
78 target_dtype = next(self.original_component.parameters()).dtype
79 except StopIteration:
80 pass
81 input_ids = self.hook_in(input_ids)
82 sig = inspect.signature(self.original_component.forward)
83 supports_position_ids = "position_ids" in sig.parameters
84 if not hasattr(self.original_component, "forward") or not supports_position_ids: 84 ↛ 88line 84 didn't jump to line 88 because the condition on line 84 was always true
85 kwargs.pop("position_ids", None)
86 output = self.original_component(input_ids, **kwargs)
87 else:
88 output = self.original_component(input_ids, position_ids=position_ids, **kwargs)
89 if isinstance(output, tuple): 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true
90 output = output[0]
91 if target_dtype is not None and output.dtype != target_dtype: 91 ↛ 92line 91 didn't jump to line 92 because the condition on line 91 was never true
92 output = output.to(dtype=target_dtype)
93 output = self.hook_out(output)
94 return output