Skip to content

Commit 49912be

Browse files
committed
Add cuFFTDx-backed FFT2 JIT support
Generate JIT classes and LTO IR for single-block C2C fft2/ifft2 fusions, including shared-memory tiling through cuFFTDx 1D passes. Teach the JIT launcher about grouped 2D blocks and vectorized EPT indexing so FFT2 operators can return multiple columns per thread. Document the supported FFT2 JIT shape/type limits and add forward/inverse FFT2 JIT fusion coverage.
1 parent 454f3e0 commit 49912be

9 files changed

Lines changed: 615 additions & 29 deletions

File tree

docs_input/api/dft/fft/fft2d.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ fft2
55

66
Perform a 2D FFT. Batching is supported for any tensor with a rank higher than 2.
77

8+
FFT kernel fusion is supported by cuFFTDx for complex-to-complex power-of-two square
9+
transforms that fit in a single CUDA block when ``-DMATX_EN_MATHDX=ON`` is enabled.
10+
Unsupported 2D FFT sizes and real-valued 2D FFTs use the existing cuFFT execution path.
11+
812
.. versionadded:: 0.6.0
913

1014
.. doxygenfunction:: fft2(const OpA &a, FFTNorm norm = FFTNorm::BACKWARD)
@@ -31,4 +35,4 @@ Examples
3135
:language: cpp
3236
:start-after: example-begin fft2-2
3337
:end-before: example-end fft2-2
34-
:dedent:
38+
:dedent:

docs_input/api/dft/fft/ifft2.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ ifft2
55

66
Perform a 2D inverse FFT. Batching is supported for any tensor with a rank higher than 2.
77

8+
IFFT kernel fusion is supported by cuFFTDx for complex-to-complex power-of-two square
9+
transforms that fit in a single CUDA block when ``-DMATX_EN_MATHDX=ON`` is enabled.
10+
Unsupported 2D IFFT sizes and real-valued inverse 2D FFTs use the existing cuFFT execution path.
11+
812
.. versionadded:: 0.6.0
913

1014
.. doxygenfunction:: ifft2(const OpA &a, FFTNorm norm = FFTNorm::BACKWARD)
@@ -31,4 +35,4 @@ Examples
3135
:language: cpp
3236
:start-after: example-begin ifft2-2
3337
:end-before: example-end ifft2-2
34-
:dedent:
38+
:dedent:

docs_input/basics/fusion.rst

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ CUDA JIT Kernel Fusion
5959

6060
CUDA JIT kernel fusion is considered an experimental feature. There may be bugs that don't occur with JIT disabled, and new features are being added over time.
6161

62-
MatX supports CUDA JIT kernel fusion that compiles the entire expression into a single kernel. Currently this is enabled
63-
for all standard MatX element-wise operators and FFT and GEMM operations via MathDx. To enable fusion with MathDx,
62+
MatX supports CUDA JIT kernel fusion that compiles the entire expression into a single kernel. Currently this is enabled
63+
for all standard MatX element-wise operators and FFT and GEMM operations via MathDx. cuFFTDx supports 1D FFT fusion and
64+
single-block complex-to-complex 2D ``fft2``/``ifft2`` fusion for supported power-of-two square transforms. To enable fusion with MathDx,
6465
the following options must be enabled: ``-DMATX_EN_MATHDX=ON``. Once enabled, the ``CUDAJITExecutor`` can be used perform JIT compilation
6566
in supported situations. If the expression cannot be JIT compiled, the JITExecutor may throw an error.
6667

@@ -118,12 +119,10 @@ MathDx Compatibility
118119
- Enabled via ``-DMATX_EN_MATHDX=ON`` for GEMM fusion paths.
119120
* - cuFFTDx
120121
- Yes
121-
- Enabled via ``-DMATX_EN_MATHDX=ON`` for FFT fusion paths.
122+
- Enabled via ``-DMATX_EN_MATHDX=ON`` for 1D FFT fusion paths and supported single-block 2D C2C FFT fusion paths.
122123
* - cuSolverDx
123124
- No
124125
- Not supported yet by MatX CUDA JIT fusion.
125126
* - cuRandDx
126127
- No
127128
- Not supported yet by MatX CUDA JIT fusion.
128-
129-

