Skip to content

Commit 36d180a

Browse files
committed
Add cuFFTDx-backed FFT2 JIT support
Generate JIT classes and LTO IR for single-block C2C fft2/ifft2 fusions, including shared-memory tiling through cuFFTDx 1D passes. Teach the JIT launcher about grouped 2D blocks and vectorized EPT indexing so FFT2 operators can return multiple columns per thread. Document the supported FFT2 JIT shape/type limits and add forward/inverse FFT2 JIT fusion coverage.
1 parent 580b8cc commit 36d180a

9 files changed

Lines changed: 620 additions & 30 deletions

File tree

docs_input/api/dft/fft/fft2d.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ fft2
55

66
Perform a 2D FFT. Batching is supported for any tensor with a rank higher than 2.
77

8+
FFT kernel fusion is supported by cuFFTDx for complex-to-complex power-of-two square
9+
transforms that fit in a single CUDA block when ``-DMATX_EN_MATHDX=ON`` is enabled.
10+
Unsupported 2D FFT sizes and real-valued 2D FFTs use the existing cuFFT execution path.
11+
812
.. versionadded:: 0.6.0
913

1014
.. doxygenfunction:: fft2(const OpA &a, FFTNorm norm = FFTNorm::BACKWARD)
@@ -31,4 +35,4 @@ Examples
3135
:language: cpp
3236
:start-after: example-begin fft2-2
3337
:end-before: example-end fft2-2
34-
:dedent:
38+
:dedent:

docs_input/api/dft/fft/ifft2.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ ifft2
55

66
Perform a 2D inverse FFT. Batching is supported for any tensor with a rank higher than 2.
77

8+
IFFT kernel fusion is supported by cuFFTDx for complex-to-complex power-of-two square
9+
transforms that fit in a single CUDA block when ``-DMATX_EN_MATHDX=ON`` is enabled.
10+
Unsupported 2D IFFT sizes and real-valued inverse 2D FFTs use the existing cuFFT execution path.
11+
812
.. versionadded:: 0.6.0
913

1014
.. doxygenfunction:: ifft2(const OpA &a, FFTNorm norm = FFTNorm::BACKWARD)
@@ -31,4 +35,4 @@ Examples
3135
:language: cpp
3236
:start-after: example-begin ifft2-2
3337
:end-before: example-end ifft2-2
34-
:dedent:
38+
:dedent:

docs_input/basics/fusion.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ CUDA JIT Kernel Fusion
6161

6262
MatX supports CUDA JIT kernel fusion that compiles the entire expression into a single kernel. Currently this is enabled
6363
for all standard MatX element-wise operators, FFT operations via cuFFTDx, GEMM operations via cuBLASDx, and selected
64-
solver operations via cuSolverDx. To enable fusion with MathDx, the following option must be enabled:
64+
solver operations via cuSolverDx. cuFFTDx supports 1D FFT fusion and single-block complex-to-complex 2D ``fft2``/``ifft2``
65+
fusion for supported power-of-two square transforms. To enable fusion with MathDx, the following option must be enabled:
6566
``-DMATX_EN_MATHDX=ON``. MathDx support also enables the NVRTC-based JIT support used by ``CUDAJITExecutor``.
6667
Once enabled, the ``CUDAJITExecutor`` can be used to perform JIT compilation in supported situations. If the expression
6768
cannot be JIT compiled, the ``CUDAJITExecutor`` may throw an error.
@@ -129,9 +130,9 @@ MathDx Compatibility
129130
cuSolverDx ``inv`` when their block-dimension ranges intersect.
130131
* - cuFFTDx
131132
- Yes
132-
- Enabled via ``-DMATX_EN_MATHDX=ON`` for compatible FFT fusion paths. cuFFTDx has stricter launch-shape
133-
requirements than the generic element-wise JIT path, and does not generally compose with cuBLASDx/cuSolverDx
134-
operations that require a different block/grid model.
133+
- Enabled via ``-DMATX_EN_MATHDX=ON`` for compatible 1D FFT fusion paths and supported single-block 2D C2C FFT
134+
fusion paths. cuFFTDx has stricter launch-shape requirements than the generic element-wise JIT path, and does not
135+
generally compose with cuBLASDx/cuSolverDx operations that require a different block/grid model.
135136
* - cuSolverDx
136137
- Partial
137138
- Enabled via ``-DMATX_EN_MATHDX=ON`` for ``chol``, ``inv``, and selected projection outputs from

include/matx/core/get_grid_dims.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -338,12 +338,17 @@ inline bool get_grid_dims_block_reduce(dim3 &blocks, dim3 &threads, const cuda::
338338
template <int RANK>
339339
inline bool get_grid_dims_block_2d(dim3 &blocks, dim3 &threads,
340340
const cuda::std::array<index_t, RANK> &sizes,
341-
int block_dim) {
341+
int block_dim,
342+
int groups_per_block = 1) {
342343
// Threads are set to block_dim in x, y and z are 1
343344
// All threads cooperate via flattened thread ID in the kernel
344345
threads.x = block_dim;
345-
threads.y = 1;
346+
threads.y = groups_per_block;
346347
threads.z = 1;
348+
349+
if (static_cast<int64_t>(threads.x) * static_cast<int64_t>(threads.y) > 1024) {
350+
MATX_THROW(matxInvalidParameter, "Block2D launch exceeds CUDA maximum threads per block");
351+
}
347352

348353
// Grid covers batch dimensions only (dims 0 to RANK-3)
349354
blocks.x = 1;
@@ -372,7 +377,8 @@ inline bool get_grid_dims_block_2d(dim3 &blocks, dim3 &threads,
372377
}
373378
}
374379

375-
MATX_LOG_DEBUG("Block2D: Blocks {}x{}x{} Threads {}x{}x{}", blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z);
380+
MATX_LOG_DEBUG("Block2D: Blocks {}x{}x{} Threads {}x{}x{} groups_per_block={}",
381+
blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, groups_per_block);
376382

377383
// No stride needed for now - could be extended for very large batches
378384
return false;
@@ -427,11 +433,12 @@ template <int RANK>
427433
inline bool get_grid_dims_block_pass_through(dim3 &blocks, dim3 &threads,
428434
const cuda::std::array<index_t, RANK> &sizes,
429435
int block_dim,
430-
int inner_rank) {
436+
int inner_rank,
437+
int groups_per_block = 1) {
431438
if (inner_rank == 1) {
432439
return get_grid_dims_block_1d<RANK>(blocks, threads, sizes, block_dim);
433440
}
434-
return get_grid_dims_block_2d<RANK>(blocks, threads, sizes, block_dim);
441+
return get_grid_dims_block_2d<RANK>(blocks, threads, sizes, block_dim, groups_per_block);
435442
}
436443
} // end namespace detail
437444
} // end namespace matx

