Coverage for transformer_lens/model_bridge/generalized_components/position_embedding_hooks_mixin.py: 95%

17 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Mixin for position embedding hooks (cos/sin) shared across attention bridges.""" 

2from __future__ import annotations 

3 

4from typing import Any 

5 

6from transformer_lens.hook_points import HookPoint 

7 

8 

9class PositionEmbeddingHooksMixin: 

10 """Mixin providing hook_cos/hook_sin and _apply_position_embedding_hooks(). 

11 

12 Used by both PositionEmbeddingsAttentionBridge and 

13 JointQKVPositionEmbeddingsAttentionBridge to avoid duplicating this logic. 

14 """ 

15 

16 def _init_position_embedding_hooks(self): 

17 """Initialize rotary embedding state and hooks. Call from __init__.""" 

18 self._rotary_emb = None 

19 self.hook_cos = HookPoint() 

20 self.hook_sin = HookPoint() 

21 

22 def set_rotary_emb(self, rotary_emb: Any) -> None: 

23 """Set reference to the model's rotary embedding component.""" 

24 self._rotary_emb = rotary_emb 

25 

26 def _apply_position_embedding_hooks(self, position_embeddings): 

27 """Apply hook_cos/hook_sin to a (cos, sin) position embeddings tuple.""" 

28 if isinstance(position_embeddings, tuple) and len(position_embeddings) == 2: 28 ↛ 33line 28 didn't jump to line 33 because the condition on line 28 was always true

29 cos, sin = position_embeddings 

30 hooked_cos = self.hook_cos(cos) 

31 hooked_sin = self.hook_sin(sin) 

32 return (hooked_cos, hooked_sin) 

33 return position_embeddings