Skip to content

Commit c9b6518

Browse files
[None][feat] Parallelize host KV cache pool prefault and add THP control (#15431)
Signed-off-by: Md Nafis Ul Haque Shifat <nafis@deepinfra.com> Co-authored-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>
1 parent 833ddd2 commit c9b6518

1 file changed

Lines changed: 65 additions & 3 deletions

File tree

  • tensorrt_llm/runtime/kv_cache_manager_v2

tensorrt_llm/runtime/kv_cache_manager_v2/_utils.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import array
17+
import concurrent.futures
1718
import ctypes
1819
import errno
1920
import functools
@@ -394,6 +395,21 @@ def find_index(seq: Iterable[T], predicate: Callable[[T], bool]) -> int:
394395
_libc.posix_fallocate.argtypes = [ctypes.c_int, ctypes.c_longlong, ctypes.c_longlong]
395396

396397
MADV_HUGEPAGE: Final[int] = 14
398+
MADV_NOHUGEPAGE: Final[int] = 15
399+
MADV_POPULATE_WRITE: Final[int] = 23
400+
401+
# TLLM_KV_CACHE_MANAGER_V2_THP=0 backs host pools with regular 4KB pages
402+
# (MADV_NOHUGEPAGE). On nodes with fragmented physical memory and THP
403+
# defrag=madvise, every 2MB THP fault stalls in direct compaction that
404+
# rarely succeeds, slowing pool population from GB/s to GB/min.
405+
USE_THP: Final[bool] = os.environ.get("TLLM_KV_CACHE_MANAGER_V2_THP", "1") == "1"
406+
# TLLM_KV_CACHE_MANAGER_V2_PREFAULT_THREADS=0 disables prefaulting; pages are
407+
# then faulted in lazily, single-threaded, inside cuMemHostRegister.
408+
PREFAULT_THREADS: Final[int] = int(
409+
os.environ.get(
410+
"TLLM_KV_CACHE_MANAGER_V2_PREFAULT_THREADS", str(min(64, (os.cpu_count() or 32) // 2))
411+
)
412+
)
397413

398414

399415
def _madvise(ptr: int, size: int, advice: int) -> None:
@@ -508,8 +524,10 @@ def __init__(self, size: int) -> None:
508524

509525
# Opportunistically advise huge pages for the whole range.
510526
# The kernel will use huge pages for aligned 2MB chunks within this range.
511-
_madvise(self._address, self._size, MADV_HUGEPAGE)
527+
_madvise(self._address, self._size, MADV_HUGEPAGE if USE_THP else MADV_NOHUGEPAGE)
512528

529+
if PREFAULT_THREADS > 0:
530+
self._parallel_prefault(PREFAULT_THREADS)
513531
self._register_to_cuda()
514532

515533
def resize(self, new_size: int) -> None:
@@ -519,8 +537,8 @@ def resize(self, new_size: int) -> None:
519537
assert self._address % self.ALIGNMENT == 0
520538
self._size = new_size
521539

522-
# Re-advise HUGEPAGE for the new range
523-
_madvise(self._address, self._size, MADV_HUGEPAGE)
540+
# Re-advise the configured page mode for the new range.
541+
_madvise(self._address, self._size, MADV_HUGEPAGE if USE_THP else MADV_NOHUGEPAGE)
524542
finally:
525543
self._register_to_cuda()
526544

@@ -535,6 +553,50 @@ def destroy(self) -> None:
535553
def __del__(self) -> None:
536554
self.destroy()
537555

556+
def _parallel_prefault(self, nthreads: int) -> None:
557+
"""Fault in all pages with parallel threads before cuMemHostRegister,
558+
so registration only pins pages (never allocates them).
559+
560+
Lazy faulting inside cuMemHostRegister is single-threaded and, under
561+
memory pressure or THP compaction stalls, can take minutes for
562+
multi-hundred-GiB pools. MADV_POPULATE_WRITE populates in bulk; small
563+
chunks keep mmap_lock hold times short (one giant madvise per thread
564+
serializes every other thread behind it) and let threads
565+
load-balance.
566+
"""
567+
chunk = 512 << 20
568+
569+
def populate(off: int) -> None:
570+
ln = min(chunk, self._size - off)
571+
if ln <= 0:
572+
return
573+
ret = _libc.madvise(
574+
ctypes.c_void_p(self._address + off),
575+
ctypes.c_size_t(ln),
576+
ctypes.c_int(MADV_POPULATE_WRITE),
577+
)
578+
if ret != 0:
579+
error_code = ctypes.get_errno()
580+
if error_code in (errno.EINVAL, getattr(errno, "ENOSYS", -1)):
581+
# MADV_POPULATE_WRITE requires Linux >= 5.14; on older
582+
# kernels fall back to touching every page.
583+
ctypes.memset(self._address + off, 0, ln)
584+
return
585+
error_name = errno.errorcode.get(error_code, "Unknown error")
586+
if error_code == errno.ENOMEM:
587+
# Surface real allocation failures instead of masking them
588+
# with a memset that would trigger a system OOM kill.
589+
raise HostOOMError(
590+
f"madvise(MADV_POPULATE_WRITE) failed with errno {error_code}: {error_name}"
591+
)
592+
raise OSError(
593+
error_code,
594+
f"madvise(MADV_POPULATE_WRITE) failed: {error_name}",
595+
)
596+
597+
with concurrent.futures.ThreadPoolExecutor(max_workers=nthreads) as executor:
598+
list(executor.map(populate, range(0, self._size, chunk)))
599+
538600
def _register_to_cuda(self) -> None:
539601
assert self._num_registered_chunks == 0
540602
for addr, size in self._iterate_chunks():

0 commit comments

Comments
 (0)