Skip to content

Commit 69cdbc2

Browse files
committed
fix(agentic): use final LMCache capacity on ROCm
1 parent e80a843 commit 69cdbc2

1 file changed

Lines changed: 34 additions & 0 deletions

File tree

benchmarks/single_node/agentic/kimik2.5_fp4_mi355x.sh

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,21 +163,55 @@ if os.environ.get("LMCACHE_ROCM_DEMAND_PINNED_ALLOCATOR") == "1":
163163
_LazyMemoryAllocator.batched_allocate = _patched_batched_allocate
164164
_LazyMemoryAllocator._agentic_rocm_demand_patch = True
165165
166+
def _patch_l1_memory_manager(_memory_manager) -> None:
167+
_L1MemoryManager = getattr(_memory_manager, "L1MemoryManager", None)
168+
_LazyMemoryAllocator = getattr(_memory_manager, "LazyMemoryAllocator", None)
169+
if _L1MemoryManager is None or _LazyMemoryAllocator is None:
170+
return
171+
if getattr(_L1MemoryManager, "_agentic_rocm_final_capacity_patch", False):
172+
return
173+
174+
_orig_get_memory_usage = _L1MemoryManager.get_memory_usage
175+
176+
def _patched_get_memory_usage(self):
177+
allocator = getattr(self, "_allocator", None)
178+
if isinstance(allocator, _LazyMemoryAllocator):
179+
address_manager = allocator.get_address_manager()
180+
used_size = (
181+
address_manager.get_heap_size() - address_manager.get_free_size()
182+
)
183+
return used_size, allocator._final_size
184+
return _orig_get_memory_usage(self)
185+
186+
_L1MemoryManager.get_memory_usage = _patched_get_memory_usage
187+
_L1MemoryManager._agentic_rocm_final_capacity_patch = True
188+
166189
def _maybe_patch_lazy_memory_allocator() -> None:
167190
module = sys.modules.get("lmcache.v1.lazy_memory_allocator")
168191
if module is not None and hasattr(module, "LazyMemoryAllocator"):
169192
_patch_lazy_memory_allocator(module)
170193
194+
def _maybe_patch_l1_memory_manager() -> None:
195+
module = sys.modules.get("lmcache.v1.distributed.memory_manager")
196+
if module is not None and hasattr(module, "L1MemoryManager"):
197+
_patch_l1_memory_manager(module)
198+
171199
def _agentic_rocm_import(name, globals=None, locals=None, fromlist=(), level=0):
172200
module = _orig_import(name, globals, locals, fromlist, level)
173201
if name == "lmcache.v1.lazy_memory_allocator" or (
174202
name.startswith("lmcache") and "lmcache.v1.lazy_memory_allocator" in sys.modules
175203
):
176204
_maybe_patch_lazy_memory_allocator()
205+
if name == "lmcache.v1.distributed.memory_manager" or (
206+
name.startswith("lmcache")
207+
and "lmcache.v1.distributed.memory_manager" in sys.modules
208+
):
209+
_maybe_patch_l1_memory_manager()
177210
return module
178211
179212
builtins.__import__ = _agentic_rocm_import
180213
_maybe_patch_lazy_memory_allocator()
214+
_maybe_patch_l1_memory_manager()
181215
182216
if os.environ.get("LMCACHE_ROCM_MP_BLOCK_FALLBACK") == "1":
183217
import torch

0 commit comments

Comments
 (0)