Commit 6ef21f0
authored
[CUDA Plugin] Refactoring Einsum (microsoft#27606)
## Description
Refactor the CUDA Einsum kernel to unify the CUDA Einsum implementation
into a single code path that works for both the standard CUDA EP build
and the plugin EP build.
### Motivation
Without refactoring, Einsum CUDA kernel need two parallel class
hierarchies and pervasive `#ifdef BUILD_CUDA_EP_AS_PLUGIN` guards —
almost every code block (kernel registration,
`Compute`/`ComputeInternal`, `DeviceCompute`, preprocessor/processor
instantiation for each data type) was duplicated. This made the code
hard to read, hard to maintain, and error-prone due to divergence risk
between the two paths.
Additionally, `EinsumComputePreprocessor` and
`EinsumTypedComputeProcessor<T>` crossed the shared-library boundary via
~15 virtual methods in `ProviderHostCPU` (factory, Run,
SetDeviceHelpers, custom deleters for each of float/double/MLFloat16),
adding fragile coupling.
### Key Changes
#### 1. Unified CUDA Einsum class (`einsum.h` / `einsum.cc`)
- **Before**: Two `Einsum` class definitions — one inheriting
`CudaKernel` (plugin path) and one inheriting `onnxruntime::Einsum`
(non-plugin path), with a vtable stub `DeviceCompute` that
`assert(false)`s.
- **After**: Single `class Einsum final : public CudaKernel` with
`ComputeInternal`. No conditional compilation in CUDA Einsum files (zero
`BUILD_CUDA_EP_AS_PLUGIN` guards remain).
- CUDA `einsum.cc`: 208 → 101 lines (−51%). CUDA `einsum.h`: 69 → 37
lines (−46%).
#### 2. Decoupled `EinsumCudaAssets` from `CUDAExecutionProvider*`
- **Before**: `EinsumCudaAssets` held a raw `const
CUDAExecutionProvider*` pointer.
- **After**: Stores only the specific values needed (`cudaDeviceProp*`,
`cublasHandle_t`, `bool use_tf32_`), extracted by the caller using
`CudaKernel` accessors (`GetDeviceProp()`, `GetCublasHandle()`,
`cuda_ep_->UseTF32()`). This is better dependency inversion — the shared
Einsum utilities no longer couple to the concrete EP class.
#### 3. Added `GetComputeStream` helper to `CudaKernel`
- New inline helper `CudaKernel::GetComputeStream(OpKernelContext*)` for
kernels that need the underlying ORT `Stream*` object (not just the raw
`cudaStream_t`).
#### 4. Inlined Einsum utility logic into headers
- Moved implementation from `.cc` to `.h` for
`EinsumComputePreprocessor` and `EinsumTypedComputeProcessor<T>`
(guarded by `#ifndef SHARED_PROVIDER` / `#else`), so the plugin build
can compile the logic directly instead of going through the provider
bridge.
- `einsum_compute_preprocessor.cc`: 479 → 8 lines.
`einsum_typed_compute_processor.cc`: 459 → 19 lines.
- Corresponding `.h` files grew to absorb the inlined implementations.
#### 5. Simplified provider bridge surface for Einsum
- **Removed** ~15 virtual methods from `ProviderHostCPU`:
`EinsumComputePreprocessor__operator_delete`,
`EinsumComputePreprocessor__Create`, `EinsumComputePreprocessor__Run`,
`EinsumComputePreprocessor__SetDeviceHelpers`,
`EinsumTypedComputeProcessor__operator_delete` (×3 types),
`EinsumTypedComputeProcessor_*__Create` (×3),
`EinsumTypedComputeProcessor__SetDeviceHelpers` (×3),
`EinsumTypedComputeProcessor__Run` (×3).
- **Replaced** with 3 simpler `EinsumTypedComputeProcessor_*_Compute`
pass-through methods (one per data type: float, double, MLFloat16).
- Removed `EinsumComputePreprocessor` and
`EinsumTypedComputeProcessor<T>` forward declarations from
`provider_api.h`.
#### 6. Added `CreateTensor` device helper
- New `DeviceHelpers::CreateTensor` callback allows the caller to
control tensor allocation. Passed through `Transpose`, `MatMul`, and
`ReduceSum` wrappers, completing the device-abstraction pattern
(previously `std::make_unique<Tensor>` was hardcoded).
- Also inlined the `EinsumOp::Transpose`, `EinsumOp::MatMul`, and
`EinsumOp::ReduceSum` wrapper functions into the header (from
`einsum_auxiliary_ops.cc`).
### Files Changed (16 files, +1370 / −1333)
| File | Change |
|---|---|
| `cpu/cpu_provider_shared.h` | Replace ~15 bridge virtuals with 3
`Compute` methods; add `EinsumOp::DeviceHelpers` type aliases |
| `cpu/cpu_provider_shared.cc` | Update implementations matching the new
`ProviderHostCPU` interface |
| `cpu/math/einsum.cc` | Minor adjustment to CPU Einsum |
| `cpu/math/einsum_utils/einsum_auxiliary_ops.cc` | Move
`Transpose`/`MatMul`/`ReduceSum` wrappers to header (inline) |
| `cpu/math/einsum_utils/einsum_auxiliary_ops.h` | Inline wrapper
functions; add `CreateTensor` helper signature |
| `cpu/math/einsum_utils/einsum_compute_preprocessor.cc` | Shrink to
stub (logic moved to header) |
| `cpu/math/einsum_utils/einsum_compute_preprocessor.h` | Inline full
implementation; add `#ifndef SHARED_PROVIDER` / `#else` dual definition
|
| `cpu/math/einsum_utils/einsum_typed_compute_processor.cc` | Shrink to
stub (logic moved to header) |
| `cpu/math/einsum_utils/einsum_typed_compute_processor.h` | Inline full
implementation with `SHARED_PROVIDER` dual definition; add
`FinalizeOutput` and `PairwiseOperandProcess` |
| `cuda/cuda_kernel.h` | Add `GetComputeStream(OpKernelContext*)` helper
|
| `cuda/math/einsum.cc` | Unified single-path implementation (remove all
`#ifdef` guards) |
| `cuda/math/einsum.h` | Single `Einsum : CudaKernel` class (remove dual
hierarchy) |
| `cuda/math/einsum_utils/einsum_auxiliary_ops.cc` | Update to use new
`EinsumCudaAssets` fields |
| `cuda/math/einsum_utils/einsum_auxiliary_ops.h` | Decouple
`EinsumCudaAssets` from EP pointer; add `CreateTensor` |
| `shared_library/provider_api.h` | Remove Einsum forward declarations
and header include |
| `shared_library/provider_bridge_provider.cc` | Replace factory-based
bridge with direct `Run()` → `Compute` forwarding |
### Motivation and Context
This is a preparatory refactoring for the CUDA plugin EP work. By making
CUDA Einsum self-contained (not inheriting from the CPU
`onnxruntime::Einsum` base class), it can be compiled as part of a
plugin shared library without depending on the CPU EP's class vtable
layout across the ABI boundary.1 parent a0f6f78 commit 6ef21f0
16 files changed
Lines changed: 1379 additions & 1335 deletions
File tree
- onnxruntime/core/providers
- cpu
- math
- einsum_utils
- cuda
- math
- einsum_utils
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| 15 | + | |
| 16 | + | |
15 | 17 | | |
16 | 18 | | |
17 | 19 | | |
| |||
169 | 171 | | |
170 | 172 | | |
171 | 173 | | |
172 | | - | |
173 | | - | |
174 | | - | |
175 | | - | |
176 | | - | |
177 | | - | |
178 | | - | |
179 | | - | |
180 | | - | |
181 | | - | |
182 | | - | |
183 | | - | |
184 | | - | |
185 | | - | |
186 | | - | |
187 | | - | |
188 | | - | |
189 | | - | |
190 | | - | |
191 | | - | |
192 | | - | |
193 | | - | |
194 | | - | |
195 | 174 | | |
196 | 175 | | |
197 | 176 | | |
198 | 177 | | |
199 | 178 | | |
200 | 179 | | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
201 | 231 | | |
202 | 232 | | |
203 | 233 | | |
| |||
0 commit comments