@@ -163,21 +163,55 @@ if os.environ.get("LMCACHE_ROCM_DEMAND_PINNED_ALLOCATOR") == "1":
163163 _LazyMemoryAllocator.batched_allocate = _patched_batched_allocate
164164 _LazyMemoryAllocator._agentic_rocm_demand_patch = True
165165
166+ def _patch_l1_memory_manager(_memory_manager) -> None:
167+ _L1MemoryManager = getattr(_memory_manager, "L1MemoryManager", None)
168+ _LazyMemoryAllocator = getattr(_memory_manager, "LazyMemoryAllocator", None)
169+ if _L1MemoryManager is None or _LazyMemoryAllocator is None:
170+ return
171+ if getattr(_L1MemoryManager, "_agentic_rocm_final_capacity_patch", False):
172+ return
173+
174+ _orig_get_memory_usage = _L1MemoryManager.get_memory_usage
175+
176+ def _patched_get_memory_usage(self):
177+ allocator = getattr(self, "_allocator", None)
178+ if isinstance(allocator, _LazyMemoryAllocator):
179+ address_manager = allocator.get_address_manager()
180+ used_size = (
181+ address_manager.get_heap_size() - address_manager.get_free_size()
182+ )
183+ return used_size, allocator._final_size
184+ return _orig_get_memory_usage(self)
185+
186+ _L1MemoryManager.get_memory_usage = _patched_get_memory_usage
187+ _L1MemoryManager._agentic_rocm_final_capacity_patch = True
188+
166189 def _maybe_patch_lazy_memory_allocator() -> None:
167190 module = sys.modules.get("lmcache.v1.lazy_memory_allocator")
168191 if module is not None and hasattr(module, "LazyMemoryAllocator"):
169192 _patch_lazy_memory_allocator(module)
170193
194+ def _maybe_patch_l1_memory_manager() -> None:
195+ module = sys.modules.get("lmcache.v1.distributed.memory_manager")
196+ if module is not None and hasattr(module, "L1MemoryManager"):
197+ _patch_l1_memory_manager(module)
198+
171199 def _agentic_rocm_import(name, globals=None, locals=None, fromlist=(), level=0):
172200 module = _orig_import(name, globals, locals, fromlist, level)
173201 if name == "lmcache.v1.lazy_memory_allocator" or (
174202 name.startswith("lmcache") and "lmcache.v1.lazy_memory_allocator" in sys.modules
175203 ):
176204 _maybe_patch_lazy_memory_allocator()
205+ if name == "lmcache.v1.distributed.memory_manager" or (
206+ name.startswith("lmcache")
207+ and "lmcache.v1.distributed.memory_manager" in sys.modules
208+ ):
209+ _maybe_patch_l1_memory_manager()
177210 return module
178211
179212 builtins.__import__ = _agentic_rocm_import
180213 _maybe_patch_lazy_memory_allocator()
214+ _maybe_patch_l1_memory_manager()
181215
182216if os.environ.get("LMCACHE_ROCM_MP_BLOCK_FALLBACK") == "1":
183217 import torch
0 commit comments