1818from tensorrt_llm ._torch .modules .rotary_embedding import RotaryEmbedding
1919from tensorrt_llm ._torch .pyexecutor .resource_manager import KVCacheManager
2020from tensorrt_llm ._torch .utils import maybe_compile , maybe_compiled_cat
21- from tensorrt_llm ._utils import get_size_in_bytes , get_sm_version
21+ from tensorrt_llm ._utils import get_size_in_bytes , get_sm_version , prefer_pinned
2222from tensorrt_llm .bindings import DataType
2323from tensorrt_llm .bindings .executor import KvCacheConfig
2424from tensorrt_llm .bindings .internal .batch_manager import \
@@ -339,7 +339,7 @@ def __post_init__(self):
339339 self .host_indexer_k_cache_block_offsets = torch .zeros_like (
340340 self .indexer_k_cache_block_offsets ,
341341 device = 'cpu' ,
342- pin_memory = True ,
342+ pin_memory = prefer_pinned () ,
343343 )
344344
345345 if not self .enable_context_mla_with_cached_kv :
@@ -353,7 +353,7 @@ def __post_init__(self):
353353 self .host_ctx_cached_token_indptr = torch .zeros_like (
354354 self .ctx_cached_token_indptr ,
355355 device = 'cpu' ,
356- pin_memory = True ,
356+ pin_memory = prefer_pinned () ,
357357 )
358358 self .ctx_kv_indptr = self .get_empty (
359359 self .cuda_graph_buffers ,
@@ -365,7 +365,7 @@ def __post_init__(self):
365365 self .host_ctx_kv_indptr = torch .zeros_like (
366366 self .ctx_kv_indptr ,
367367 device = 'cpu' ,
368- pin_memory = True ,
368+ pin_memory = prefer_pinned () ,
369369 )
370370
371371 # Only when MLA chunked prefill is enabled, we need to gather the full KV for indexer's logit computation.
@@ -385,7 +385,7 @@ def __post_init__(self):
385385 self .host_gen_cached_token_indptr = torch .zeros_like (
386386 self .gen_cached_token_indptr ,
387387 device = 'cpu' ,
388- pin_memory = True ,
388+ pin_memory = prefer_pinned () ,
389389 )
390390 self .gen_kv_indptr = self .get_empty (
391391 self .cuda_graph_buffers ,
@@ -397,7 +397,7 @@ def __post_init__(self):
397397 self .host_gen_kv_indptr = torch .zeros_like (
398398 self .gen_kv_indptr ,
399399 device = 'cpu' ,
400- pin_memory = True ,
400+ pin_memory = prefer_pinned () ,
401401 )
402402 # Indexer metadata
403403 # Separate slot mappings for non-interleaved layout (flat byte indices)
@@ -411,7 +411,7 @@ def __post_init__(self):
411411 self .host_slot_mapping_fp8 = torch .zeros_like (
412412 self .slot_mapping_fp8 ,
413413 device = 'cpu' ,
414- pin_memory = True ,
414+ pin_memory = prefer_pinned () ,
415415 )
416416 self .slot_mapping_scale = self .get_empty (
417417 self .cuda_graph_buffers ,
@@ -423,7 +423,7 @@ def __post_init__(self):
423423 self .host_slot_mapping_scale = torch .zeros_like (
424424 self .slot_mapping_scale ,
425425 device = 'cpu' ,
426- pin_memory = True ,
426+ pin_memory = prefer_pinned () ,
427427 )
428428 # Per-token request index buffer for topk_indices conversion
429429 self .req_idx_per_token = self .get_empty (
@@ -474,7 +474,7 @@ def __post_init__(self):
474474 self .host_topk_indices_buffer = torch .zeros_like (
475475 self .topk_indices_buffer ,
476476 device = 'cpu' ,
477- pin_memory = True ,
477+ pin_memory = prefer_pinned () ,
478478 )
479479 # Create expanded buffers for MTP support
480480 self .create_expanded_buffers (capture_graph = capture_graph )
@@ -491,7 +491,7 @@ def create_expanded_buffers(self, capture_graph=False):
491491 self .kv_lens_expanded_host = torch .zeros_like (
492492 self .kv_lens_expanded_cuda ,
493493 device = 'cpu' ,
494- pin_memory = True ,
494+ pin_memory = prefer_pinned () ,
495495 )
496496 self .block_table_expanded = self .get_empty (
497497 self .cuda_graph_buffers ,
@@ -506,7 +506,7 @@ def create_expanded_buffers(self, capture_graph=False):
506506 self .host_block_table_expanded = torch .zeros_like (
507507 self .block_table_expanded ,
508508 device = 'cpu' ,
509- pin_memory = True ,
509+ pin_memory = prefer_pinned () ,
510510 )
511511 self .scheduler_metadata_buffer_expanded = self .get_empty (
512512 self .cuda_graph_buffers ,
@@ -1171,12 +1171,10 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
11711171 total_kv_per_request = seq_lens [:
11721172 num_contexts ] + start_positions [:
11731173 num_contexts ]
1174- host_slot_mapping_fp8_fullkv = torch .empty (total_kv_len ,
1175- dtype = torch .int64 ,
1176- pin_memory = True )
1177- host_slot_mapping_scale_fullkv = torch .empty (total_kv_len ,
1178- dtype = torch .int64 ,
1179- pin_memory = True )
1174+ host_slot_mapping_fp8_fullkv = torch .empty (
1175+ total_kv_len , dtype = torch .int64 , pin_memory = prefer_pinned ())
1176+ host_slot_mapping_scale_fullkv = torch .empty (
1177+ total_kv_len , dtype = torch .int64 , pin_memory = prefer_pinned ())
11801178
11811179 req_indices = torch .repeat_interleave (
11821180 torch .arange (num_contexts , dtype = torch .int64 , device = 'cpu' ),
0 commit comments