Skip to content

Commit 0103241

Browse files
committed
fix(agentic): lazily patch ROCm LMCache allocator
1 parent 20d6508 commit 0103241

1 file changed

Lines changed: 25 additions & 3 deletions

File tree

benchmarks/single_node/agentic/kimik2.5_fp4_mi355x.sh

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,17 @@ import os
5858
import threading
5959
6060
if os.environ.get("LMCACHE_ROCM_DEMAND_PINNED_ALLOCATOR") == "1":
61-
from lmcache.v1 import lazy_memory_allocator as _lazy_memory_allocator
61+
import builtins
62+
import sys
6263
63-
_LazyMemoryAllocator = _lazy_memory_allocator.LazyMemoryAllocator
64+
_orig_import = builtins.__import__
65+
66+
def _patch_lazy_memory_allocator(_lazy_memory_allocator) -> None:
67+
_LazyMemoryAllocator = _lazy_memory_allocator.LazyMemoryAllocator
68+
69+
if getattr(_LazyMemoryAllocator, "_agentic_rocm_demand_patch", False):
70+
return
6471
65-
if not getattr(_LazyMemoryAllocator, "_agentic_rocm_demand_patch", False):
6672
_orig_init = _LazyMemoryAllocator.__init__
6773
_orig_allocate = _LazyMemoryAllocator.allocate
6874
_orig_batched_allocate = _LazyMemoryAllocator.batched_allocate
@@ -150,6 +156,22 @@ if os.environ.get("LMCACHE_ROCM_DEMAND_PINNED_ALLOCATOR") == "1":
150156
_LazyMemoryAllocator.batched_allocate = _patched_batched_allocate
151157
_LazyMemoryAllocator._agentic_rocm_demand_patch = True
152158
159+
def _maybe_patch_lazy_memory_allocator() -> None:
160+
module = sys.modules.get("lmcache.v1.lazy_memory_allocator")
161+
if module is not None:
162+
_patch_lazy_memory_allocator(module)
163+
164+
def _agentic_rocm_import(name, globals=None, locals=None, fromlist=(), level=0):
165+
module = _orig_import(name, globals, locals, fromlist, level)
166+
if name == "lmcache.v1.lazy_memory_allocator" or (
167+
name.startswith("lmcache") and "lmcache.v1.lazy_memory_allocator" in sys.modules
168+
):
169+
_maybe_patch_lazy_memory_allocator()
170+
return module
171+
172+
builtins.__import__ = _agentic_rocm_import
173+
_maybe_patch_lazy_memory_allocator()
174+
153175
if os.environ.get("LMCACHE_ROCM_MP_BLOCK_FALLBACK") == "1":
154176
import torch
155177
import lmcache.non_cuda_equivalents as lmc

0 commit comments

Comments
 (0)