|
| 1 | +#include "core/plugins/impl/scatter_add_plugin.h" |
| 2 | +#include "core/plugins/plugins.h" |
| 3 | +#include "core/util/prelude.h" |
| 4 | + |
| 5 | +#include <ATen/cuda/CUDAContext.h> |
| 6 | +#include <ATen/cuda/CUDAEvent.h> |
| 7 | +#include <c10/cuda/CUDAStream.h> |
| 8 | + |
| 9 | +namespace torch_tensorrt { |
| 10 | +namespace core { |
| 11 | +namespace plugins { |
| 12 | +namespace impl { |
| 13 | + |
| 14 | +ScatterAddPlugin::ScatterAddPlugin() = default; |
| 15 | + |
| 16 | +nvinfer1::IPluginCapability* ScatterAddPlugin::getCapabilityInterface(nvinfer1::PluginCapabilityType type) noexcept { |
| 17 | + switch (type) { |
| 18 | + case nvinfer1::PluginCapabilityType::kCORE: |
| 19 | + return static_cast<nvinfer1::IPluginV3OneCore*>(this); |
| 20 | + case nvinfer1::PluginCapabilityType::kBUILD: |
| 21 | + return static_cast<nvinfer1::IPluginV3OneBuild*>(this); |
| 22 | + case nvinfer1::PluginCapabilityType::kRUNTIME: |
| 23 | + return static_cast<nvinfer1::IPluginV3OneRuntime*>(this); |
| 24 | + default: |
| 25 | + return nullptr; |
| 26 | + } |
| 27 | +} |
| 28 | + |
| 29 | +nvinfer1::IPluginV3* ScatterAddPlugin::clone() noexcept { |
| 30 | + return new ScatterAddPlugin(*this); |
| 31 | +} |
| 32 | + |
| 33 | +// --------------------------------------------------------------------------- |
| 34 | +// IPluginV3OneCore |
| 35 | +// --------------------------------------------------------------------------- |
| 36 | + |
| 37 | +const char* ScatterAddPlugin::getPluginName() const noexcept { |
| 38 | + return "ScatterAdd"; |
| 39 | +} |
| 40 | + |
| 41 | +const char* ScatterAddPlugin::getPluginVersion() const noexcept { |
| 42 | + return "1"; |
| 43 | +} |
| 44 | + |
| 45 | +const char* ScatterAddPlugin::getPluginNamespace() const noexcept { |
| 46 | + return "torch_tensorrt"; |
| 47 | +} |
| 48 | + |
| 49 | +// --------------------------------------------------------------------------- |
| 50 | +// IPluginV3OneBuild |
| 51 | +// --------------------------------------------------------------------------- |
| 52 | + |
| 53 | +int32_t ScatterAddPlugin::getNbOutputs() const noexcept { |
| 54 | + return 1; |
| 55 | +} |
| 56 | + |
| 57 | +int32_t ScatterAddPlugin::getOutputDataTypes( |
| 58 | + nvinfer1::DataType* outputTypes, |
| 59 | + int32_t nbOutputs, |
| 60 | + const nvinfer1::DataType* inputTypes, |
| 61 | + int32_t nbInputs) const noexcept { |
| 62 | + // Output has the same dtype as src (input 0). |
| 63 | + outputTypes[0] = inputTypes[0]; |
| 64 | + return 0; |
| 65 | +} |
| 66 | + |
| 67 | +int32_t ScatterAddPlugin::getOutputShapes( |
| 68 | + const nvinfer1::DimsExprs* inputs, |
| 69 | + int32_t nbInputs, |
| 70 | + const nvinfer1::DimsExprs* /*shapeInputs*/, |
| 71 | + int32_t /*nbShapeInputs*/, |
| 72 | + nvinfer1::DimsExprs* outputs, |
| 73 | + int32_t /*nbOutputs*/, |
| 74 | + nvinfer1::IExprBuilder& /*exprBuilder*/) noexcept { |
| 75 | + // Output shape == src shape (input 0). |
| 76 | + outputs[0] = inputs[0]; |
| 77 | + return 0; |
| 78 | +} |
| 79 | + |
| 80 | +bool ScatterAddPlugin::supportsFormatCombination( |
| 81 | + int32_t pos, |
| 82 | + const nvinfer1::DynamicPluginTensorDesc* inOut, |
| 83 | + int32_t nbInputs, |
| 84 | + int32_t nbOutputs) noexcept { |
| 85 | + const auto& desc = inOut[pos]; |
| 86 | + |
| 87 | + // All tensors must be row-major (linear) layout. |
| 88 | + if (desc.desc.format != nvinfer1::TensorFormat::kLINEAR) { |
| 89 | + return false; |
| 90 | + } |
| 91 | + |
| 92 | + // Positions 1 through nbInputs-2 are index tensors: int32 or int64. |
| 93 | + if (pos >= 1 && pos <= nbInputs - 2) { |
| 94 | + return desc.desc.type == nvinfer1::DataType::kINT32 || desc.desc.type == nvinfer1::DataType::kINT64; |
| 95 | + } |
| 96 | + |
| 97 | + // pos 0 (src), pos nbInputs-1 (values), pos nbInputs (output): |
| 98 | + // float32 / float16 / bfloat16, all sharing the same type. |
| 99 | + const bool float_type = desc.desc.type == nvinfer1::DataType::kFLOAT || desc.desc.type == nvinfer1::DataType::kHALF || |
| 100 | + desc.desc.type == nvinfer1::DataType::kBF16; |
| 101 | + if (!float_type) { |
| 102 | + return false; |
| 103 | + } |
| 104 | + |
| 105 | + // src, values and output must have the same dtype. |
| 106 | + if (pos == 0) { |
| 107 | + return true; |
| 108 | + } |
| 109 | + return desc.desc.type == inOut[0].desc.type; |
| 110 | +} |
| 111 | + |
| 112 | +int32_t ScatterAddPlugin::configurePlugin( |
| 113 | + const nvinfer1::DynamicPluginTensorDesc* in, |
| 114 | + int32_t nbInputs, |
| 115 | + const nvinfer1::DynamicPluginTensorDesc* /*out*/, |
| 116 | + int32_t /*nbOutputs*/) noexcept { |
| 117 | + dtype_ = in[0].desc.type; |
| 118 | + n_indices_ = nbInputs - 2; // exclude src and values |
| 119 | + idx_dtypes_.resize(n_indices_); |
| 120 | + for (int i = 0; i < n_indices_; ++i) { |
| 121 | + idx_dtypes_[i] = in[1 + i].desc.type; |
| 122 | + } |
| 123 | + return 0; |
| 124 | +} |
| 125 | + |
| 126 | +size_t ScatterAddPlugin::getWorkspaceSize( |
| 127 | + const nvinfer1::DynamicPluginTensorDesc* /*inputs*/, |
| 128 | + int32_t /*nbInputs*/, |
| 129 | + const nvinfer1::DynamicPluginTensorDesc* /*outputs*/, |
| 130 | + int32_t /*nbOutputs*/) const noexcept { |
| 131 | + return 0; |
| 132 | +} |
| 133 | + |
| 134 | +// --------------------------------------------------------------------------- |
| 135 | +// IPluginV3OneRuntime |
| 136 | +// --------------------------------------------------------------------------- |
| 137 | + |
| 138 | +int32_t ScatterAddPlugin::onShapeChange( |
| 139 | + const nvinfer1::PluginTensorDesc* in, |
| 140 | + int32_t nbInputs, |
| 141 | + const nvinfer1::PluginTensorDesc* /*out*/, |
| 142 | + int32_t /*nbOutputs*/) noexcept { |
| 143 | + dtype_ = in[0].type; |
| 144 | + n_indices_ = nbInputs - 2; |
| 145 | + src_shape_ = util::toVec(in[0].dims); |
| 146 | + val_shape_ = util::toVec(in[nbInputs - 1].dims); |
| 147 | + idx_dtypes_.resize(n_indices_); |
| 148 | + idx_shapes_.resize(n_indices_); |
| 149 | + for (int i = 0; i < n_indices_; ++i) { |
| 150 | + idx_dtypes_[i] = in[1 + i].type; |
| 151 | + idx_shapes_[i] = util::toVec(in[1 + i].dims); |
| 152 | + } |
| 153 | + return 0; |
| 154 | +} |
| 155 | + |
| 156 | +int32_t ScatterAddPlugin::enqueue( |
| 157 | + const nvinfer1::PluginTensorDesc* inputDesc, |
| 158 | + const nvinfer1::PluginTensorDesc* outputDesc, |
| 159 | + const void* const* inputs, |
| 160 | + void* const* outputs, |
| 161 | + void* /*workspace*/, |
| 162 | + cudaStream_t stream) noexcept { |
| 163 | + const at::ScalarType float_dtype = util::TRTDataTypeToScalarType(dtype_); |
| 164 | + const auto float_opts = at::TensorOptions().device(at::kCUDA).dtype(float_dtype); |
| 165 | + |
| 166 | + at::Tensor src = at::from_blob(const_cast<void*>(inputs[0]), src_shape_, float_opts); |
| 167 | + at::Tensor val = at::from_blob(const_cast<void*>(inputs[n_indices_ + 1]), val_shape_, float_opts); |
| 168 | + |
| 169 | + // Build the indices list — one entry per index tensor, all cast to int64 |
| 170 | + // as required by ATen's index_put kernel. |
| 171 | + c10::List<std::optional<at::Tensor>> indices; |
| 172 | + indices.reserve(n_indices_); |
| 173 | + for (int i = 0; i < n_indices_; ++i) { |
| 174 | + const at::ScalarType int_dtype = util::TRTDataTypeToScalarType(idx_dtypes_[i]); |
| 175 | + const auto int_opts = at::TensorOptions().device(at::kCUDA).dtype(int_dtype); |
| 176 | + at::Tensor idx = at::from_blob(const_cast<void*>(inputs[1 + i]), idx_shapes_[i], int_opts); |
| 177 | + indices.push_back(std::optional<at::Tensor>(idx.to(torch::kLong))); |
| 178 | + } |
| 179 | + |
| 180 | + // Use a separate PyTorch CUDA stream and synchronise with the TRT stream via |
| 181 | + // CUDA events — same pattern as interpolate_plugin.cpp. |
| 182 | + at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool(); |
| 183 | + at::cuda::CUDAStreamGuard torch_guard(torch_stream); |
| 184 | + |
| 185 | + cudaEvent_t start_event; |
| 186 | + cudaEventCreate(&start_event); |
| 187 | + cudaEventRecord(start_event, stream); |
| 188 | + cudaStreamWaitEvent(torch_stream.stream(), start_event, 0); |
| 189 | + |
| 190 | + // index_put with accumulate=true calls the atomicAdd-based CUDA kernel. |
| 191 | + at::Tensor result = at::index_put(src, indices, val, /*accumulate=*/true); |
| 192 | + |
| 193 | + at::Tensor out_t = at::from_blob(outputs[0], src_shape_, float_opts); |
| 194 | + out_t.copy_(result); |
| 195 | + |
| 196 | + cudaEvent_t done_event; |
| 197 | + cudaEventCreate(&done_event); |
| 198 | + cudaEventRecord(done_event, torch_stream.stream()); |
| 199 | + cudaStreamWaitEvent(stream, done_event, 0); |
| 200 | + |
| 201 | + cudaEventDestroy(start_event); |
| 202 | + cudaEventDestroy(done_event); |
| 203 | + |
| 204 | + return 0; |
| 205 | +} |
| 206 | + |
| 207 | +nvinfer1::IPluginV3* ScatterAddPlugin::attachToContext(nvinfer1::IPluginResourceContext* /*context*/) noexcept { |
| 208 | + return clone(); |
| 209 | +} |
| 210 | + |
| 211 | +const nvinfer1::PluginFieldCollection* ScatterAddPlugin::getFieldsToSerialize() noexcept { |
| 212 | + // No configuration attributes to serialize — shapes and dtype are captured |
| 213 | + // from the tensor descriptors at runtime. |
| 214 | + return &empty_fc_; |
| 215 | +} |
| 216 | + |
| 217 | +// --------------------------------------------------------------------------- |
| 218 | +// ScatterAddPluginCreator |
| 219 | +// --------------------------------------------------------------------------- |
| 220 | + |
| 221 | +ScatterAddPluginCreator::ScatterAddPluginCreator() = default; |
| 222 | + |
| 223 | +const char* ScatterAddPluginCreator::getPluginName() const noexcept { |
| 224 | + return "ScatterAdd"; |
| 225 | +} |
| 226 | + |
| 227 | +const char* ScatterAddPluginCreator::getPluginVersion() const noexcept { |
| 228 | + return "1"; |
| 229 | +} |
| 230 | + |
| 231 | +const char* ScatterAddPluginCreator::getPluginNamespace() const noexcept { |
| 232 | + return "torch_tensorrt"; |
| 233 | +} |
| 234 | + |
| 235 | +const nvinfer1::PluginFieldCollection* ScatterAddPluginCreator::getFieldNames() noexcept { |
| 236 | + return &fc_; |
| 237 | +} |
| 238 | + |
| 239 | +nvinfer1::IPluginV3* ScatterAddPluginCreator::createPlugin( |
| 240 | + const char* /*name*/, |
| 241 | + const nvinfer1::PluginFieldCollection* /*fc*/, |
| 242 | + nvinfer1::TensorRTPhase /*phase*/) noexcept { |
| 243 | + return new ScatterAddPlugin(); |
| 244 | +} |
| 245 | + |
| 246 | +// Register with the torch_tensorrt namespace. |
| 247 | +REGISTER_TORCHTRT_PLUGIN(ScatterAddPluginCreator); |
| 248 | + |
| 249 | +} // namespace impl |
| 250 | +} // namespace plugins |
| 251 | +} // namespace core |
| 252 | +} // namespace torch_tensorrt |
0 commit comments