Skip to content

Commit 06456e1

Browse files
authored
[TRTLLM-10947][perf] eagle3: use cudaMemcpy2DAsync custom op for hidden-state capture (#14479)
Signed-off-by: Pietro Cicotti <5833013+pcicotti@users.noreply.github.com>
1 parent 441eaae commit 06456e1

7 files changed

Lines changed: 242 additions & 5 deletions

File tree

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ add_library(
119119
dsv3RopeOp.cpp
120120
fusedGemmAllreduceOp.cpp
121121
convertReqIndexToGlobalOp.cpp
122-
trtllmGenQKVProcessOp.cpp)
122+
trtllmGenQKVProcessOp.cpp
123+
inplaceSliceCopyOp.cpp)
123124
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
124125
target_link_libraries(
125126
th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "tensorrt_llm/common/cudaUtils.h"
18+
#include "tensorrt_llm/thop/thUtils.h"
19+
20+
#include <ATen/cuda/CUDAContext.h>
21+
#include <cuda_runtime_api.h>
22+
23+
namespace tensorrt_llm::torch_ext
24+
{
25+
26+
// Copy src[:, :] into dest[:numTokens, dim1Start:dim1End] using cudaMemcpy2D.
27+
// dest : 2-D contiguous CUDA tensor, shape [destRows, destCols]
28+
// src : 2-D contiguous CUDA tensor, shape [numTokens, sliceWidth] where sliceWidth == dim1End - dim1Start
29+
// dim1Start : first column index in dest to write into
30+
// dim1End : one-past-last column index in dest to write into
31+
// numTokens is inferred from src.size(0)
32+
void inplaceSliceCopy(at::Tensor& dest, at::Tensor const& src, int64_t dim1Start, int64_t dim1End)
33+
{
34+
CHECK_TH_CUDA(dest);
35+
CHECK_TH_CUDA(src);
36+
TORCH_CHECK(dest.get_device() == src.get_device(), "dest and src must be on the same CUDA device");
37+
TORCH_CHECK(dest.is_contiguous(), "dest must be contiguous");
38+
TORCH_CHECK(src.is_contiguous(), "src must be contiguous");
39+
TORCH_CHECK(dest.dim() == 2, "dest must be 2-D");
40+
TORCH_CHECK(src.dim() == 2, "src must be 2-D");
41+
TORCH_CHECK(dest.scalar_type() == src.scalar_type(), "dest and src must have the same dtype");
42+
43+
int64_t const numTokens = src.size(0);
44+
int64_t const sliceWidth = dim1End - dim1Start;
45+
TORCH_CHECK(dim1Start >= 0, "dim1Start must be non-negative");
46+
TORCH_CHECK(sliceWidth > 0, "dim1End must be greater than dim1Start");
47+
TORCH_CHECK(numTokens <= dest.size(0), "numTokens exceeds dest row count");
48+
TORCH_CHECK(dim1End <= dest.size(1), "dim1End exceeds dest column count");
49+
TORCH_CHECK(src.size(1) == sliceWidth, "src column count must equal dim1End - dim1Start");
50+
51+
if (numTokens == 0 || sliceWidth == 0)
52+
{
53+
return;
54+
}
55+
56+
int64_t const elemSize = dest.element_size();
57+
int64_t const destPitch = dest.size(1) * elemSize; // bytes per dest row
58+
int64_t const srcPitch = src.size(1) * elemSize; // bytes per src row
59+
int64_t const width = sliceWidth * elemSize; // bytes to copy per row
60+
61+
char* destPtr = static_cast<char*>(dest.data_ptr()) + dim1Start * elemSize;
62+
char const* srcPtr = static_cast<char const*>(src.data_ptr());
63+
64+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dest.get_device());
65+
TLLM_CUDA_CHECK(cudaMemcpy2DAsync(
66+
destPtr, destPitch, srcPtr, srcPitch, width, static_cast<size_t>(numTokens), cudaMemcpyDeviceToDevice, stream));
67+
}
68+
69+
} // namespace tensorrt_llm::torch_ext
70+
71+
TORCH_LIBRARY_FRAGMENT(trtllm, m)
72+
{
73+
// dest: destination tensor (mutated in-place)
74+
// src: source tensor (numTokens inferred from src.size(0))
75+
// dim1_start: first column index in dest
76+
// dim1_end: one-past-last column index in dest
77+
m.def("inplace_slice_copy(Tensor(a!) dest, Tensor src, int dim1_start, int dim1_end) -> ()");
78+
}
79+
80+
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
81+
{
82+
m.impl("inplace_slice_copy", TORCH_FN(tensorrt_llm::torch_ext::inplaceSliceCopy));
83+
}

