Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs_input/api/dft/fft/fft2d.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ fft2

Perform a 2D FFT. Batching is supported for any tensor with a rank higher than 2.

FFT kernel fusion is supported by cuFFTDx for complex-to-complex power-of-two square
transforms that fit in a single CUDA block when ``-DMATX_EN_MATHDX=ON`` is enabled.
Unsupported 2D FFT sizes and real-valued 2D FFTs use the existing cuFFT execution path.

.. versionadded:: 0.6.0

.. doxygenfunction:: fft2(const OpA &a, FFTNorm norm = FFTNorm::BACKWARD)
Expand All @@ -31,4 +35,4 @@ Examples
:language: cpp
:start-after: example-begin fft2-2
:end-before: example-end fft2-2
:dedent:
:dedent:
6 changes: 5 additions & 1 deletion docs_input/api/dft/fft/ifft2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ ifft2

Perform a 2D inverse FFT. Batching is supported for any tensor with a rank higher than 2.

IFFT kernel fusion is supported by cuFFTDx for complex-to-complex power-of-two square
transforms that fit in a single CUDA block when ``-DMATX_EN_MATHDX=ON`` is enabled.
Unsupported 2D IFFT sizes and real-valued inverse 2D FFTs use the existing cuFFT execution path.

.. versionadded:: 0.6.0

.. doxygenfunction:: ifft2(const OpA &a, FFTNorm norm = FFTNorm::BACKWARD)
Expand All @@ -31,4 +35,4 @@ Examples
:language: cpp
:start-after: example-begin ifft2-2
:end-before: example-end ifft2-2
:dedent:
:dedent:
9 changes: 5 additions & 4 deletions docs_input/basics/fusion.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ CUDA JIT Kernel Fusion