include/matx/core/get_grid_dims.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,12 +338,17 @@ inline bool get_grid_dims_block_reduce(dim3 &blocks, dim3 &threads, const cuda::
338338
template <int RANK>
339339
inline bool get_grid_dims_block_2d(dim3 &blocks, dim3 &threads,
340340
const cuda::std::array<index_t, RANK> &sizes,
341-
int block_dim) {
341+
int block_dim,
342+
int groups_per_block = 1) {
342343
// Threads are set to block_dim in x, y and z are 1
343344
// All threads cooperate via flattened thread ID in the kernel
344345
threads.x = block_dim;
345-
threads.y = 1;
346+
threads.y = groups_per_block;
346347
threads.z = 1;
348+
349+
if (static_cast<int64_t>(threads.x) * static_cast<int64_t>(threads.y) > 1024) {
350+
MATX_THROW(matxInvalidParameter, "Block2D launch exceeds CUDA maximum threads per block");
351+
}
347352

348353
// Grid covers batch dimensions only (dims 0 to RANK-3)
349354
blocks.x = 1;
@@ -372,7 +377,8 @@ inline bool get_grid_dims_block_2d(dim3 &blocks, dim3 &threads,
372377
}
373378
}
374379

375-
MATX_LOG_DEBUG("Block2D: Blocks {}x{}x{} Threads {}x{}x{}", blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z);
380+
MATX_LOG_DEBUG("Block2D: Blocks {}x{}x{} Threads {}x{}x{} groups_per_block={}",
381+
blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, groups_per_block);
376382

377383
// No stride needed for now - could be extended for very large batches
378384
return false;

include/matx/executors/jit_cuda.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -289,16 +289,20 @@ namespace matx
289289
if (block_dim_range[0] == detail::capability_attributes<detail::OperatorCapability::BLOCK_DIM>::invalid) {
290290
MATX_THROW(matxInvalidParameter, "No valid JIT block dimension satisfies the fused operator requirements");
291291
}
292+
auto group_range = detail::get_operator_capability<detail::OperatorCapability::GROUPS_PER_BLOCK>(op);
292293
block_size = block_dim_range[0];
293-
stride = detail::get_grid_dims_block_2d<RANK>(blocks, threads, sizes, block_size);
294+
groups_per_block = group_range[0];
295+
if (groups_per_block == detail::capability_attributes<detail::OperatorCapability::GROUPS_PER_BLOCK>::invalid) {
296+
MATX_THROW(matxInvalidParameter, "No valid JIT groups-per-block value satisfies the fused operator requirements");
297+
}
298+
stride = detail::get_grid_dims_block_2d<RANK>(blocks, threads, sizes, block_size, groups_per_block);
294299

295-
// EPT is 1 for 2D block operators - the operator handles elements internally
296-
best_ept = detail::ElementsPerThread::ONE;
300+
// Block-level operators can still return vectorized output lanes.
301+
best_ept = jit_ept_bounds[1];
297302
shm_size = detail::get_operator_capability<detail::OperatorCapability::DYN_SHM_SIZE>(op);
298-
groups_per_block = 1;
299303

300-
MATX_LOG_DEBUG("Block2D: EPT {}, Shm size {}, Block size {}",
301-
static_cast<int>(best_ept), shm_size, block_size);
304+
MATX_LOG_DEBUG("Block2D: EPT {}, Shm size {}, Block size {}, Groups per block {}",
305+
static_cast<int>(best_ept), shm_size, block_size, groups_per_block);
302306
} else if constexpr (is_dynamic_rank_op_v<Op>) {
303307
// Dynamic tensor expressions: pre-compiled kernels don't exist for this Op type,
304308
// so we cannot query register pressure. Use conservative defaults.

include/matx/executors/jit_kernel.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,10 @@ namespace matx {
395395
template <class Op>\n\
396396
__global__ void matxOpT2KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1) {\n\
397397
int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\
398-
matx::index_t idx = tid % size1;\n\
399-
matx::index_t idy = tid / size1;\n\
398+
constexpr int ept = static_cast<int>(CurrentCapabilities::ept);\n\
399+
matx::index_t size1_vectors = (size1 + ept - 1) / ept;\n\
400+
matx::index_t idx = tid % size1_vectors;\n\
401+
matx::index_t idy = tid / size1_vectors;\n\
400402
if constexpr (cuda::std::is_pointer_v<Op>) {\n\
401403
(*op).template operator()<CurrentCapabilities>(idy, idx);\n\
402404
} else {\n\
@@ -407,8 +409,10 @@ namespace matx {
407409
template <class Op>\n\
408410
__global__ void matxOpT3KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2) {\n\
409411
int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\
410-
matx::index_t idx = tid % size2;\n\
411-
matx::index_t idy = tid / size2;\n\
412+
constexpr int ept = static_cast<int>(CurrentCapabilities::ept);\n\
413+
matx::index_t size2_vectors = (size2 + ept - 1) / ept;\n\
414+
matx::index_t idx = tid % size2_vectors;\n\
415+
matx::index_t idy = tid / size2_vectors;\n\
412416
matx::index_t idz = blockIdx.x;\n\
413417
if constexpr (cuda::std::is_pointer_v<Op>) {\n\
414418
(*op).template operator()<CurrentCapabilities>(idz, idy, idx);\n\
@@ -420,8 +424,10 @@ namespace matx {
420424
template <class Op>\n\
421425
__global__ void matxOpT4KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2, matx::index_t size3) {\n\
422426
int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\
423-
matx::index_t idx = tid % size3;\n\
424-
matx::index_t idy = tid / size3;\n\
427+
constexpr int ept = static_cast<int>(CurrentCapabilities::ept);\n\
428+
matx::index_t size3_vectors = (size3 + ept - 1) / ept;\n\
429+
matx::index_t idx = tid % size3_vectors;\n\
430+
matx::index_t idy = tid / size3_vectors;\n\
425431
matx::index_t idz = blockIdx.x;\n\
426432
matx::index_t idw = blockIdx.y;\n\
427433
if constexpr (cuda::std::is_pointer_v<Op>) {\n\

0 commit comments

Comments
 (0)