Coverage for transformer_lens/model_bridge/compat.py: 62%

22 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Compatibility shims for transformers version differences. 

2 

3These patches are applied lazily (only when missing) so they're safe to call 

4from multiple adapters — the first caller wins, subsequent calls are no-ops. 

5 

6WARNING: patches here mutate classes from the installed `transformers` package 

7in place. They are process-global and persist for the entire Python session — 

8every model loaded afterward, including ones unrelated to the caller, sees the 

9patched class. This is acceptable because the shims only *add* v4-era methods 

10that v5 removed; they do not change v5 behavior. But it means a bug in a shim 

11affects the whole session, not just the adapter that invoked it. 

12 

13REMOVAL: drop the corresponding block (and its call sites) once the minimum 

14supported `transformers` version provides the method natively, or once all 

15remote-code models we support have been updated for v5. Track upstream status 

16against `transformers.cache_utils.DynamicCache` — when `from_legacy_cache`, 

17`to_legacy_cache`, and `get_usable_length` are restored or no longer needed, 

18`patch_dynamic_cache_v5` can be deleted outright. 

19""" 

20 

21 

22def patch_dynamic_cache_v5() -> None: 

23 """Backfill DynamicCache methods removed in transformers v5. 

24 

25 Remote-code models written for transformers v4 call from_legacy_cache, 

26 to_legacy_cache, and get_usable_length which were removed in v5. 

27 Call this from any adapter's prepare_loading() that needs them. 

28 

29 Side effect: mutates `transformers.cache_utils.DynamicCache` for the whole 

30 process. See module docstring. 

31 """ 

32 try: 

33 from transformers.cache_utils import DynamicCache 

34 except Exception: 

35 return 

36 

37 if not hasattr(DynamicCache, "from_legacy_cache"): 

38 

39 @classmethod # type: ignore[misc] 

40 def _from_legacy_cache(cls, past_key_values=None): # type: ignore[no-untyped-def] 

41 cache = cls() 

42 if past_key_values is not None: 

43 for idx, layer_past in enumerate(past_key_values): 

44 cache.update(layer_past[0], layer_past[1], idx) 

45 return cache 

46 

47 DynamicCache.from_legacy_cache = _from_legacy_cache # type: ignore[attr-defined] 

48 

49 if not hasattr(DynamicCache, "get_usable_length"): 

50 

51 def _get_usable_length(self, new_seq_len: int = 0, layer_idx: int = 0) -> int: # type: ignore[no-untyped-def] 

52 return self.get_seq_length(layer_idx) 

53 

54 DynamicCache.get_usable_length = _get_usable_length # type: ignore[attr-defined] 

55 

56 if not hasattr(DynamicCache, "to_legacy_cache"): 

57 

58 def _to_legacy_cache(self): # type: ignore[no-untyped-def] 

59 return tuple((layer.keys, layer.values) for layer in self.layers) 

60 

61 DynamicCache.to_legacy_cache = _to_legacy_cache # type: ignore[attr-defined]