tensorrt_llm/_torch/compilation/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ def inplace_info():
121121
torch.ops.trtllm.cute_dsl_bf16_gemm_blackwell.default: {
122122
1: "output"
123123
},
124+
torch.ops.trtllm.inplace_slice_copy.default: {
125+
1: "dest"
126+
}
124127
}
125128
if IS_CUDA_TILE_AVAILABLE:
126129
# cuda.tile availability depends on GPU capability thus runtime check.

tensorrt_llm/_torch/custom_ops/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import torch
2+
13
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
24
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
35
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
@@ -10,6 +12,12 @@
1012
# modules.attention and must be imported from there. They are not re-exported here to
1113
# avoid circular imports: custom_ops must not depend on modules.attention.
1214

15+
16+
def inplace_slice_copy(dest: torch.Tensor, src: torch.Tensor, dim1_start: int,
17+
dim1_end: int) -> None:
18+
torch.ops.trtllm.inplace_slice_copy(dest, src, dim1_start, dim1_end)
19+
20+
1321
__all__ = [
1422
'IS_FLASHINFER_AVAILABLE',
1523
'_register_fake',
@@ -20,6 +28,7 @@
2028
'copy_to_userbuffers',
2129
'matmul_to_ub',
2230
'IS_CUTLASS_DSL_AVAILABLE',
31+
'inplace_slice_copy',
2332
]
2433

2534
if IS_FLASHINFER_AVAILABLE:

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,10 @@ def _(scores, scores_with_bias, n_group, topk_group, topk,
204204
dtype=scores_with_bias.dtype), scores.new_empty(
205205
shape, dtype=torch.int32)
206206

207+
@torch.library.register_fake("trtllm::inplace_slice_copy")
208+
def _(dest, src, dim1_start, dim1_end):
209+
pass
210+
207211
@torch.library.register_fake("trtllm::indexer_topk_prefill")
208212
def _(logits, row_starts, row_ends, indices, index_topk):
209213
# In-place operation, no return value (void function)

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch import nn
66

7+
from tensorrt_llm._torch.custom_ops import inplace_slice_copy
78
from tensorrt_llm._utils import prefer_pinned
89
from tensorrt_llm.mapping import Mapping
910

@@ -457,13 +458,13 @@ def maybe_capture_hidden_states(
457458
layer_id: int,
458459
hidden_states: torch.Tensor,
459460
residual: Optional[torch.Tensor] = None) -> None:
461+
460462
for i, captured_layer_id in enumerate(self.layers_to_capture):
461463
if captured_layer_id == layer_id:
462-
num_tokens = hidden_states.shape[0]
463464
to_save = hidden_states + residual if residual is not None else hidden_states
464-
self.hidden_states[:num_tokens, i * self.hidden_size:(i + 1) *
465-
self.hidden_size].copy_(to_save,
466-
non_blocking=True)
465+
inplace_slice_copy(self.hidden_states, to_save,
466+
i * self.hidden_size,
467+
(i + 1) * self.hidden_size)
467468
break
468469