MatX supports CUDA JIT kernel fusion that compiles the entire expression into a single kernel. Currently this is enabled
for all standard MatX element-wise operators, FFT operations via cuFFTDx, GEMM operations via cuBLASDx, and selected
solver operations via cuSolverDx. To enable fusion with MathDx, the following option must be enabled:
solver operations via cuSolverDx. cuFFTDx supports 1D FFT fusion and single-block complex-to-complex 2D ``fft2``/``ifft2``
fusion for supported power-of-two square transforms. To enable fusion with MathDx, the following option must be enabled:
``-DMATX_EN_MATHDX=ON``. MathDx support also enables the NVRTC-based JIT support used by ``CUDAJITExecutor``.
Once enabled, the ``CUDAJITExecutor`` can be used to perform JIT compilation in supported situations. If the expression
cannot be JIT compiled, the ``CUDAJITExecutor`` may throw an error.
Expand Down Expand Up @@ -129,9 +130,9 @@ MathDx Compatibility
cuSolverDx ``inv`` when their block-dimension ranges intersect.
* - cuFFTDx
- Yes
- Enabled via ``-DMATX_EN_MATHDX=ON`` for compatible FFT fusion paths. cuFFTDx has stricter launch-shape
requirements than the generic element-wise JIT path, and does not generally compose with cuBLASDx/cuSolverDx
operations that require a different block/grid model.
- Enabled via ``-DMATX_EN_MATHDX=ON`` for compatible 1D FFT fusion paths and supported single-block 2D C2C FFT
fusion paths. cuFFTDx has stricter launch-shape requirements than the generic element-wise JIT path, and does not
generally compose with cuBLASDx/cuSolverDx operations that require a different block/grid model.
* - cuSolverDx
- Partial
- Enabled via ``-DMATX_EN_MATHDX=ON`` for ``chol``, ``inv``, and selected projection outputs from
Expand Down
17 changes: 12 additions & 5 deletions include/matx/core/get_grid_dims.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,17 @@ inline bool get_grid_dims_block_reduce(dim3 &blocks, dim3 &threads, const cuda::
template <int RANK>
inline bool get_grid_dims_block_2d(dim3 &blocks, dim3 &threads,
const cuda::std::array<index_t, RANK> &sizes,
int block_dim) {
int block_dim,
int groups_per_block = 1) {
// Threads are set to block_dim in x, y and z are 1
// All threads cooperate via flattened thread ID in the kernel
threads.x = block_dim;
threads.y = 1;
threads.y = groups_per_block;
threads.z = 1;

if (static_cast<int64_t>(threads.x) * static_cast<int64_t>(threads.y) > 1024) {
MATX_THROW(matxInvalidParameter, "Block2D launch exceeds CUDA maximum threads per block");
}

// Grid covers batch dimensions only (dims 0 to RANK-3)
blocks.x = 1;
Expand Down Expand Up @@ -372,7 +377,8 @@ inline bool get_grid_dims_block_2d(dim3 &blocks, dim3 &threads,
}
}

MATX_LOG_DEBUG("Block2D: Blocks {}x{}x{} Threads {}x{}x{}", blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z);
MATX_LOG_DEBUG("Block2D: Blocks {}x{}x{} Threads {}x{}x{} groups_per_block={}",
blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, groups_per_block);

// No stride needed for now - could be extended for very large batches
return false;
Expand Down Expand Up @@ -427,11 +433,12 @@ template <int RANK>
inline bool get_grid_dims_block_pass_through(dim3 &blocks, dim3 &threads,
const cuda::std::array<index_t, RANK> &sizes,
int block_dim,
int inner_rank) {
int inner_rank,
int groups_per_block = 1) {
if (inner_rank == 1) {
return get_grid_dims_block_1d<RANK>(blocks, threads, sizes, block_dim);
}
return get_grid_dims_block_2d<RANK>(blocks, threads, sizes, block_dim);
return get_grid_dims_block_2d<RANK>(blocks, threads, sizes, block_dim, groups_per_block);
}
} // end namespace detail
} // end namespace matx
22 changes: 15 additions & 7 deletions include/matx/executors/jit_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -510,15 +510,23 @@ namespace matx
// For pass-through operators (e.g., MathDx), block dimensions are constrained by the operator.
auto block_dim_range = detail::get_operator_capability<detail::OperatorCapability::BLOCK_DIM>(op);
block_size = detail::SelectJITPassThroughBlockDim(block_dim_range);
stride = detail::get_grid_dims_block_pass_through<RANK>(blocks, threads, sizes, block_size, pass_through_inner_rank);

// EPT is 1 for 2D block operators - the operator handles elements internally
best_ept = detail::ElementsPerThread::ONE;
auto group_range = detail::get_operator_capability<detail::OperatorCapability::GROUPS_PER_BLOCK>(op);
groups_per_block = group_range[0];
if (groups_per_block == detail::capability_attributes<detail::OperatorCapability::GROUPS_PER_BLOCK>::invalid) {
MATX_THROW(matxInvalidParameter, "No valid JIT groups-per-block value satisfies the fused operator requirements");
}
stride = detail::get_grid_dims_block_pass_through<RANK>(
blocks, threads, sizes, block_size, pass_through_inner_rank, groups_per_block);

// Pass-through operators do not run the normal SET_ELEMENTS_PER_THREAD query path.
// Use the lower bound so existing operators with a default [ONE, MAX] range keep
// the historical EPT=ONE behavior. Operators that require vector lanes, such as
// FFT2, advertise a fixed [N, N] range.
best_ept = jit_ept_bounds[0];
shm_size = detail::get_operator_capability<detail::OperatorCapability::DYN_SHM_SIZE>(op);
groups_per_block = 1;

