Skip to content

Commit 6ef21f0

Browse files
authored
[CUDA Plugin] Refactoring Einsum (microsoft#27606)
## Description Refactor the CUDA Einsum kernel to unify the CUDA Einsum implementation into a single code path that works for both the standard CUDA EP build and the plugin EP build. ### Motivation Without refactoring, Einsum CUDA kernel need two parallel class hierarchies and pervasive `#ifdef BUILD_CUDA_EP_AS_PLUGIN` guards — almost every code block (kernel registration, `Compute`/`ComputeInternal`, `DeviceCompute`, preprocessor/processor instantiation for each data type) was duplicated. This made the code hard to read, hard to maintain, and error-prone due to divergence risk between the two paths. Additionally, `EinsumComputePreprocessor` and `EinsumTypedComputeProcessor<T>` crossed the shared-library boundary via ~15 virtual methods in `ProviderHostCPU` (factory, Run, SetDeviceHelpers, custom deleters for each of float/double/MLFloat16), adding fragile coupling. ### Key Changes #### 1. Unified CUDA Einsum class (`einsum.h` / `einsum.cc`) - **Before**: Two `Einsum` class definitions — one inheriting `CudaKernel` (plugin path) and one inheriting `onnxruntime::Einsum` (non-plugin path), with a vtable stub `DeviceCompute` that `assert(false)`s. - **After**: Single `class Einsum final : public CudaKernel` with `ComputeInternal`. No conditional compilation in CUDA Einsum files (zero `BUILD_CUDA_EP_AS_PLUGIN` guards remain). - CUDA `einsum.cc`: 208 → 101 lines (−51%). CUDA `einsum.h`: 69 → 37 lines (−46%). #### 2. Decoupled `EinsumCudaAssets` from `CUDAExecutionProvider*` - **Before**: `EinsumCudaAssets` held a raw `const CUDAExecutionProvider*` pointer. - **After**: Stores only the specific values needed (`cudaDeviceProp*`, `cublasHandle_t`, `bool use_tf32_`), extracted by the caller using `CudaKernel` accessors (`GetDeviceProp()`, `GetCublasHandle()`, `cuda_ep_->UseTF32()`). This is better dependency inversion — the shared Einsum utilities no longer couple to the concrete EP class. #### 3. Added `GetComputeStream` helper to `CudaKernel` - New inline helper `CudaKernel::GetComputeStream(OpKernelContext*)` for kernels that need the underlying ORT `Stream*` object (not just the raw `cudaStream_t`). #### 4. Inlined Einsum utility logic into headers - Moved implementation from `.cc` to `.h` for `EinsumComputePreprocessor` and `EinsumTypedComputeProcessor<T>` (guarded by `#ifndef SHARED_PROVIDER` / `#else`), so the plugin build can compile the logic directly instead of going through the provider bridge. - `einsum_compute_preprocessor.cc`: 479 → 8 lines. `einsum_typed_compute_processor.cc`: 459 → 19 lines. - Corresponding `.h` files grew to absorb the inlined implementations. #### 5. Simplified provider bridge surface for Einsum - **Removed** ~15 virtual methods from `ProviderHostCPU`: `EinsumComputePreprocessor__operator_delete`, `EinsumComputePreprocessor__Create`, `EinsumComputePreprocessor__Run`, `EinsumComputePreprocessor__SetDeviceHelpers`, `EinsumTypedComputeProcessor__operator_delete` (×3 types), `EinsumTypedComputeProcessor_*__Create` (×3), `EinsumTypedComputeProcessor__SetDeviceHelpers` (×3), `EinsumTypedComputeProcessor__Run` (×3). - **Replaced** with 3 simpler `EinsumTypedComputeProcessor_*_Compute` pass-through methods (one per data type: float, double, MLFloat16). - Removed `EinsumComputePreprocessor` and `EinsumTypedComputeProcessor<T>` forward declarations from `provider_api.h`. #### 6. Added `CreateTensor` device helper - New `DeviceHelpers::CreateTensor` callback allows the caller to control tensor allocation. Passed through `Transpose`, `MatMul`, and `ReduceSum` wrappers, completing the device-abstraction pattern (previously `std::make_unique<Tensor>` was hardcoded). - Also inlined the `EinsumOp::Transpose`, `EinsumOp::MatMul`, and `EinsumOp::ReduceSum` wrapper functions into the header (from `einsum_auxiliary_ops.cc`). ### Files Changed (16 files, +1370 / −1333) | File | Change | |---|---| | `cpu/cpu_provider_shared.h` | Replace ~15 bridge virtuals with 3 `Compute` methods; add `EinsumOp::DeviceHelpers` type aliases | | `cpu/cpu_provider_shared.cc` | Update implementations matching the new `ProviderHostCPU` interface | | `cpu/math/einsum.cc` | Minor adjustment to CPU Einsum | | `cpu/math/einsum_utils/einsum_auxiliary_ops.cc` | Move `Transpose`/`MatMul`/`ReduceSum` wrappers to header (inline) | | `cpu/math/einsum_utils/einsum_auxiliary_ops.h` | Inline wrapper functions; add `CreateTensor` helper signature | | `cpu/math/einsum_utils/einsum_compute_preprocessor.cc` | Shrink to stub (logic moved to header) | | `cpu/math/einsum_utils/einsum_compute_preprocessor.h` | Inline full implementation; add `#ifndef SHARED_PROVIDER` / `#else` dual definition | | `cpu/math/einsum_utils/einsum_typed_compute_processor.cc` | Shrink to stub (logic moved to header) | | `cpu/math/einsum_utils/einsum_typed_compute_processor.h` | Inline full implementation with `SHARED_PROVIDER` dual definition; add `FinalizeOutput` and `PairwiseOperandProcess` | | `cuda/cuda_kernel.h` | Add `GetComputeStream(OpKernelContext*)` helper | | `cuda/math/einsum.cc` | Unified single-path implementation (remove all `#ifdef` guards) | | `cuda/math/einsum.h` | Single `Einsum : CudaKernel` class (remove dual hierarchy) | | `cuda/math/einsum_utils/einsum_auxiliary_ops.cc` | Update to use new `EinsumCudaAssets` fields | | `cuda/math/einsum_utils/einsum_auxiliary_ops.h` | Decouple `EinsumCudaAssets` from EP pointer; add `CreateTensor` | | `shared_library/provider_api.h` | Remove Einsum forward declarations and header include | | `shared_library/provider_bridge_provider.cc` | Replace factory-based bridge with direct `Run()` → `Compute` forwarding | ### Motivation and Context This is a preparatory refactoring for the CUDA plugin EP work. By making CUDA Einsum self-contained (not inheriting from the CPU `onnxruntime::Einsum` base class), it can be compiled as part of a plugin shared library without depending on the CPU EP's class vtable layout across the ABI boundary.
1 parent a0f6f78 commit 6ef21f0

16 files changed

Lines changed: 1379 additions & 1335 deletions

onnxruntime/core/providers/cpu/cpu_provider_shared.cc

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include "core/providers/cpu/controlflow/scan.h"
1313
#include "core/providers/cpu/math/cumsum.h"
1414
#include "core/providers/cpu/math/einsum.h"
15+
#include "core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h"
16+
#include "core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h"
1517
#include "core/providers/cpu/object_detection/non_max_suppression.h"
1618
#include "core/providers/cpu/object_detection/roialign.h"
1719
#include "core/providers/cpu/tensor/concatbase.h"
@@ -169,35 +171,63 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
169171

170172
Status Einsum__Compute(const Einsum* p, OpKernelContext* context) override { return p->Einsum::Compute(context); }
171173

172-
// EinsumComputePreprocessor (wrapped)
173-
void EinsumComputePreprocessor__operator_delete(EinsumComputePreprocessor* p) override { delete p; }
174-
std::unique_ptr<EinsumComputePreprocessor> EinsumComputePreprocessor__Create(EinsumEquationPreprocessor& equation_preprocessor,
175-
const std::vector<const Tensor*>& inputs,
176-
AllocatorPtr allocator,
177-
void* einsum_cuda_assets) override { return std::make_unique<EinsumComputePreprocessor>(equation_preprocessor, inputs, allocator, einsum_cuda_assets); }
178-
179-
Status EinsumComputePreprocessor__Run(EinsumComputePreprocessor* p) override { return p->Run(); }
180-
void EinsumComputePreprocessor__SetDeviceHelpers(EinsumComputePreprocessor* p, const EinsumOp::DeviceHelpers::Diagonal& diagonal_func, const EinsumOp::DeviceHelpers::Transpose& transpose_func) override { return p->SetDeviceHelpers(diagonal_func, transpose_func); }
181-
182-
// EinsumTypedComputeProcessor (wrapped)
183-
void EinsumTypedComputeProcessor__operator_delete(EinsumTypedComputeProcessor<float>* p) override { delete p; }
184-
void EinsumTypedComputeProcessor__operator_delete(EinsumTypedComputeProcessor<double>* p) override { delete p; }
185-
void EinsumTypedComputeProcessor__operator_delete(EinsumTypedComputeProcessor<MLFloat16>* p) override { delete p; }
186-
std::unique_ptr<EinsumTypedComputeProcessor<float>> EinsumTypedComputeProcessor_float__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, const void* mlas_backend_config, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique<EinsumTypedComputeProcessor<float>>(context, allocator, tp, mlas_backend_config, einsum_compute_preprocessor, einsum_cuda_assets); }
187-
std::unique_ptr<EinsumTypedComputeProcessor<double>> EinsumTypedComputeProcessor_double__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, const void* mlas_backend_config, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique<EinsumTypedComputeProcessor<double>>(context, allocator, tp, mlas_backend_config, einsum_compute_preprocessor, einsum_cuda_assets); }
188-
std::unique_ptr<EinsumTypedComputeProcessor<MLFloat16>> EinsumTypedComputeProcessor_MLFloat16__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, const void* mlas_backend_config, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique<EinsumTypedComputeProcessor<MLFloat16>>(context, allocator, tp, mlas_backend_config, einsum_compute_preprocessor, einsum_cuda_assets); }
189-
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<float>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<float>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<float>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); }
190-
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<double>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<double>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<double>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); }
191-
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<MLFloat16>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<MLFloat16>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<MLFloat16>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zero_buffer_func); }
192-
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<float>* p) override { return p->Run(); }
193-
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<double>* p) override { return p->Run(); }
194-
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<MLFloat16>* p) override { return p->Run(); }
195174
void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
196175
gsl::span<const int64_t> input_dims,
197176
InlinedVector<float>& scales) const override {
198177
p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales);
199178
}
200179

