Coverage for transformer_lens/model_bridge/generalized_components/qwen3_5_vision_encoder.py: 82%

18 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Qwen3.5 vision-tower bridges (``model.visual``). 

2 

3The Qwen vision tower differs from SigLIP/CLIP, so it needs its own bridge. The merger 

4(vision->text projector) is bridged separately as the adapter's ``vision_projector``, and 

5the paramless ``rotary_pos_emb`` is left native. 

6""" 

7from typing import Any, Dict, Optional 

8 

9from transformer_lens.model_bridge.generalized_components.base import ( 

10 GeneralizedComponent, 

11) 

12from transformer_lens.model_bridge.generalized_components.linear import LinearBridge 

13 

14 

15class Qwen3_5VisionBlockBridge(GeneralizedComponent): 

16 """Bridge for a single Qwen3.5 vision block. 

17 

18 Norms stay black-box (hooked, not recomputed): NormalizationBridge would recompute 

19 with the wrong eps and break parity. 

20 """ 

21 

22 is_list_item: bool = True 

23 hook_aliases = { 

24 "hook_resid_pre": "hook_in", 

25 "hook_resid_post": "hook_out", 

26 "hook_attn_in": "attn.hook_in", 

27 "hook_attn_out": "attn.hook_out", 

28 "hook_mlp_in": "mlp.hook_in", 

29 "hook_mlp_out": "mlp.hook_out", 

30 } 

31 

32 def __init__( 

33 self, 

34 name: str, 

35 config: Optional[Any] = None, 

36 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

37 ): 

38 default_submodules: Dict[str, GeneralizedComponent] = { 

39 "norm1": GeneralizedComponent(name="norm1"), 

40 "norm2": GeneralizedComponent(name="norm2"), 

41 "attn": GeneralizedComponent( 

42 name="attn", 

43 submodules={ 

44 "qkv": LinearBridge(name="qkv"), 

45 "proj": LinearBridge(name="proj"), 

46 }, 

47 ), 

48 "mlp": GeneralizedComponent( 

49 name="mlp", 

50 submodules={ 

51 "linear_fc1": LinearBridge(name="linear_fc1"), 

52 "linear_fc2": LinearBridge(name="linear_fc2"), 

53 }, 

54 ), 

55 } 

56 if submodules: 56 ↛ 57line 56 didn't jump to line 57 because the condition on line 56 was never true

57 default_submodules.update(submodules) 

58 super().__init__(name, config, submodules=default_submodules) 

59 

60 

61class Qwen3_5VisionEncoderBridge(GeneralizedComponent): 

62 """Bridge for the Qwen3.5 vision tower (``model.visual``); merger is bridged separately.""" 

63 

64 hook_aliases = { 

65 "hook_vision_embed": "patch_embed.hook_out", 

66 "hook_vision_out": "hook_out", 

67 } 

68 

69 def __init__( 

70 self, 

71 name: str, 

72 config: Optional[Any] = None, 

73 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

74 ): 

75 default_submodules: Dict[str, GeneralizedComponent] = { 

76 "patch_embed": GeneralizedComponent(name="patch_embed"), 

77 "pos_embed": GeneralizedComponent(name="pos_embed"), 

78 "blocks": Qwen3_5VisionBlockBridge(name="blocks"), 

79 } 

80 if submodules: 80 ↛ 81line 80 didn't jump to line 81 because the condition on line 80 was never true

81 default_submodules.update(submodules) 

82 super().__init__(name, config, submodules=default_submodules)