469470

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Unit tests for trtllm::inplace_slice_copy.
16+
17+
Verifies that the cudaMemcpy2DAsync-backed op produces the same result as a
18+
reference Python slice + Tensor.copy_, for the row-prefix / column-slice
19+
write pattern used in EAGLE3 hidden-state capture.
20+
"""
21+
22+
import pytest
23+
import torch
24+
25+
import tensorrt_llm # noqa: F401
26+
27+
28+
def _reference(dest_shape, src, dim1_start, dim1_end, dtype):
29+
dest = torch.zeros(dest_shape, dtype=dtype, device="cuda")
30+
num_tokens = src.shape[0]
31+
dest[:num_tokens, dim1_start:dim1_end].copy_(src)
32+
return dest
33+
34+
35+
def _run(dest_shape, src, dim1_start, dim1_end, dtype):
36+
dest = torch.zeros(dest_shape, dtype=dtype, device="cuda")
37+
torch.ops.trtllm.inplace_slice_copy(dest, src, dim1_start, dim1_end)
38+
return dest
39+
40+
41+
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
42+
def test_full_dest_full_width(dtype):
43+
"""num_tokens == dest.size(0) and slice == full dest width."""
44+
dest_shape = (16, 64)
45+
src = torch.randn(16, 64, dtype=dtype, device="cuda")
46+
out = _run(dest_shape, src, 0, 64, dtype)
47+
ref = _reference(dest_shape, src, 0, 64, dtype)
48+
torch.testing.assert_close(out, ref)
49+
50+
51+
def test_partial_rows():
52+
"""num_tokens < dest.size(0): trailing rows must stay zero."""
53+
dtype = torch.bfloat16
54+
dest_shape = (32, 64)
55+
src = torch.randn(8, 64, dtype=dtype, device="cuda")
56+
out = _run(dest_shape, src, 0, 64, dtype)
57+
ref = _reference(dest_shape, src, 0, 64, dtype)
58+
torch.testing.assert_close(out, ref)
59+
assert torch.all(out[8:] == 0)
60+
61+
62+
def test_column_slice_middle():
63+
"""Write to a middle column band; flanking columns must stay zero."""
64+
dtype = torch.bfloat16
65+
dest_shape = (16, 96)
66+
src = torch.randn(16, 32, dtype=dtype, device="cuda")
67+
out = _run(dest_shape, src, 32, 64, dtype)
68+
ref = _reference(dest_shape, src, 32, 64, dtype)
69+
torch.testing.assert_close(out, ref)
70+
assert torch.all(out[:, :32] == 0)
71+
assert torch.all(out[:, 64:] == 0)
72+
73+
74+
def test_layered_capture_pattern():
75+
"""Mimic EAGLE3 hidden-state capture: write each layer into its band."""
76+
dtype = torch.bfloat16
77+
num_tokens, hidden_size, num_layers = 12, 48, 3
78+
dest_shape = (24, hidden_size * num_layers)
79+
srcs = [
80+
torch.randn(num_tokens, hidden_size, dtype=dtype, device="cuda") for _ in range(num_layers)
81+
]
82+
83+
out = torch.zeros(dest_shape, dtype=dtype, device="cuda")
84+
for i, s in enumerate(srcs):
85+
torch.ops.trtllm.inplace_slice_copy(out, s, i * hidden_size, (i + 1) * hidden_size)
86+
87+
ref = torch.zeros(dest_shape, dtype=dtype, device="cuda")
88+
for i, s in enumerate(srcs):
89+
ref[:num_tokens, i * hidden_size : (i + 1) * hidden_size].copy_(s)
90+
91+
torch.testing.assert_close(out, ref)
92+
93+
94+
def test_empty_src_is_noop():
95+
"""num_tokens == 0 must not modify dest and must not raise."""
96+
dtype = torch.bfloat16
97+
dest_shape = (16, 64)
98+
dest = torch.full(dest_shape, 7, dtype=dtype, device="cuda")
99+
src = torch.empty(0, 32, dtype=dtype, device="cuda")
100+
torch.ops.trtllm.inplace_slice_copy(dest, src, 16, 48)
101+
assert torch.all(dest == 7)
102+
103+
104+
def test_dtype_mismatch_raises():
105+
dest = torch.zeros(8, 32, dtype=torch.bfloat16, device="cuda")
106+
src = torch.randn(8, 32, dtype=torch.float16, device="cuda")
107+
with pytest.raises(RuntimeError):
108+
torch.ops.trtllm.inplace_slice_copy(dest, src, 0, 32)
109+
110+
111+
def test_out_of_bounds_raises():
112+
dtype = torch.bfloat16
113+
dest = torch.zeros(8, 32, dtype=dtype, device="cuda")
114+
src = torch.randn(8, 8, dtype=dtype, device="cuda")
115+
with pytest.raises(RuntimeError):
116+
torch.ops.trtllm.inplace_slice_copy(dest, src, 28, 36)
117+
118+
119+
def test_negative_dim1_start_raises():
120+
"""A negative dim1_start would underflow the dest pointer."""
121+
dtype = torch.bfloat16
122+
dest = torch.zeros(8, 32, dtype=dtype, device="cuda")
123+
src = torch.randn(8, 8, dtype=dtype, device="cuda")
124+
with pytest.raises(RuntimeError):
125+
torch.ops.trtllm.inplace_slice_copy(dest, src, -8, 0)
126+
127+
128+
def test_device_mismatch_raises():
129+
"""dest and src on different CUDA devices must be rejected."""
130+
if torch.cuda.device_count() < 2:
131+
pytest.skip("requires >= 2 CUDA devices")
132+
dtype = torch.bfloat16
133+
dest = torch.zeros(8, 32, dtype=dtype, device="cuda:0")
134+
src = torch.randn(8, 32, dtype=dtype, device="cuda:1")
135+
with pytest.raises(RuntimeError):
136+
torch.ops.trtllm.inplace_slice_copy(dest, src, 0, 32)

0 commit comments

Comments
 (0)