include/matx/executors/jit_cuda.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -510,15 +510,20 @@ namespace matx
510510
// For pass-through operators (e.g., MathDx), block dimensions are constrained by the operator.
511511
auto block_dim_range = detail::get_operator_capability<detail::OperatorCapability::BLOCK_DIM>(op);
512512
block_size = detail::SelectJITPassThroughBlockDim(block_dim_range);
513-
stride = detail::get_grid_dims_block_pass_through<RANK>(blocks, threads, sizes, block_size, pass_through_inner_rank);
513+
auto group_range = detail::get_operator_capability<detail::OperatorCapability::GROUPS_PER_BLOCK>(op);
514+
groups_per_block = group_range[0];
515+
if (groups_per_block == detail::capability_attributes<detail::OperatorCapability::GROUPS_PER_BLOCK>::invalid) {
516+
MATX_THROW(matxInvalidParameter, "No valid JIT groups-per-block value satisfies the fused operator requirements");
517+
}
518+
stride = detail::get_grid_dims_block_pass_through<RANK>(
519+
blocks, threads, sizes, block_size, pass_through_inner_rank, groups_per_block);
514520

515-
// EPT is 1 for 2D block operators - the operator handles elements internally
516-
best_ept = detail::ElementsPerThread::ONE;
521+
// Block-level operators can still return vectorized output lanes.
522+
best_ept = jit_ept_bounds[1];
517523
shm_size = detail::get_operator_capability<detail::OperatorCapability::DYN_SHM_SIZE>(op);
518-
groups_per_block = 1;
519524

520-
MATX_LOG_DEBUG("Block2D: EPT {}, Shm size {}, Block size {}",
521-
static_cast<int>(best_ept), shm_size, block_size);
525+
MATX_LOG_DEBUG("Block2D: EPT {}, Shm size {}, Block size {}, Groups per block {}",
526+
static_cast<int>(best_ept), shm_size, block_size, groups_per_block);
522527
} else if constexpr (is_dynamic_rank_op_v<Op>) {
523528
// Dynamic tensor expressions: pre-compiled kernels don't exist for this Op type,
524529
// so we cannot query register pressure. Use conservative defaults.

include/matx/executors/jit_kernel.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,10 @@ namespace matx {
445445
template <class Op>\n\
446446
__global__ void matxOpT2KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1) {\n\
447447
int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\
448-
matx::index_t idx = tid % size1;\n\
449-
matx::index_t idy = tid / size1;\n\
448+
constexpr int ept = static_cast<int>(CurrentCapabilities::ept);\n\
449+
matx::index_t size1_vectors = (size1 + ept - 1) / ept;\n\
450+
matx::index_t idx = tid % size1_vectors;\n\
451+
matx::index_t idy = tid / size1_vectors;\n\
450452
if constexpr (cuda::std::is_pointer_v<Op>) {\n\
451453
(*op).template operator()<CurrentCapabilities>(idy, idx);\n\
452454
} else {\n\
@@ -457,8 +459,10 @@ namespace matx {
457459
template <class Op>\n\
458460
__global__ void matxOpT3KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2) {\n\
459461
int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\
460-
matx::index_t idx = tid % size2;\n\
461-
matx::index_t idy = tid / size2;\n\
462+
constexpr int ept = static_cast<int>(CurrentCapabilities::ept);\n\
463+
matx::index_t size2_vectors = (size2 + ept - 1) / ept;\n\
464+
matx::index_t idx = tid % size2_vectors;\n\
465+
matx::index_t idy = tid / size2_vectors;\n\
462466
matx::index_t idz = blockIdx.x;\n\
463467
if constexpr (cuda::std::is_pointer_v<Op>) {\n\
464468
(*op).template operator()<CurrentCapabilities>(idz, idy, idx);\n\
@@ -470,8 +474,10 @@ namespace matx {
470474
template <class Op>\n\
471475
__global__ void matxOpT4KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2, matx::index_t size3) {\n\
472476
int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\
473-
matx::index_t idx = tid % size3;\n\
474-
matx::index_t idy = tid / size3;\n\
477+
constexpr int ept = static_cast<int>(CurrentCapabilities::ept);\n\
478+
matx::index_t size3_vectors = (size3 + ept - 1) / ept;\n\
479+
matx::index_t idx = tid % size3_vectors;\n\
480+
matx::index_t idy = tid / size3_vectors;\n\
475481
matx::index_t idz = blockIdx.x;\n\
476482
matx::index_t idw = blockIdx.y;\n\
477483
if constexpr (cuda::std::is_pointer_v<Op>) {\n\

0 commit comments

Comments
 (0)