Coverage for transformer_lens/model_bridge/generalized_components/position_embedding_hooks_mixin.py: 95%
17 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Mixin for position embedding hooks (cos/sin) shared across attention bridges."""
2from __future__ import annotations
4from typing import Any
6from transformer_lens.hook_points import HookPoint
9class PositionEmbeddingHooksMixin:
10 """Mixin providing hook_cos/hook_sin and _apply_position_embedding_hooks().
12 Used by both PositionEmbeddingsAttentionBridge and
13 JointQKVPositionEmbeddingsAttentionBridge to avoid duplicating this logic.
14 """
16 def _init_position_embedding_hooks(self):
17 """Initialize rotary embedding state and hooks. Call from __init__."""
18 self._rotary_emb = None
19 self.hook_cos = HookPoint()
20 self.hook_sin = HookPoint()
22 def set_rotary_emb(self, rotary_emb: Any) -> None:
23 """Set reference to the model's rotary embedding component."""
24 self._rotary_emb = rotary_emb
26 def _apply_position_embedding_hooks(self, position_embeddings):
27 """Apply hook_cos/hook_sin to a (cos, sin) position embeddings tuple."""
28 if isinstance(position_embeddings, tuple) and len(position_embeddings) == 2: 28 ↛ 33line 28 didn't jump to line 33 because the condition on line 28 was always true
29 cos, sin = position_embeddings
30 hooked_cos = self.hook_cos(cos)
31 hooked_sin = self.hook_sin(sin)
32 return (hooked_cos, hooked_sin)
33 return position_embeddings