Skip to content

Commit c6e3a52

Browse files
committed
feat: move the acc case to a plugin that calls the libtorch kernel
1 parent e8e622e commit c6e3a52

7 files changed

Lines changed: 546 additions & 197 deletions

File tree

core/plugins/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ cc_library(
6666
"//conditions:default": [
6767
"impl/interpolate_plugin.cpp",
6868
"impl/normalize_plugin.cpp",
69+
"impl/scatter_add_plugin.cpp",
6970
"register_plugins.cpp",
7071
],
7172
}),
@@ -75,6 +76,7 @@ cc_library(
7576
"//conditions:default": [
7677
"impl/interpolate_plugin.h",
7778
"impl/normalize_plugin.h",
79+
"impl/scatter_add_plugin.h",
7880
"plugins.h",
7981
],
8082
}),
@@ -132,6 +134,7 @@ filegroup(
132134
srcs = [
133135
"impl/interpolate_plugin.h",
134136
"impl/normalize_plugin.h",
137+
"impl/scatter_add_plugin.h",
135138
],
136139
visibility = ["//visibility:public"],
137140
)

core/plugins/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_library(${lib_name} OBJECT)
44
target_sources(${lib_name}
55
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/impl/interpolate_plugin.cpp"
66
"${CMAKE_CURRENT_SOURCE_DIR}/impl/normalize_plugin.cpp"
7+
"${CMAKE_CURRENT_SOURCE_DIR}/impl/scatter_add_plugin.cpp"
78
"${CMAKE_CURRENT_SOURCE_DIR}/register_plugins.cpp"
89
PUBLIC $<TARGET_OBJECTS:core_util>
910
)
@@ -33,6 +34,7 @@ install(
3334
FILES
3435
"${CMAKE_CURRENT_SOURCE_DIR}/impl/interpolate_plugin.h"
3536
"${CMAKE_CURRENT_SOURCE_DIR}/impl/normalize_plugin.h"
37+
"${CMAKE_CURRENT_SOURCE_DIR}/impl/scatter_add_plugin.h"
3638
DESTINATION
3739
"${CMAKE_INSTALL_INCLUDEDIR}/torch_tensorrt/core/plugins/impl"
3840
)
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
#include "core/plugins/impl/scatter_add_plugin.h"
2+
#include "core/plugins/plugins.h"
3+
#include "core/util/prelude.h"
4+
5+
#include <ATen/cuda/CUDAContext.h>
6+
#include <ATen/cuda/CUDAEvent.h>
7+
#include <c10/cuda/CUDAStream.h>
8+
9+
namespace torch_tensorrt {
10+
namespace core {
11+
namespace plugins {
12+
namespace impl {
13+
14+
// ---------------------------------------------------------------------------
15+
// Helpers
16+
// ---------------------------------------------------------------------------
17+
18+
19+
// ---------------------------------------------------------------------------
20+
// ScatterAddPlugin — construction
21+
// ---------------------------------------------------------------------------
22+
23+
ScatterAddPlugin::ScatterAddPlugin() = default;
24+
25+
// ---------------------------------------------------------------------------
26+
// IPluginV3
27+
// ---------------------------------------------------------------------------
28+
29+
nvinfer1::IPluginCapability* ScatterAddPlugin::getCapabilityInterface(
30+
nvinfer1::PluginCapabilityType type) noexcept {
31+
switch (type) {
32+
case nvinfer1::PluginCapabilityType::kCORE:
33+
return static_cast<nvinfer1::IPluginV3OneCore*>(this);
34+
case nvinfer1::PluginCapabilityType::kBUILD:
35+
return static_cast<nvinfer1::IPluginV3OneBuild*>(this);
36+
case nvinfer1::PluginCapabilityType::kRUNTIME:
37+
return static_cast<nvinfer1::IPluginV3OneRuntime*>(this);
38+
default:
39+
return nullptr;
40+
}
41+
}
42+
43+
nvinfer1::IPluginV3* ScatterAddPlugin::clone() noexcept {
44+
return new ScatterAddPlugin(*this);
45+
}
46+
47+
// ---------------------------------------------------------------------------
48+
// IPluginV3OneCore
49+
// ---------------------------------------------------------------------------
50+
51+
const char* ScatterAddPlugin::getPluginName() const noexcept {
52+
return "ScatterAdd";
53+
}
54+
55+
const char* ScatterAddPlugin::getPluginVersion() const noexcept {
56+
return "1";
57+
}
58+
59+
const char* ScatterAddPlugin::getPluginNamespace() const noexcept {
60+
return "torch_tensorrt";
61+
}
62+
63+
// ---------------------------------------------------------------------------
64+
// IPluginV3OneBuild
65+
// ---------------------------------------------------------------------------
66+
67+
int32_t ScatterAddPlugin::getNbOutputs() const noexcept {
68+
return 1;
69+
}
70+
71+
int32_t ScatterAddPlugin::getOutputDataTypes(
72+
nvinfer1::DataType* outputTypes,
73+
int32_t nbOutputs,
74+
const nvinfer1::DataType* inputTypes,
75+
int32_t nbInputs) const noexcept {
76+
// Output has the same dtype as src (input 0).
77+
outputTypes[0] = inputTypes[0];
78+
return 0;
79+
}
80+
81+
int32_t ScatterAddPlugin::getOutputShapes(
82+
const nvinfer1::DimsExprs* inputs,
83+
int32_t nbInputs,
84+
const nvinfer1::DimsExprs* /*shapeInputs*/,
85+
int32_t /*nbShapeInputs*/,
86+
nvinfer1::DimsExprs* outputs,
87+
int32_t /*nbOutputs*/,
88+
nvinfer1::IExprBuilder& /*exprBuilder*/) noexcept {
89+
// Output shape == src shape (input 0).
90+
outputs[0] = inputs[0];
91+
return 0;
92+
}
93+
94+
bool ScatterAddPlugin::supportsFormatCombination(
95+
int32_t pos,
96+
const nvinfer1::DynamicPluginTensorDesc* inOut,
97+
int32_t nbInputs,
98+
int32_t nbOutputs) noexcept {
99+
const auto& desc = inOut[pos];
100+
101+
// All tensors must be row-major (linear) layout.
102+
if (desc.desc.format != nvinfer1::TensorFormat::kLINEAR) {
103+
return false;
104+
}
105+
106+
// Positions 1 through nbInputs-2 are index tensors: int32 or int64.
107+
if (pos >= 1 && pos <= nbInputs - 2) {
108+
return desc.desc.type == nvinfer1::DataType::kINT32 ||
109+
desc.desc.type == nvinfer1::DataType::kINT64;
110+
}
111+
112+
// pos 0 (src), pos nbInputs-1 (values), pos nbInputs (output):
113+
// float32 / float16 / bfloat16, all sharing the same type.
114+
const bool float_type = desc.desc.type == nvinfer1::DataType::kFLOAT ||
115+
desc.desc.type == nvinfer1::DataType::kHALF ||
116+
desc.desc.type == nvinfer1::DataType::kBF16;
117+
if (!float_type) {
118+
return false;
119+
}
120+
121+
// src, values and output must have the same dtype.
122+
if (pos == 0) {
123+
return true;
124+
}
125+
return desc.desc.type == inOut[0].desc.type;
126+
}
127+
128+
int32_t ScatterAddPlugin::configurePlugin(
129+
const nvinfer1::DynamicPluginTensorDesc* in,
130+
int32_t nbInputs,
131+
const nvinfer1::DynamicPluginTensorDesc* /*out*/,
132+
int32_t /*nbOutputs*/) noexcept {
133+
dtype_ = in[0].desc.type;
134+
n_indices_ = nbInputs - 2; // exclude src and values
135+
idx_dtypes_.resize(n_indices_);
136+
for (int i = 0; i < n_indices_; ++i) {
137+
idx_dtypes_[i] = in[1 + i].desc.type;
138+
}
139+
return 0;
140+
}
141+
142+
size_t ScatterAddPlugin::getWorkspaceSize(
143+
const nvinfer1::DynamicPluginTensorDesc* /*inputs*/,
144+
int32_t /*nbInputs*/,
145+
const nvinfer1::DynamicPluginTensorDesc* /*outputs*/,
146+
int32_t /*nbOutputs*/) const noexcept {
147+
return 0;
148+
}
149+
150+
// ---------------------------------------------------------------------------
151+
// IPluginV3OneRuntime
152+
// ---------------------------------------------------------------------------
153+
154+
int32_t ScatterAddPlugin::onShapeChange(
155+
const nvinfer1::PluginTensorDesc* in,
156+
int32_t nbInputs,
157+
const nvinfer1::PluginTensorDesc* /*out*/,
158+
int32_t /*nbOutputs*/) noexcept {
159+
dtype_ = in[0].type;
160+
n_indices_ = nbInputs - 2;
161+
src_shape_ = util::toVec(in[0].dims);
162+
val_shape_ = util::toVec(in[nbInputs - 1].dims);
163+
idx_dtypes_.resize(n_indices_);
164+
idx_shapes_.resize(n_indices_);
165+
for (int i = 0; i < n_indices_; ++i) {
166+
idx_dtypes_[i] = in[1 + i].type;
167+
idx_shapes_[i] = util::toVec(in[1 + i].dims);
168+
}
169+
return 0;
170+
}
171+
172+
int32_t ScatterAddPlugin::enqueue(
173+
const nvinfer1::PluginTensorDesc* inputDesc,
174+
const nvinfer1::PluginTensorDesc* outputDesc,
175+
const void* const* inputs,
176+
void* const* outputs,
177+
void* /*workspace*/,
178+
cudaStream_t stream) noexcept {
179+
const at::ScalarType float_dtype = util::TRTDataTypeToScalarType(dtype_);
180+
const auto float_opts = at::TensorOptions().device(at::kCUDA).dtype(float_dtype);
181+
182+
at::Tensor src = at::from_blob(const_cast<void*>(inputs[0]), src_shape_, float_opts);
183+
at::Tensor val = at::from_blob(
184+
const_cast<void*>(inputs[n_indices_ + 1]), val_shape_, float_opts);
185+
186+
// Build the indices list — one entry per index tensor, all cast to int64
187+
// as required by ATen's index_put kernel.
188+
c10::List<std::optional<at::Tensor>> indices;
189+
indices.reserve(n_indices_);
190+
for (int i = 0; i < n_indices_; ++i) {
191+
const at::ScalarType int_dtype = util::TRTDataTypeToScalarType(idx_dtypes_[i]);
192+
const auto int_opts = at::TensorOptions().device(at::kCUDA).dtype(int_dtype);
193+
at::Tensor idx = at::from_blob(
194+
const_cast<void*>(inputs[1 + i]), idx_shapes_[i], int_opts);
195+
indices.push_back(std::optional<at::Tensor>(idx.to(torch::kLong)));
196+
}
197+
198+
// Use a separate PyTorch CUDA stream and synchronise with the TRT stream via
199+
// CUDA events — same pattern as interpolate_plugin.cpp.
200+
at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool();
201+
at::cuda::CUDAStreamGuard torch_guard(torch_stream);
202+
203+
cudaEvent_t start_event;
204+
cudaEventCreate(&start_event);
205+
cudaEventRecord(start_event, stream);
206+
cudaStreamWaitEvent(torch_stream.stream(), start_event, 0);
207+
208+
// index_put with accumulate=true calls the atomicAdd-based CUDA kernel.
209+
at::Tensor result = at::index_put(src, indices, val, /*accumulate=*/true);
210+
211+
at::Tensor out_t = at::from_blob(outputs[0], src_shape_, float_opts);
212+
out_t.copy_(result);
213+
214+
cudaEvent_t done_event;
215+
cudaEventCreate(&done_event);
216+
cudaEventRecord(done_event, torch_stream.stream());
217+
cudaStreamWaitEvent(stream, done_event, 0);
218+
219+
cudaEventDestroy(start_event);
220+
cudaEventDestroy(done_event);
221+
222+
return 0;
223+
}
224+
225+
nvinfer1::IPluginV3* ScatterAddPlugin::attachToContext(
226+
nvinfer1::IPluginResourceContext* /*context*/) noexcept {
227+
return clone();
228+
}
229+
230+
const nvinfer1::PluginFieldCollection* ScatterAddPlugin::getFieldsToSerialize() noexcept {
231+
// No configuration attributes to serialize — shapes and dtype are captured
232+
// from the tensor descriptors at runtime.
233+
return &empty_fc_;
234+
}
235+
236+
// ---------------------------------------------------------------------------
237+
// ScatterAddPluginCreator
238+
// ---------------------------------------------------------------------------
239+
240+
ScatterAddPluginCreator::ScatterAddPluginCreator() = default;
241+
242+
const char* ScatterAddPluginCreator::getPluginName() const noexcept {
243+
return "ScatterAdd";
244+
}
245+
246+
const char* ScatterAddPluginCreator::getPluginVersion() const noexcept {
247+
return "1";
248+
}
249+
250+
const char* ScatterAddPluginCreator::getPluginNamespace() const noexcept {
251+
return "torch_tensorrt";
252+
}
253+
254+
const nvinfer1::PluginFieldCollection* ScatterAddPluginCreator::getFieldNames() noexcept {
255+
return &fc_;
256+
}
257+
258+
nvinfer1::IPluginV3* ScatterAddPluginCreator::createPlugin(
259+
const char* /*name*/,
260+
const nvinfer1::PluginFieldCollection* /*fc*/,
261+
nvinfer1::TensorRTPhase /*phase*/) noexcept {
262+
return new ScatterAddPlugin();
263+
}
264+
265+
// Register with the torch_tensorrt namespace.
266+
REGISTER_TORCHTRT_PLUGIN(ScatterAddPluginCreator);
267+
268+
} // namespace impl
269+
} // namespace plugins
270+
} // namespace core
271+
} // namespace torch_tensorrt

0 commit comments

Comments
 (0)