180+
Status EinsumTypedComputeProcessor_float_Compute(
181+
OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp,
182+
const void* mlas_backend_config, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets,
183+
const EinsumOp::DeviceHelpers::Transpose& device_transpose_func,
184+
const EinsumOp::DeviceHelpers::MatMul<float>& device_matmul_func,
185+
const EinsumOp::DeviceHelpers::ReduceSum<float>& device_reduce_sum_func,
186+
const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func,
187+
const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func,
188+
const EinsumOp::DeviceHelpers::CreateTensor& device_create_tensor_func) override {
189+
EinsumTypedComputeProcessor<float> einsum_compute_processor(
190+
context, allocator, tp, mlas_backend_config, einsum_compute_preprocessor, einsum_cuda_assets);
191+
einsum_compute_processor.SetDeviceHelpers(
192+
device_transpose_func, device_matmul_func, device_reduce_sum_func,
193+
device_data_copy_func, device_zero_buffer_func, device_create_tensor_func);
194+
return einsum_compute_processor.Run();
195+
}
196+
197+
Status EinsumTypedComputeProcessor_double_Compute(
198+
OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp,
199+
const void* mlas_backend_config, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets,
200+
const EinsumOp::DeviceHelpers::Transpose& device_transpose_func,
201+
const EinsumOp::DeviceHelpers::MatMul<double>& device_matmul_func,
202+
const EinsumOp::DeviceHelpers::ReduceSum<double>& device_reduce_sum_func,
203+
const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func,
204+
const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func,
205+
const EinsumOp::DeviceHelpers::CreateTensor& device_create_tensor_func) override {
206+
EinsumTypedComputeProcessor<double> einsum_compute_processor(
207+
context, allocator, tp, mlas_backend_config, einsum_compute_preprocessor, einsum_cuda_assets);
208+
einsum_compute_processor.SetDeviceHelpers(
209+
device_transpose_func, device_matmul_func, device_reduce_sum_func,
210+
device_data_copy_func, device_zero_buffer_func, device_create_tensor_func);
211+
return einsum_compute_processor.Run();
212+
}
213+
214+
Status EinsumTypedComputeProcessor_MLFloat16_Compute(
215+
OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp,
216+
const void* mlas_backend_config, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets,
217+
const EinsumOp::DeviceHelpers::Transpose& device_transpose_func,
218+
const EinsumOp::DeviceHelpers::MatMul<MLFloat16>& device_matmul_func,
219+
const EinsumOp::DeviceHelpers::ReduceSum<MLFloat16>& device_reduce_sum_func,
220+
const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func,
221+
const EinsumOp::DeviceHelpers::ZeroBuffer& device_zero_buffer_func,
222+
const EinsumOp::DeviceHelpers::CreateTensor& device_create_tensor_func) override {
223+
EinsumTypedComputeProcessor<MLFloat16> einsum_compute_processor(
224+
context, allocator, tp, mlas_backend_config, einsum_compute_preprocessor, einsum_cuda_assets);
225+
einsum_compute_processor.SetDeviceHelpers(
226+
device_transpose_func, device_matmul_func, device_reduce_sum_func,
227+
device_data_copy_func, device_zero_buffer_func, device_create_tensor_func);
228+
return einsum_compute_processor.Run();
229+
}
230+
201231
#ifndef DISABLE_CONTRIB_OPS
202232
Status embed_layer_norm__CheckInputs(const OpKernelContext* context, bool quantizedVersion) override {
203233
return contrib::embed_layer_norm::CheckInputs(context, quantizedVersion);

0 commit comments

Comments
 (0)