MATX_LOG_DEBUG("Block2D: EPT {}, Shm size {}, Block size {}",
static_cast<int>(best_ept), shm_size, block_size);
MATX_LOG_DEBUG("Block2D: EPT {}, Shm size {}, Block size {}, Groups per block {}",
static_cast<int>(best_ept), shm_size, block_size, groups_per_block);
} else if constexpr (is_dynamic_rank_op_v<Op>) {
// Dynamic tensor expressions: pre-compiled kernels don't exist for this Op type,
// so we cannot query register pressure. Use conservative defaults.
Expand Down
51 changes: 33 additions & 18 deletions include/matx/executors/jit_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,39 +445,54 @@ namespace matx {
template <class Op>\n\
__global__ void matxOpT2KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1) {\n\
int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\
matx::index_t idx = tid % size1;\n\
matx::index_t idy = tid / size1;\n\
if constexpr (cuda::std::is_pointer_v<Op>) {\n\
(*op).template operator()<CurrentCapabilities>(idy, idx);\n\
} else {\n\
op.template operator()<CurrentCapabilities>(idy, idx);\n\
constexpr int ept = static_cast<int>(CurrentCapabilities::ept);\n\
matx::index_t size1_vectors = (size1 + ept - 1) / ept;\n\
if (size1_vectors == 0) return;\n\
matx::index_t idx = tid % size1_vectors;\n\
matx::index_t idy = tid / size1_vectors;\n\
if (idy < size0 && idx * static_cast<matx::index_t>(ept) < size1) {\n\
if constexpr (cuda::std::is_pointer_v<Op>) {\n\
(*op).template operator()<CurrentCapabilities>(idy, idx);\n\
} else {\n\
op.template operator()<CurrentCapabilities>(idy, idx);\n\
}\n\
}\n\
}\n\
\n\
template <class Op>\n\
__global__ void matxOpT3KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2) {\n\
int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\
matx::index_t idx = tid % size2;\n\
matx::index_t idy = tid / size2;\n\
constexpr int ept = static_cast<int>(CurrentCapabilities::ept);\n\
matx::index_t size2_vectors = (size2 + ept - 1) / ept;\n\
if (size2_vectors == 0) return;\n\
matx::index_t idx = tid % size2_vectors;\n\
matx::index_t idy = tid / size2_vectors;\n\
matx::index_t idz = blockIdx.x;\n\
if constexpr (cuda::std::is_pointer_v<Op>) {\n\
(*op).template operator()<CurrentCapabilities>(idz, idy, idx);\n\
} else {\n\
op.template operator()<CurrentCapabilities>(idz, idy, idx);\n\
if (idz < size0 && idy < size1 && idx * static_cast<matx::index_t>(ept) < size2) {\n\
if constexpr (cuda::std::is_pointer_v<Op>) {\n\
(*op).template operator()<CurrentCapabilities>(idz, idy, idx);\n\
} else {\n\
op.template operator()<CurrentCapabilities>(idz, idy, idx);\n\
}\n\
}\n\
}\n\
\n\
template <class Op>\n\
__global__ void matxOpT4KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2, matx::index_t size3) {\n\
int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\
matx::index_t idx = tid % size3;\n\
matx::index_t idy = tid / size3;\n\
constexpr int ept = static_cast<int>(CurrentCapabilities::ept);\n\
matx::index_t size3_vectors = (size3 + ept - 1) / ept;\n\
if (size3_vectors == 0) return;\n\
matx::index_t idx = tid % size3_vectors;\n\
matx::index_t idy = tid / size3_vectors;\n\
matx::index_t idz = blockIdx.x;\n\
matx::index_t idw = blockIdx.y;\n\
if constexpr (cuda::std::is_pointer_v<Op>) {\n\
(*op).template operator()<CurrentCapabilities>(idw, idz, idy, idx);\n\
} else {\n\
op.template operator()<CurrentCapabilities>(idw, idz, idy, idx);\n\
if (idw < size0 && idz < size1 && idy < size2 && idx * static_cast<matx::index_t>(ept) < size3) {\n\
if constexpr (cuda::std::is_pointer_v<Op>) {\n\
(*op).template operator()<CurrentCapabilities>(idw, idz, idy, idx);\n\
} else {\n\
op.template operator()<CurrentCapabilities>(idw, idz, idy, idx);\n\
}\n\
}\n\
}\n\
}\n\
Expand Down
Loading