@@ -58,11 +58,17 @@ import os
5858import threading
5959
6060if os.environ.get("LMCACHE_ROCM_DEMAND_PINNED_ALLOCATOR") == "1":
61- from lmcache.v1 import lazy_memory_allocator as _lazy_memory_allocator
61+ import builtins
62+ import sys
6263
63- _LazyMemoryAllocator = _lazy_memory_allocator.LazyMemoryAllocator
64+ _orig_import = builtins.__import__
65+
66+ def _patch_lazy_memory_allocator(_lazy_memory_allocator) -> None:
67+ _LazyMemoryAllocator = _lazy_memory_allocator.LazyMemoryAllocator
68+
69+ if getattr(_LazyMemoryAllocator, "_agentic_rocm_demand_patch", False):
70+ return
6471
65- if not getattr(_LazyMemoryAllocator, "_agentic_rocm_demand_patch", False):
6672 _orig_init = _LazyMemoryAllocator.__init__
6773 _orig_allocate = _LazyMemoryAllocator.allocate
6874 _orig_batched_allocate = _LazyMemoryAllocator.batched_allocate
@@ -150,6 +156,22 @@ if os.environ.get("LMCACHE_ROCM_DEMAND_PINNED_ALLOCATOR") == "1":
150156 _LazyMemoryAllocator.batched_allocate = _patched_batched_allocate
151157 _LazyMemoryAllocator._agentic_rocm_demand_patch = True
152158
159+ def _maybe_patch_lazy_memory_allocator() -> None:
160+ module = sys.modules.get("lmcache.v1.lazy_memory_allocator")
161+ if module is not None:
162+ _patch_lazy_memory_allocator(module)
163+
164+ def _agentic_rocm_import(name, globals=None, locals=None, fromlist=(), level=0):
165+ module = _orig_import(name, globals, locals, fromlist, level)
166+ if name == "lmcache.v1.lazy_memory_allocator" or (
167+ name.startswith("lmcache") and "lmcache.v1.lazy_memory_allocator" in sys.modules
168+ ):
169+ _maybe_patch_lazy_memory_allocator()
170+ return module
171+
172+ builtins.__import__ = _agentic_rocm_import
173+ _maybe_patch_lazy_memory_allocator()
174+
153175if os.environ.get("LMCACHE_ROCM_MP_BLOCK_FALLBACK") == "1":
154176 import torch
155177 import lmcache.non_cuda_equivalents as lmc
0 commit comments