Coverage for transformer_lens/model_bridge/compat.py: 62%
22 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Compatibility shims for transformers version differences.
3These patches are applied lazily (only when missing) so they're safe to call
4from multiple adapters — the first caller wins, subsequent calls are no-ops.
6WARNING: patches here mutate classes from the installed `transformers` package
7in place. They are process-global and persist for the entire Python session —
8every model loaded afterward, including ones unrelated to the caller, sees the
9patched class. This is acceptable because the shims only *add* v4-era methods
10that v5 removed; they do not change v5 behavior. But it means a bug in a shim
11affects the whole session, not just the adapter that invoked it.
13REMOVAL: drop the corresponding block (and its call sites) once the minimum
14supported `transformers` version provides the method natively, or once all
15remote-code models we support have been updated for v5. Track upstream status
16against `transformers.cache_utils.DynamicCache` — when `from_legacy_cache`,
17`to_legacy_cache`, and `get_usable_length` are restored or no longer needed,
18`patch_dynamic_cache_v5` can be deleted outright.
19"""
22def patch_dynamic_cache_v5() -> None:
23 """Backfill DynamicCache methods removed in transformers v5.
25 Remote-code models written for transformers v4 call from_legacy_cache,
26 to_legacy_cache, and get_usable_length which were removed in v5.
27 Call this from any adapter's prepare_loading() that needs them.
29 Side effect: mutates `transformers.cache_utils.DynamicCache` for the whole
30 process. See module docstring.
31 """
32 try:
33 from transformers.cache_utils import DynamicCache
34 except Exception:
35 return
37 if not hasattr(DynamicCache, "from_legacy_cache"):
39 @classmethod # type: ignore[misc]
40 def _from_legacy_cache(cls, past_key_values=None): # type: ignore[no-untyped-def]
41 cache = cls()
42 if past_key_values is not None:
43 for idx, layer_past in enumerate(past_key_values):
44 cache.update(layer_past[0], layer_past[1], idx)
45 return cache
47 DynamicCache.from_legacy_cache = _from_legacy_cache # type: ignore[attr-defined]
49 if not hasattr(DynamicCache, "get_usable_length"):
51 def _get_usable_length(self, new_seq_len: int = 0, layer_idx: int = 0) -> int: # type: ignore[no-untyped-def]
52 return self.get_seq_length(layer_idx)
54 DynamicCache.get_usable_length = _get_usable_length # type: ignore[attr-defined]
56 if not hasattr(DynamicCache, "to_legacy_cache"):
58 def _to_legacy_cache(self): # type: ignore[no-untyped-def]
59 return tuple((layer.keys, layer.values) for layer in self.layers)
61 DynamicCache.to_legacy_cache = _to_legacy_cache # type: ignore[attr-defined]