Skip to content

Commit 949ed54

Browse files
authored
Add dynamic shape support to index_put (#4143)
1 parent 0168301 commit 949ed54

File tree

17 files changed

+1502
-198
lines changed

17 files changed

+1502
-198
lines changed

=2.12.0.dev

Whitespace-only changes.

core/plugins/BUILD

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
load("@rules_cc//cc:defs.bzl", "cc_library")
22
load("@rules_pkg//:pkg.bzl", "pkg_tar")
33
load("@rules_pkg//pkg:mappings.bzl", "pkg_files")
4+
45
package(default_visibility = ["//visibility:public"])
56

67
config_setting(
@@ -61,20 +62,32 @@ config_setting(
6162
cc_library(
6263
name = "torch_tensorrt_plugins",
6364
srcs = select({
65+
":jetpack": [
66+
"impl/interpolate_plugin.cpp",
67+
"impl/normalize_plugin.cpp",
68+
"register_plugins.cpp",
69+
],
6470
":rtx_win": [],
6571
":rtx_x86_64": [],
6672
"//conditions:default": [
6773
"impl/interpolate_plugin.cpp",
6874
"impl/normalize_plugin.cpp",
75+
"impl/scatter_add_plugin.cpp",
6976
"register_plugins.cpp",
7077
],
7178
}),
7279
hdrs = select({
80+
":jetpack": [
81+
"impl/interpolate_plugin.h",
82+
"impl/normalize_plugin.h",
83+
"plugins.h",
84+
],
7385
":rtx_win": [],
7486
":rtx_x86_64": [],
7587
"//conditions:default": [
7688
"impl/interpolate_plugin.h",
7789
"impl/normalize_plugin.h",
90+
"impl/scatter_add_plugin.h",
7891
"plugins.h",
7992
],
8093
}),
@@ -132,15 +145,16 @@ filegroup(
132145
srcs = [
133146
"impl/interpolate_plugin.h",
134147
"impl/normalize_plugin.h",
148+
"impl/scatter_add_plugin.h",
135149
],
136150
visibility = ["//visibility:public"],
137151
)
138152

139153
pkg_files(
140154
name = "impl_include_pkg_files",
141155
srcs = [":impl_include_files"],
142-
visibility = ["//visibility:public"],
143156
prefix = "include/torch_tensorrt/core/plugins/impl",
157+
visibility = ["//visibility:public"],
144158
)
145159

146160
pkg_tar(

core/plugins/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_library(${lib_name} OBJECT)
44
target_sources(${lib_name}
55
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/impl/interpolate_plugin.cpp"
66
"${CMAKE_CURRENT_SOURCE_DIR}/impl/normalize_plugin.cpp"
7+
"${CMAKE_CURRENT_SOURCE_DIR}/impl/scatter_add_plugin.cpp"
78
"${CMAKE_CURRENT_SOURCE_DIR}/register_plugins.cpp"
89
PUBLIC $<TARGET_OBJECTS:core_util>
910
)
@@ -33,6 +34,7 @@ install(
3334
FILES
3435
"${CMAKE_CURRENT_SOURCE_DIR}/impl/interpolate_plugin.h"
3536
"${CMAKE_CURRENT_SOURCE_DIR}/impl/normalize_plugin.h"
37+
"${CMAKE_CURRENT_SOURCE_DIR}/impl/scatter_add_plugin.h"
3638
DESTINATION
3739
"${CMAKE_INSTALL_INCLUDEDIR}/torch_tensorrt/core/plugins/impl"
3840
)
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
#include "core/plugins/impl/scatter_add_plugin.h"
2+
#include "core/plugins/plugins.h"
3+
#include "core/util/prelude.h"
4+
5+
#include <ATen/cuda/CUDAContext.h>
6+
#include <ATen/cuda/CUDAEvent.h>
7+
#include <c10/cuda/CUDAStream.h>
8+
9+
namespace torch_tensorrt {
10+
namespace core {
11+
namespace plugins {
12+
namespace impl {
13+
14+
ScatterAddPlugin::ScatterAddPlugin() = default;
15+
16+
nvinfer1::IPluginCapability* ScatterAddPlugin::getCapabilityInterface(nvinfer1::PluginCapabilityType type) noexcept {
17+
switch (type) {
18+
case nvinfer1::PluginCapabilityType::kCORE:
19+
return static_cast<nvinfer1::IPluginV3OneCore*>(this);
20+
case nvinfer1::PluginCapabilityType::kBUILD:
21+
return static_cast<nvinfer1::IPluginV3OneBuild*>(this);
22+
case nvinfer1::PluginCapabilityType::kRUNTIME:
23+
return static_cast<nvinfer1::IPluginV3OneRuntime*>(this);
24+
default:
25+
return nullptr;
26+
}
27+
}
28+
29+
nvinfer1::IPluginV3* ScatterAddPlugin::clone() noexcept {
30+
return new ScatterAddPlugin(*this);
31+
}
32+
33+
// ---------------------------------------------------------------------------
34+
// IPluginV3OneCore
35+
// ---------------------------------------------------------------------------
36+
37+
const char* ScatterAddPlugin::getPluginName() const noexcept {
38+
return "ScatterAdd";
39+
}
40+
41+
const char* ScatterAddPlugin::getPluginVersion() const noexcept {
42+
return "1";
43+
}
44+
45+
const char* ScatterAddPlugin::getPluginNamespace() const noexcept {
46+
return "torch_tensorrt";
47+
}
48+
49+
// ---------------------------------------------------------------------------
50+
// IPluginV3OneBuild
51+
// ---------------------------------------------------------------------------
52+
53+
int32_t ScatterAddPlugin::getNbOutputs() const noexcept {
54+
return 1;
55+
}
56+
57+
int32_t ScatterAddPlugin::getOutputDataTypes(
58+
nvinfer1::DataType* outputTypes,
59+
int32_t nbOutputs,
60+
const nvinfer1::DataType* inputTypes,
61+
int32_t nbInputs) const noexcept {
62+
// Output has the same dtype as src (input 0).
63+
outputTypes[0] = inputTypes[0];
64+
return 0;
65+
}
66+
67+
int32_t ScatterAddPlugin::getOutputShapes(
68+
const nvinfer1::DimsExprs* inputs,
69+
int32_t nbInputs,
70+
const nvinfer1::DimsExprs* /*shapeInputs*/,
71+
int32_t /*nbShapeInputs*/,
72+
nvinfer1::DimsExprs* outputs,
73+
int32_t /*nbOutputs*/,
74+
nvinfer1::IExprBuilder& /*exprBuilder*/) noexcept {
75+
// Output shape == src shape (input 0).
76+
outputs[0] = inputs[0];
77+
return 0;
78+
}
79+
80+
bool ScatterAddPlugin::supportsFormatCombination(
81+
int32_t pos,
82+
const nvinfer1::DynamicPluginTensorDesc* inOut,
83+
int32_t nbInputs,
84+
int32_t nbOutputs) noexcept {
85+
const auto& desc = inOut[pos];
86+
87+
// All tensors must be row-major (linear) layout.
88+
if (desc.desc.format != nvinfer1::TensorFormat::kLINEAR) {
89+
return false;
90+
}
91+
92+
// Positions 1 through nbInputs-2 are index tensors: int32 or int64.
93+
if (pos >= 1 && pos <= nbInputs - 2) {
94+
return desc.desc.type == nvinfer1::DataType::kINT32 || desc.desc.type == nvinfer1::DataType::kINT64;
95+
}
96+
97+
// pos 0 (src), pos nbInputs-1 (values), pos nbInputs (output):
98+
// float32 / float16 / bfloat16, all sharing the same type.
99+
const bool float_type = desc.desc.type == nvinfer1::DataType::kFLOAT || desc.desc.type == nvinfer1::DataType::kHALF ||
100+
desc.desc.type == nvinfer1::DataType::kBF16;
101+
if (!float_type) {
102+
return false;
103+
}
104+
105+
// src, values and output must have the same dtype.
106+
if (pos == 0) {
107+
return true;
108+
}
109+
return desc.desc.type == inOut[0].desc.type;
110+
}
111+
112+
int32_t ScatterAddPlugin::configurePlugin(
113+
const nvinfer1::DynamicPluginTensorDesc* in,
114+
int32_t nbInputs,
115+
const nvinfer1::DynamicPluginTensorDesc* /*out*/,
116+
int32_t /*nbOutputs*/) noexcept {
117+
dtype_ = in[0].desc.type;
118+
n_indices_ = nbInputs - 2; // exclude src and values
119+
idx_dtypes_.resize(n_indices_);
120+
for (int i = 0; i < n_indices_; ++i) {
121+
idx_dtypes_[i] = in[1 + i].desc.type;
122+
}
123+
return 0;
124+
}
125+
126+
size_t ScatterAddPlugin::getWorkspaceSize(
127+
const nvinfer1::DynamicPluginTensorDesc* /*inputs*/,
128+
int32_t /*nbInputs*/,
129+
const nvinfer1::DynamicPluginTensorDesc* /*outputs*/,
130+
int32_t /*nbOutputs*/) const noexcept {
131+
return 0;
132+
}
133+
134+
// ---------------------------------------------------------------------------
135+
// IPluginV3OneRuntime
136+
// ---------------------------------------------------------------------------
137+
138+
int32_t ScatterAddPlugin::onShapeChange(
139+
const nvinfer1::PluginTensorDesc* in,
140+
int32_t nbInputs,
141+
const nvinfer1::PluginTensorDesc* /*out*/,
142+
int32_t /*nbOutputs*/) noexcept {
143+
dtype_ = in[0].type;
144+
n_indices_ = nbInputs - 2;
145+
src_shape_ = util::toVec(in[0].dims);
146+
val_shape_ = util::toVec(in[nbInputs - 1].dims);
147+
idx_dtypes_.resize(n_indices_);
148+
idx_shapes_.resize(n_indices_);
149+
for (int i = 0; i < n_indices_; ++i) {
150+
idx_dtypes_[i] = in[1 + i].type;
151+
idx_shapes_[i] = util::toVec(in[1 + i].dims);
152+
}
153+
return 0;
154+
}
155+
156+
int32_t ScatterAddPlugin::enqueue(
157+
const nvinfer1::PluginTensorDesc* inputDesc,
158+
const nvinfer1::PluginTensorDesc* outputDesc,
159+
const void* const* inputs,
160+
void* const* outputs,
161+
void* /*workspace*/,
162+
cudaStream_t stream) noexcept {
163+
const at::ScalarType float_dtype = util::TRTDataTypeToScalarType(dtype_);
164+
const auto float_opts = at::TensorOptions().device(at::kCUDA).dtype(float_dtype);
165+
166+
at::Tensor src = at::from_blob(const_cast<void*>(inputs[0]), src_shape_, float_opts);
167+
at::Tensor val = at::from_blob(const_cast<void*>(inputs[n_indices_ + 1]), val_shape_, float_opts);
168+
169+
// Build the indices list — one entry per index tensor, all cast to int64
170+
// as required by ATen's index_put kernel.
171+
c10::List<std::optional<at::Tensor>> indices;
172+
indices.reserve(n_indices_);
173+
for (int i = 0; i < n_indices_; ++i) {
174+
const at::ScalarType int_dtype = util::TRTDataTypeToScalarType(idx_dtypes_[i]);
175+
const auto int_opts = at::TensorOptions().device(at::kCUDA).dtype(int_dtype);
176+
at::Tensor idx = at::from_blob(const_cast<void*>(inputs[1 + i]), idx_shapes_[i], int_opts);
177+
indices.push_back(std::optional<at::Tensor>(idx.to(torch::kLong)));
178+
}
179+
180+
// Use a separate PyTorch CUDA stream and synchronise with the TRT stream via
181+
// CUDA events — same pattern as interpolate_plugin.cpp.
182+
at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool();
183+
at::cuda::CUDAStreamGuard torch_guard(torch_stream);
184+
185+
cudaEvent_t start_event;
186+
cudaEventCreate(&start_event);
187+
cudaEventRecord(start_event, stream);
188+
cudaStreamWaitEvent(torch_stream.stream(), start_event, 0);
189+
190+
// index_put with accumulate=true calls the atomicAdd-based CUDA kernel.
191+
at::Tensor result = at::index_put(src, indices, val, /*accumulate=*/true);
192+
193+
at::Tensor out_t = at::from_blob(outputs[0], src_shape_, float_opts);
194+
out_t.copy_(result);
195+
196+
cudaEvent_t done_event;
197+
cudaEventCreate(&done_event);
198+
cudaEventRecord(done_event, torch_stream.stream());
199+
cudaStreamWaitEvent(stream, done_event, 0);
200+
201+
cudaEventDestroy(start_event);
202+
cudaEventDestroy(done_event);
203+
204+
return 0;
205+
}
206+
207+
nvinfer1::IPluginV3* ScatterAddPlugin::attachToContext(nvinfer1::IPluginResourceContext* /*context*/) noexcept {
208+
return clone();
209+
}
210+
211+
const nvinfer1::PluginFieldCollection* ScatterAddPlugin::getFieldsToSerialize() noexcept {
212+
// No configuration attributes to serialize — shapes and dtype are captured
213+
// from the tensor descriptors at runtime.
214+
return &empty_fc_;
215+
}
216+
217+
// ---------------------------------------------------------------------------
218+
// ScatterAddPluginCreator
219+
// ---------------------------------------------------------------------------
220+
221+
ScatterAddPluginCreator::ScatterAddPluginCreator() = default;
222+
223+
const char* ScatterAddPluginCreator::getPluginName() const noexcept {
224+
return "ScatterAdd";
225+
}
226+
227+
const char* ScatterAddPluginCreator::getPluginVersion() const noexcept {
228+
return "1";
229+
}
230+
231+
const char* ScatterAddPluginCreator::getPluginNamespace() const noexcept {
232+
return "torch_tensorrt";
233+
}
234+
235+
const nvinfer1::PluginFieldCollection* ScatterAddPluginCreator::getFieldNames() noexcept {
236+
return &fc_;
237+
}
238+
239+
nvinfer1::IPluginV3* ScatterAddPluginCreator::createPlugin(
240+
const char* /*name*/,
241+
const nvinfer1::PluginFieldCollection* /*fc*/,
242+
nvinfer1::TensorRTPhase /*phase*/) noexcept {
243+
return new ScatterAddPlugin();
244+
}
245+
246+
// Register with the torch_tensorrt namespace.
247+
REGISTER_TORCHTRT_PLUGIN(ScatterAddPluginCreator);
248+
249+
} // namespace impl
250+
} // namespace plugins
251+
} // namespace core
252+
} // namespace torch_tensorrt

0 commit comments

Comments
 (0)