From 36d180aca30c3197a0b3ca4eefa540515dec2947 Mon Sep 17 00:00:00 2001 From: Cliff Burdick Date: Tue, 26 May 2026 13:34:04 -0700 Subject: [PATCH 1/6] Add cuFFTDx-backed FFT2 JIT support Generate JIT classes and LTO IR for single-block C2C fft2/ifft2 fusions, including shared-memory tiling through cuFFTDx 1D passes. Teach the JIT launcher about grouped 2D blocks and vectorized EPT indexing so FFT2 operators can return multiple columns per thread. Document the supported FFT2 JIT shape/type limits and add forward/inverse FFT2 JIT fusion coverage. --- docs_input/api/dft/fft/fft2d.rst | 6 +- docs_input/api/dft/fft/ifft2.rst | 6 +- docs_input/basics/fusion.rst | 9 +- include/matx/core/get_grid_dims.h | 17 +- include/matx/executors/jit_cuda.h | 17 +- include/matx/executors/jit_kernel.h | 18 +- include/matx/operators/fft.h | 225 +++++++++++++++++- include/matx/transforms/fft/fft_cufftdx.h | 268 ++++++++++++++++++++++ test/00_transform/FFT.cu | 84 ++++++- 9 files changed, 620 insertions(+), 30 deletions(-) diff --git a/docs_input/api/dft/fft/fft2d.rst b/docs_input/api/dft/fft/fft2d.rst index 8b3636446..297322100 100644 --- a/docs_input/api/dft/fft/fft2d.rst +++ b/docs_input/api/dft/fft/fft2d.rst @@ -5,6 +5,10 @@ fft2 Perform a 2D FFT. Batching is supported for any tensor with a rank higher than 2. +FFT kernel fusion is supported by cuFFTDx for complex-to-complex power-of-two square +transforms that fit in a single CUDA block when ``-DMATX_EN_MATHDX=ON`` is enabled. +Unsupported 2D FFT sizes and real-valued 2D FFTs use the existing cuFFT execution path. + .. versionadded:: 0.6.0 .. doxygenfunction:: fft2(const OpA &a, FFTNorm norm = FFTNorm::BACKWARD) @@ -31,4 +35,4 @@ Examples :language: cpp :start-after: example-begin fft2-2 :end-before: example-end fft2-2 - :dedent: \ No newline at end of file + :dedent: diff --git a/docs_input/api/dft/fft/ifft2.rst b/docs_input/api/dft/fft/ifft2.rst index be910ff2d..7129bb32a 100644 --- a/docs_input/api/dft/fft/ifft2.rst +++ b/docs_input/api/dft/fft/ifft2.rst @@ -5,6 +5,10 @@ ifft2 Perform a 2D inverse FFT. Batching is supported for any tensor with a rank higher than 2. +IFFT kernel fusion is supported by cuFFTDx for complex-to-complex power-of-two square +transforms that fit in a single CUDA block when ``-DMATX_EN_MATHDX=ON`` is enabled. +Unsupported 2D IFFT sizes and real-valued inverse 2D FFTs use the existing cuFFT execution path. + .. versionadded:: 0.6.0 .. doxygenfunction:: ifft2(const OpA &a, FFTNorm norm = FFTNorm::BACKWARD) @@ -31,4 +35,4 @@ Examples :language: cpp :start-after: example-begin ifft2-2 :end-before: example-end ifft2-2 - :dedent: \ No newline at end of file + :dedent: diff --git a/docs_input/basics/fusion.rst b/docs_input/basics/fusion.rst index 83eb740b4..4989c09f3 100644 --- a/docs_input/basics/fusion.rst +++ b/docs_input/basics/fusion.rst @@ -61,7 +61,8 @@ CUDA JIT Kernel Fusion MatX supports CUDA JIT kernel fusion that compiles the entire expression into a single kernel. Currently this is enabled for all standard MatX element-wise operators, FFT operations via cuFFTDx, GEMM operations via cuBLASDx, and selected -solver operations via cuSolverDx. To enable fusion with MathDx, the following option must be enabled: +solver operations via cuSolverDx. cuFFTDx supports 1D FFT fusion and single-block complex-to-complex 2D ``fft2``/``ifft2`` +fusion for supported power-of-two square transforms. To enable fusion with MathDx, the following option must be enabled: ``-DMATX_EN_MATHDX=ON``. MathDx support also enables the NVRTC-based JIT support used by ``CUDAJITExecutor``. Once enabled, the ``CUDAJITExecutor`` can be used to perform JIT compilation in supported situations. If the expression cannot be JIT compiled, the ``CUDAJITExecutor`` may throw an error. @@ -129,9 +130,9 @@ MathDx Compatibility cuSolverDx ``inv`` when their block-dimension ranges intersect. * - cuFFTDx - Yes - - Enabled via ``-DMATX_EN_MATHDX=ON`` for compatible FFT fusion paths. cuFFTDx has stricter launch-shape - requirements than the generic element-wise JIT path, and does not generally compose with cuBLASDx/cuSolverDx - operations that require a different block/grid model. + - Enabled via ``-DMATX_EN_MATHDX=ON`` for compatible 1D FFT fusion paths and supported single-block 2D C2C FFT + fusion paths. cuFFTDx has stricter launch-shape requirements than the generic element-wise JIT path, and does not + generally compose with cuBLASDx/cuSolverDx operations that require a different block/grid model. * - cuSolverDx - Partial - Enabled via ``-DMATX_EN_MATHDX=ON`` for ``chol``, ``inv``, and selected projection outputs from diff --git a/include/matx/core/get_grid_dims.h b/include/matx/core/get_grid_dims.h index b75eeed1d..a224a3b93 100644 --- a/include/matx/core/get_grid_dims.h +++ b/include/matx/core/get_grid_dims.h @@ -338,12 +338,17 @@ inline bool get_grid_dims_block_reduce(dim3 &blocks, dim3 &threads, const cuda:: template inline bool get_grid_dims_block_2d(dim3 &blocks, dim3 &threads, const cuda::std::array &sizes, - int block_dim) { + int block_dim, + int groups_per_block = 1) { // Threads are set to block_dim in x, y and z are 1 // All threads cooperate via flattened thread ID in the kernel threads.x = block_dim; - threads.y = 1; + threads.y = groups_per_block; threads.z = 1; + + if (static_cast(threads.x) * static_cast(threads.y) > 1024) { + MATX_THROW(matxInvalidParameter, "Block2D launch exceeds CUDA maximum threads per block"); + } // Grid covers batch dimensions only (dims 0 to RANK-3) blocks.x = 1; @@ -372,7 +377,8 @@ inline bool get_grid_dims_block_2d(dim3 &blocks, dim3 &threads, } } - MATX_LOG_DEBUG("Block2D: Blocks {}x{}x{} Threads {}x{}x{}", blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z); + MATX_LOG_DEBUG("Block2D: Blocks {}x{}x{} Threads {}x{}x{} groups_per_block={}", + blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, groups_per_block); // No stride needed for now - could be extended for very large batches return false; @@ -427,11 +433,12 @@ template inline bool get_grid_dims_block_pass_through(dim3 &blocks, dim3 &threads, const cuda::std::array &sizes, int block_dim, - int inner_rank) { + int inner_rank, + int groups_per_block = 1) { if (inner_rank == 1) { return get_grid_dims_block_1d(blocks, threads, sizes, block_dim); } - return get_grid_dims_block_2d(blocks, threads, sizes, block_dim); + return get_grid_dims_block_2d(blocks, threads, sizes, block_dim, groups_per_block); } } // end namespace detail } // end namespace matx diff --git a/include/matx/executors/jit_cuda.h b/include/matx/executors/jit_cuda.h index cddf00678..a58ec837e 100644 --- a/include/matx/executors/jit_cuda.h +++ b/include/matx/executors/jit_cuda.h @@ -510,15 +510,20 @@ namespace matx // For pass-through operators (e.g., MathDx), block dimensions are constrained by the operator. auto block_dim_range = detail::get_operator_capability(op); block_size = detail::SelectJITPassThroughBlockDim(block_dim_range); - stride = detail::get_grid_dims_block_pass_through(blocks, threads, sizes, block_size, pass_through_inner_rank); + auto group_range = detail::get_operator_capability(op); + groups_per_block = group_range[0]; + if (groups_per_block == detail::capability_attributes::invalid) { + MATX_THROW(matxInvalidParameter, "No valid JIT groups-per-block value satisfies the fused operator requirements"); + } + stride = detail::get_grid_dims_block_pass_through( + blocks, threads, sizes, block_size, pass_through_inner_rank, groups_per_block); - // EPT is 1 for 2D block operators - the operator handles elements internally - best_ept = detail::ElementsPerThread::ONE; + // Block-level operators can still return vectorized output lanes. + best_ept = jit_ept_bounds[1]; shm_size = detail::get_operator_capability(op); - groups_per_block = 1; - MATX_LOG_DEBUG("Block2D: EPT {}, Shm size {}, Block size {}", - static_cast(best_ept), shm_size, block_size); + MATX_LOG_DEBUG("Block2D: EPT {}, Shm size {}, Block size {}, Groups per block {}", + static_cast(best_ept), shm_size, block_size, groups_per_block); } else if constexpr (is_dynamic_rank_op_v) { // Dynamic tensor expressions: pre-compiled kernels don't exist for this Op type, // so we cannot query register pressure. Use conservative defaults. diff --git a/include/matx/executors/jit_kernel.h b/include/matx/executors/jit_kernel.h index cec56ba56..ec6a5103a 100644 --- a/include/matx/executors/jit_kernel.h +++ b/include/matx/executors/jit_kernel.h @@ -445,8 +445,10 @@ namespace matx { template \n\ __global__ void matxOpT2KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1) {\n\ int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\ - matx::index_t idx = tid % size1;\n\ - matx::index_t idy = tid / size1;\n\ + constexpr int ept = static_cast(CurrentCapabilities::ept);\n\ + matx::index_t size1_vectors = (size1 + ept - 1) / ept;\n\ + matx::index_t idx = tid % size1_vectors;\n\ + matx::index_t idy = tid / size1_vectors;\n\ if constexpr (cuda::std::is_pointer_v) {\n\ (*op).template operator()(idy, idx);\n\ } else {\n\ @@ -457,8 +459,10 @@ namespace matx { template \n\ __global__ void matxOpT3KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2) {\n\ int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\ - matx::index_t idx = tid % size2;\n\ - matx::index_t idy = tid / size2;\n\ + constexpr int ept = static_cast(CurrentCapabilities::ept);\n\ + matx::index_t size2_vectors = (size2 + ept - 1) / ept;\n\ + matx::index_t idx = tid % size2_vectors;\n\ + matx::index_t idy = tid / size2_vectors;\n\ matx::index_t idz = blockIdx.x;\n\ if constexpr (cuda::std::is_pointer_v) {\n\ (*op).template operator()(idz, idy, idx);\n\ @@ -470,8 +474,10 @@ namespace matx { template \n\ __global__ void matxOpT4KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2, matx::index_t size3) {\n\ int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\ - matx::index_t idx = tid % size3;\n\ - matx::index_t idy = tid / size3;\n\ + constexpr int ept = static_cast(CurrentCapabilities::ept);\n\ + matx::index_t size3_vectors = (size3 + ept - 1) / ept;\n\ + matx::index_t idx = tid % size3_vectors;\n\ + matx::index_t idy = tid / size3_vectors;\n\ matx::index_t idz = blockIdx.x;\n\ matx::index_t idw = blockIdx.y;\n\ if constexpr (cuda::std::is_pointer_v) {\n\ diff --git a/include/matx/operators/fft.h b/include/matx/operators/fft.h index 15706ba26..b616d5442 100644 --- a/include/matx/operators/fft.h +++ b/include/matx/operators/fft.h @@ -239,7 +239,7 @@ namespace matx " constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int dim) const\n" + " {\n" + " return out_dims_[dim];\n " + - " }\n" + + " }\n" + "};\n") ); } @@ -754,12 +754,31 @@ namespace matx mutable detail::tensor_impl_t tmp_out_; mutable ttype *ptr = nullptr; mutable bool prerun_done_ = false; +#if defined(MATX_EN_MATHDX) && defined(__CUDACC__) + mutable cuFFTDx2DHelper dx_fft2_helper_; + bool jit_axes_supported_ = true; +#endif public: using matxop = bool; using value_type = typename OpA::value_type; using matx_transform_op = bool; using fft2_xform_op = bool; + using self_type = FFT2Op; + + // Propagate dynamic tensor marker through expression tree + using dynamic_tensor_expr = cuda::std::bool_constant< + is_dynamic_tensor_v || is_dynamic_rank_op_v>; + +#ifdef MATX_EN_JIT + struct JIT_Storage { + typename detail::inner_storage_or_self_t> a_; + }; + + JIT_Storage ToJITStorage() const { + return JIT_Storage{detail::to_jit_storage(a_)}; + } +#endif __MATX_INLINE__ std::string str() const { if constexpr (Direction == detail::FFTDirection::FORWARD) { @@ -802,9 +821,87 @@ namespace matx out_dims_[Rank() - 2] = out_dims_[Rank() - 2]; } } + +#if defined(MATX_EN_MATHDX) && defined(__CUDACC__) + if constexpr (!std::is_same_v) { + for (int32_t i = 0; i < Rank(); i++) { + if (perm_[static_cast(i)] != i) { + jit_axes_supported_ = false; + break; + } + } + } + + int major = 0; + int minor = 0; + int device; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device); + int cc = major * 100 + minor * 10; + + dx_fft2_helper_.set_fft_size_y(out_dims_[Rank() - 2]); + dx_fft2_helper_.set_fft_size_x(out_dims_[Rank() - 1]); + dx_fft2_helper_.set_fft_type(Type); + dx_fft2_helper_.set_direction(Direction); + dx_fft2_helper_.set_cc(cc); +#endif } - +#if defined(MATX_EN_MATHDX) && defined(__CUDACC__) + __MATX_INLINE__ std::string get_jit_class_name() const { + std::string symbol_name = "JITFFT2Op_"; + symbol_name += std::to_string(dx_fft2_helper_.get_fft_size_y()); + symbol_name += "_"; + symbol_name += std::to_string(dx_fft2_helper_.get_fft_size_x()); + symbol_name += "_T"; + symbol_name += std::to_string(static_cast(Type)); + symbol_name += "_D"; + symbol_name += Direction == detail::FFTDirection::FORWARD ? std::string("F") : std::string("B"); + return symbol_name; + } + + __MATX_INLINE__ auto get_jit_op_str() const { + const int actual_rank = jit_rank(); + const std::string class_name = get_jit_class_name(); + const std::string fft_x_func_name = dx_fft2_helper_.GetXFuncName(); + const std::string fft_y_func_name = dx_fft2_helper_.GetYFuncName(); + + std::string declarations = + " extern \"C\" __device__ void " + fft_x_func_name + "(" + detail::type_to_string() + "*);\n"; + if (fft_y_func_name != fft_x_func_name) { + declarations += + " extern \"C\" __device__ void " + fft_y_func_name + "(" + detail::type_to_string() + "*);\n"; + } + + return cuda::std::make_tuple( + class_name, + std::string( + declarations + + " template struct " + class_name + " {\n" + + " using input_type = typename OpA::value_type;\n" + + " using matxop = bool;\n" + + " using value_type = input_type;\n" + + " typename detail::inner_storage_or_self_t> a_;\n" + + " constexpr static cuda::std::array out_dims_ = { " + + detail::array_to_string(out_dims_, actual_rank) + " };\n" + + " template \n" + + " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) const\n" + + " {\n" + + " " + dx_fft2_helper_.GetFuncStr(fft_x_func_name, fft_y_func_name, static_cast(norm_), actual_rank) + "\n" + + " }\n" + + " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank()\n" + + " {\n" + + " return " + std::to_string(actual_rank) + ";\n" + + " }\n" + + " constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int dim) const\n" + + " {\n" + + " return out_dims_[dim];\n " + + " }\n" + + "};\n") + ); + } +#endif template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const @@ -825,6 +922,14 @@ namespace matx return out_dims_[dim]; } + __MATX_INLINE__ __MATX_HOST__ int32_t DynRank() const { + return detail::get_dyn_rank(a_); + } + + __MATX_INLINE__ __MATX_HOST__ int32_t jit_rank() const { + if constexpr (is_dynamic_rank_op_v) return DynRank(); + else return Rank(); + } __MATX_HOST__ __MATX_INLINE__ auto Data() const noexcept { return ptr; } @@ -865,10 +970,118 @@ namespace matx template __MATX_INLINE__ __MATX_HOST__ auto get_capability([[maybe_unused]] InType &in) const { - // 1. Determine if the binary operation ITSELF intrinsically has this capability. - auto self_has_cap = capability_attributes::default_value; - auto result = combine_capabilities(self_has_cap, detail::get_operator_capability(a_, in)); - return result; +#if defined(MATX_EN_MATHDX) && defined(__CUDACC__) + if constexpr (Cap == OperatorCapability::DYN_SHM_SIZE) { + const int self_shm = (jit_axes_supported_ && dx_fft2_helper_.template CheckJITSizeAndTypeRequirements()) ? + dx_fft2_helper_.GetShmRequired() : capability_attributes::default_value; + auto result = combine_capabilities(self_shm, detail::get_operator_capability(a_, in)); + MATX_LOG_DEBUG("cuFFTDx 2D DYN_SHM_SIZE: {}", result); + return result; + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { + bool supported = jit_axes_supported_; + if (supported && !dx_fft2_helper_.template CheckJITSizeAndTypeRequirements()) { + supported = false; + } + else if (supported) { + supported = dx_fft2_helper_.IsSupported(); + if (supported && dx_fft2_helper_.GetElementsPerThread() == ElementsPerThread::INVALID) { + supported = false; + } + } + + auto result = combine_capabilities(supported, detail::get_operator_capability(a_, in)); + MATX_LOG_DEBUG("cuFFTDx 2D SUPPORTS_JIT: {}", result); + return result; + } + else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { + const auto [key, value] = get_jit_op_str(); + if (in.find(key) == in.end()) { + in[key] = value; + } + detail::get_operator_capability(a_, in); + MATX_LOG_DEBUG("cuFFTDx 2D JIT_CLASS_QUERY: true"); + return true; + } + else if constexpr (Cap == OperatorCapability::ELEMENTS_PER_THREAD) { + ElementsPerThread ept = ElementsPerThread::ONE; + if (in.jit) { + ept = (jit_axes_supported_ && dx_fft2_helper_.template CheckJITSizeAndTypeRequirements()) ? + dx_fft2_helper_.GetElementsPerThread() : ElementsPerThread::INVALID; + } + const auto my_cap = cuda::std::array{ept, ept}; + auto result = combine_capabilities(my_cap, detail::get_operator_capability(a_, in)); + MATX_LOG_DEBUG("cuFFTDx 2D ELEMENTS_PER_THREAD: [{},{}]", + static_cast(result[0]), static_cast(result[1])); + return result; + } + else if constexpr (Cap == OperatorCapability::MAX_EPT_VEC_LOAD) { + auto result = combine_capabilities(1, detail::get_operator_capability(a_, in)); + MATX_LOG_DEBUG("cuFFTDx 2D MAX_EPT_VEC_LOAD: {}", result); + return result; + } + else if constexpr (Cap == OperatorCapability::GLOBAL_KERNEL) { + MATX_LOG_DEBUG("cuFFTDx 2D GLOBAL_KERNEL: false"); + return false; + } + else if constexpr (Cap == OperatorCapability::PASS_THROUGH_THREADS) { + MATX_LOG_DEBUG("cuFFTDx 2D PASS_THROUGH_THREADS: true"); + return true; + } + else if constexpr (Cap == OperatorCapability::GROUPS_PER_BLOCK) { + const int ffts_per_block = dx_fft2_helper_.GetFFTsPerBlock(); + const auto my_cap = cuda::std::array{ffts_per_block, ffts_per_block}; + auto result = combine_capabilities(my_cap, detail::get_operator_capability(a_, in)); + MATX_LOG_DEBUG("cuFFTDx 2D GROUPS_PER_BLOCK: [{},{}]", result[0], result[1]); + return result; + } + else if constexpr (Cap == OperatorCapability::SET_GROUPS_PER_BLOCK || + Cap == OperatorCapability::SET_ELEMENTS_PER_THREAD) { + auto result = combine_capabilities(capability_attributes::default_value, detail::get_operator_capability(a_, in)); + MATX_LOG_DEBUG("cuFFTDx 2D SET_GROUPS/EPT: {}", result); + return result; + } + else if constexpr (Cap == OperatorCapability::BLOCK_DIM) { + const int block_dim = dx_fft2_helper_.GetBlockDim(); + const auto my_block = block_dim > 0 ? + cuda::std::array{block_dim, block_dim} : + cuda::std::array{capability_attributes::invalid, capability_attributes::invalid}; + auto result = combine_capabilities(my_block, detail::get_operator_capability(a_, in)); + MATX_LOG_DEBUG("cuFFTDx 2D BLOCK_DIM: [{},{}]", result[0], result[1]); + return result; + } + else if constexpr (Cap == OperatorCapability::GENERATE_LTOIR) { + auto result = combine_capabilities( + dx_fft2_helper_.GenerateLTOIR(in.ltoir_symbols), + detail::get_operator_capability(a_, in)); + MATX_LOG_DEBUG("cuFFTDx 2D GENERATE_LTOIR: {}", result); + return result; + } + else if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { + const auto inner_op_jit_name = detail::get_operator_capability(a_, in); + auto result = get_jit_class_name() + "<" + inner_op_jit_name + ">"; + MATX_LOG_DEBUG("cuFFTDx 2D JIT_TYPE_QUERY: {}", result); + return result; + } + else { + auto self_has_cap = capability_attributes::default_value; + auto result = combine_capabilities(self_has_cap, detail::get_operator_capability(a_, in)); + return result; + } +#else + if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { + bool supported = false; + return combine_capabilities(supported, detail::get_operator_capability(a_, in)); + } + else if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { + return ""; + } + else { + auto self_has_cap = capability_attributes::default_value; + auto result = combine_capabilities(self_has_cap, detail::get_operator_capability(a_, in)); + return result; + } +#endif } template diff --git a/include/matx/transforms/fft/fft_cufftdx.h b/include/matx/transforms/fft/fft_cufftdx.h index f1f8bbd94..c8bdd7673 100644 --- a/include/matx/transforms/fft/fft_cufftdx.h +++ b/include/matx/transforms/fft/fft_cufftdx.h @@ -38,6 +38,7 @@ #include "matx/core/operator_options.h" #include "matx/core/capabilities.h" #include "matx/core/log.h" +#include #include #define LIBMATHDX_CHECK(ans) \ @@ -435,5 +436,272 @@ namespace matx { #endif }; + template + class cuFFTDx2DHelper { + private: + index_t fft_size_x_ = 0; + index_t fft_size_y_ = 0; + FFTType fft_type_ = FFTType::C2C; + FFTDirection direction_ = FFTDirection::FORWARD; + int cc_ = 0; + cuFFTDxHelper fft_x_helper_; + cuFFTDxHelper fft_y_helper_; + + static ElementsPerThread IntToElementsPerThread(int ept) { + switch (ept) { + case 1: return ElementsPerThread::ONE; + case 2: return ElementsPerThread::TWO; + case 4: return ElementsPerThread::FOUR; + case 8: return ElementsPerThread::EIGHT; + case 16: return ElementsPerThread::SIXTEEN; + case 32: return ElementsPerThread::THIRTY_TWO; + default: return ElementsPerThread::INVALID; + } + } + + void Configure1DHelpers() { + fft_x_helper_.set_fft_size(fft_size_x_); + fft_x_helper_.set_fft_type(fft_type_); + fft_x_helper_.set_direction(direction_); + fft_x_helper_.set_ffts_per_block(static_cast(fft_size_y_)); + fft_x_helper_.set_cc(cc_); + fft_x_helper_.set_contiguous_input(false); + fft_x_helper_.set_method(cuFFTDxMethod::SHARED); + + fft_y_helper_.set_fft_size(fft_size_y_); + fft_y_helper_.set_fft_type(fft_type_); + fft_y_helper_.set_direction(direction_); + fft_y_helper_.set_ffts_per_block(static_cast(fft_size_x_)); + fft_y_helper_.set_cc(cc_); + fft_y_helper_.set_contiguous_input(false); + fft_y_helper_.set_method(cuFFTDxMethod::SHARED); + } + + public: + cuFFTDx2DHelper() = default; + + index_t get_fft_size_x() const { return fft_size_x_; } + index_t get_fft_size_y() const { return fft_size_y_; } + FFTType get_fft_type() const { return fft_type_; } + FFTDirection get_direction() const { return direction_; } + int get_cc() const { return cc_; } + + void set_fft_size_x(index_t size) { fft_size_x_ = size; Configure1DHelpers(); } + void set_fft_size_y(index_t size) { fft_size_y_ = size; Configure1DHelpers(); } + void set_fft_type(FFTType type) { fft_type_ = type; Configure1DHelpers(); } + void set_direction(FFTDirection dir) { direction_ = dir; Configure1DHelpers(); } + void set_cc(int cc) { cc_ = cc; Configure1DHelpers(); } + +#if defined(MATX_EN_MATHDX) && defined(__CUDACC__) + std::string GetSymbolName() const { + std::string symbol_name; + symbol_name += std::to_string(fft_size_x_); + symbol_name += "_"; + symbol_name += std::to_string(fft_size_y_); + symbol_name += "_T"; + symbol_name += std::to_string(static_cast(fft_type_)); + symbol_name += "_D"; + symbol_name += std::to_string(static_cast(direction_)); + symbol_name += "_CC"; + symbol_name += std::to_string(cc_); + +#if defined(CUDA_VERSION) + symbol_name += "_CUDA"; + symbol_name += std::to_string(CUDART_VERSION); +#else + symbol_name += "_CUDAUNKNOWN"; +#endif + + return symbol_name; + } + + template + bool CheckJITSizeAndTypeRequirements() const { + using OpInputType = typename OpType::value_type; + + if (fft_type_ != FFTType::C2C) { + return false; + } + + if ((fft_size_x_ & (fft_size_x_ - 1)) != 0 || fft_size_x_ == 0 || + (fft_size_y_ & (fft_size_y_ - 1)) != 0 || fft_size_y_ == 0) { + return false; + } + + // The single-kernel JIT path uses one CUDA block for a complete 2D tile. + if (fft_size_x_ != fft_size_y_ || (fft_size_x_ * fft_size_y_) > 1024) { + return false; + } + + if constexpr (is_complex_half_v || !is_complex_v) { + return false; + } + + return true; + } + + bool IsSupported() const { + return fft_x_helper_.IsSupported() && fft_y_helper_.IsSupported(); + } + + int GetShmRequired() const { + const auto data_size = static_cast(fft_size_x_) * + static_cast(fft_size_y_) * + static_cast(sizeof(InputType)); + const auto x_shm = static_cast(fft_x_helper_.GetShmRequired()); + const auto y_shm = static_cast(fft_y_helper_.GetShmRequired()); + const auto extra_shm = std::max(0, std::max(x_shm, y_shm) - data_size); + const auto total = data_size * 2 + extra_shm; + MATX_LOG_DEBUG("cuFFTDx 2D shared memory: data={}, scratch={}, extra={}, total={}", + data_size, data_size, extra_shm, total); + return static_cast(total); + } + + int GetBlockDim() const { + const auto block_x = fft_x_helper_.GetBlockDim(); + const auto block_y = fft_y_helper_.GetBlockDim(); + if (block_x != block_y) { + MATX_LOG_DEBUG("cuFFTDx 2D block dims differ: x={}, y={}", block_x, block_y); + return -1; + } + + return block_x; + } + + int GetFFTsPerBlock() const { + return static_cast(fft_size_y_); + } + + ElementsPerThread GetElementsPerThread() const { + const auto block_dim = GetBlockDim(); + if (block_dim <= 0 || fft_size_x_ % block_dim != 0) { + return ElementsPerThread::INVALID; + } + + return IntToElementsPerThread(static_cast(fft_size_x_ / block_dim)); + } + + bool GenerateLTOIR(std::set <oir_symbols) { + return fft_x_helper_.GenerateLTOIR(ltoir_symbols) && + fft_y_helper_.GenerateLTOIR(ltoir_symbols); + } + + std::string GetXFuncName() { + return std::string(FFT_DX_FUNC_PREFIX) + "_" + fft_x_helper_.GetSymbolName(); + } + + std::string GetYFuncName() { + return std::string(FFT_DX_FUNC_PREFIX) + "_" + fft_y_helper_.GetSymbolName(); + } + + std::string GetFuncStr(const std::string &fft_x_func_name, + const std::string &fft_y_func_name, + int fft_norm, + int actual_rank) { + const int fft_forward = (direction_ == FFTDirection::FORWARD) ? 1 : 0; + + std::string result = R"( + using input_type = )"; + result += detail::type_to_string(); + result += R"(; + using input_type_converted = typename detail::convert_matx_type_t; + using precision = typename detail::inner_precision::type; + using ScalarCap = typename CapType::scalar_cap; + [[maybe_unused]] static constexpr int fft_size_x = )"; + result += std::to_string(static_cast(fft_size_x_)); + result += R"(; + [[maybe_unused]] static constexpr int fft_size_y = )"; + result += std::to_string(static_cast(fft_size_y_)); + result += R"(; + [[maybe_unused]] static constexpr int fft_forward = )"; + result += std::to_string(fft_forward); + result += R"(; + [[maybe_unused]] static constexpr int fft_norm = )"; + result += std::to_string(fft_norm); + result += R"(; + [[maybe_unused]] static constexpr int fft_rank = )"; + result += std::to_string(actual_rank); + result += R"(; + [[maybe_unused]] static constexpr int fft_elements = fft_size_x * fft_size_y; + + extern __shared__ __align__(16) unsigned char fft2_smem_raw[]; + auto *fft_data = reinterpret_cast(fft2_smem_raw); + auto *fft_scratch = fft_data + fft_elements; + + const int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; + const int total_threads = blockDim.x * blockDim.y * blockDim.z; + cuda::std::array fft_indices{static_cast(indices)...}; + + for (int elem = tid; elem < fft_elements; elem += total_threads) { + const int row = elem / fft_size_x; + const int col = elem - row * fft_size_x; + fft_indices[fft_rank - 2] = row; + fft_indices[fft_rank - 1] = col; + fft_data[elem] = static_cast( + cuda::std::apply([&](auto... args) { + return a_.template operator()(args...); + }, fft_indices)); + } + + __syncthreads(); + )"; + result += fft_x_func_name; + result += R"((reinterpret_cast(fft_data)); + __syncthreads(); + + for (int elem = tid; elem < fft_elements; elem += total_threads) { + const int row = elem / fft_size_x; + const int col = elem - row * fft_size_x; + fft_scratch[col * fft_size_y + row] = fft_data[row * fft_size_x + col]; + } + + __syncthreads(); + )"; + result += fft_y_func_name; + result += R"((reinterpret_cast(fft_scratch)); + __syncthreads(); + + static constexpr int fft_output_ept = static_cast(CapType::ept); + const int out_row = static_cast(cuda::std::get(cuda::std::make_tuple(indices...))); + const int out_col_base = static_cast(cuda::std::get(cuda::std::make_tuple(indices...))) * fft_output_ept; + + if constexpr (CapType::ept == ElementsPerThread::ONE) { + input_type_converted result = fft_scratch[out_col_base * fft_size_y + out_row]; + if constexpr (fft_norm == 2) { + result = result * static_cast(1.f) / static_cast(cuda::std::sqrt(static_cast(fft_elements))); + } + else if constexpr ((fft_norm == 1 && fft_forward) || (fft_norm == 0 && !fft_forward)) { + result = result * static_cast(1.f) / static_cast(fft_elements); + } + + return static_cast(result); + } + else { + Vector result_vec; + #pragma unroll + for (int i = 0; i < fft_output_ept; i++) { + const int out_col = out_col_base + i; + input_type_converted result = input_type_converted{}; + if (out_col < fft_size_x) { + result = fft_scratch[out_col * fft_size_y + out_row]; + if constexpr (fft_norm == 2) { + result = result * static_cast(1.f) / static_cast(cuda::std::sqrt(static_cast(fft_elements))); + } + else if constexpr ((fft_norm == 1 && fft_forward) || (fft_norm == 0 && !fft_forward)) { + result = result * static_cast(1.f) / static_cast(fft_elements); + } + } + result_vec.data[i] = static_cast(result); + } + + return result_vec; + } + )"; + + return result; + } +#endif + }; + } // namespace detail } // namespace matx diff --git a/test/00_transform/FFT.cu b/test/00_transform/FFT.cu index 796218d14..2640c6b92 100644 --- a/test/00_transform/FFT.cu +++ b/test/00_transform/FFT.cu @@ -919,6 +919,88 @@ TYPED_TEST(FFTTestComplexTypes, FFT2D16R2C) MATX_EXIT_HANDLER(); } +#if defined(MATX_EN_JIT) && defined(MATX_EN_MATHDX) +TEST(FFTJIT, CuFFTDx2DFFT2Fusion) +{ + MATX_ENTER_HANDLER(); + + using complex_type = cuda::std::complex; + constexpr index_t fft_dim = 4; + + auto in = make_tensor({fft_dim, fft_dim}); + auto jit_out = make_tensor({fft_dim, fft_dim}); + auto ref_out = make_tensor({fft_dim, fft_dim}); + + for (index_t row = 0; row < fft_dim; row++) { + for (index_t col = 0; col < fft_dim; col++) { + in(row, col) = complex_type{static_cast(row + 2 * col), + static_cast((row + col) % 3)}; + } + } + + auto expr = fft2(in) + complex_type{1.0f, -1.0f}; + if (!jit_supported(expr)) { + GTEST_SKIP(); + } + + CUDAJITExecutor jit_exec{}; + cudaExecutor cuda_exec{}; + (jit_out = expr).run(jit_exec); + (ref_out = fft2(in) + complex_type{1.0f, -1.0f}).run(cuda_exec); + jit_exec.sync(); + cuda_exec.sync(); + + for (index_t row = 0; row < fft_dim; row++) { + for (index_t col = 0; col < fft_dim; col++) { + ASSERT_NEAR(jit_out(row, col).real(), ref_out(row, col).real(), 0.01f); + ASSERT_NEAR(jit_out(row, col).imag(), ref_out(row, col).imag(), 0.01f); + } + } + + MATX_EXIT_HANDLER(); +} + +TEST(FFTJIT, CuFFTDx2DIFFT2Fusion) +{ + MATX_ENTER_HANDLER(); + + using complex_type = cuda::std::complex; + constexpr index_t fft_dim = 4; + + auto in = make_tensor({fft_dim, fft_dim}); + auto jit_out = make_tensor({fft_dim, fft_dim}); + auto ref_out = make_tensor({fft_dim, fft_dim}); + + for (index_t row = 0; row < fft_dim; row++) { + for (index_t col = 0; col < fft_dim; col++) { + in(row, col) = complex_type{static_cast(row - col), + static_cast(1 + row + col)}; + } + } + + auto expr = ifft2(in, FFTNorm::BACKWARD) * complex_type{0.5f, 0.0f}; + if (!jit_supported(expr)) { + GTEST_SKIP(); + } + + CUDAJITExecutor jit_exec{}; + cudaExecutor cuda_exec{}; + (jit_out = expr).run(jit_exec); + (ref_out = ifft2(in, FFTNorm::BACKWARD) * complex_type{0.5f, 0.0f}).run(cuda_exec); + jit_exec.sync(); + cuda_exec.sync(); + + for (index_t row = 0; row < fft_dim; row++) { + for (index_t col = 0; col < fft_dim; col++) { + ASSERT_NEAR(jit_out(row, col).real(), ref_out(row, col).real(), 0.01f); + ASSERT_NEAR(jit_out(row, col).imag(), ref_out(row, col).imag(), 0.01f); + } + } + + MATX_EXIT_HANDLER(); +} +#endif + TYPED_TEST(FFTTestComplexTypes, FFT2D16x32R2C) { MATX_ENTER_HANDLER(); @@ -1023,4 +1105,4 @@ TYPED_TEST(FFTTestComplexNonHalfTypesAllExecs, IFFT1D1024C2CShort) MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh); MATX_EXIT_HANDLER(); -} \ No newline at end of file +} From 202d0ecac82fe858586db8b4ca2c964acf69f851 Mon Sep 17 00:00:00 2001 From: Cliff Burdick Date: Tue, 26 May 2026 14:14:03 -0700 Subject: [PATCH 2/6] address greptile review feedback (greploop iteration 1) --- include/matx/executors/jit_cuda.h | 7 +- include/matx/operators/fft.h | 4 + include/matx/transforms/fft/fft_cufftdx.h | 4 + test/00_transform/FFT.cu | 155 +++++++++++++++++----- 4 files changed, 134 insertions(+), 36 deletions(-) diff --git a/include/matx/executors/jit_cuda.h b/include/matx/executors/jit_cuda.h index a58ec837e..4fdee6362 100644 --- a/include/matx/executors/jit_cuda.h +++ b/include/matx/executors/jit_cuda.h @@ -518,8 +518,11 @@ namespace matx stride = detail::get_grid_dims_block_pass_through( blocks, threads, sizes, block_size, pass_through_inner_rank, groups_per_block); - // Block-level operators can still return vectorized output lanes. - best_ept = jit_ept_bounds[1]; + // Pass-through operators do not run the normal SET_ELEMENTS_PER_THREAD query path. + // Use the lower bound so existing operators with a default [ONE, MAX] range keep + // the historical EPT=ONE behavior. Operators that require vector lanes, such as + // FFT2, advertise a fixed [N, N] range. + best_ept = jit_ept_bounds[0]; shm_size = detail::get_operator_capability(op); MATX_LOG_DEBUG("Block2D: EPT {}, Shm size {}, Block size {}, Groups per block {}", diff --git a/include/matx/operators/fft.h b/include/matx/operators/fft.h index b616d5442..43f913bee 100644 --- a/include/matx/operators/fft.h +++ b/include/matx/operators/fft.h @@ -1028,6 +1028,10 @@ namespace matx MATX_LOG_DEBUG("cuFFTDx 2D PASS_THROUGH_THREADS: true"); return true; } + else if constexpr (Cap == OperatorCapability::PASS_THROUGH_INNER_RANK) { + MATX_LOG_DEBUG("cuFFTDx 2D PASS_THROUGH_INNER_RANK: 2"); + return 2; + } else if constexpr (Cap == OperatorCapability::GROUPS_PER_BLOCK) { const int ffts_per_block = dx_fft2_helper_.GetFFTsPerBlock(); const auto my_cap = cuda::std::array{ffts_per_block, ffts_per_block}; diff --git a/include/matx/transforms/fft/fft_cufftdx.h b/include/matx/transforms/fft/fft_cufftdx.h index c8bdd7673..f7f76fb9b 100644 --- a/include/matx/transforms/fft/fft_cufftdx.h +++ b/include/matx/transforms/fft/fft_cufftdx.h @@ -460,6 +460,10 @@ namespace matx { } void Configure1DHelpers() { + if (fft_size_x_ <= 0 || fft_size_y_ <= 0 || cc_ <= 0) { + return; + } + fft_x_helper_.set_fft_size(fft_size_x_); fft_x_helper_.set_fft_type(fft_type_); fft_x_helper_.set_direction(direction_); diff --git a/test/00_transform/FFT.cu b/test/00_transform/FFT.cu index 2640c6b92..294181aff 100644 --- a/test/00_transform/FFT.cu +++ b/test/00_transform/FFT.cu @@ -920,25 +920,69 @@ TYPED_TEST(FFTTestComplexTypes, FFT2D16R2C) } #if defined(MATX_EN_JIT) && defined(MATX_EN_MATHDX) -TEST(FFTJIT, CuFFTDx2DFFT2Fusion) -{ - MATX_ENTER_HANDLER(); +namespace { +using FFT2JITComplex = cuda::std::complex; - using complex_type = cuda::std::complex; - constexpr index_t fft_dim = 4; +void FillFFT2JITInput2D(auto &in, index_t fft_dim) +{ + for (index_t row = 0; row < fft_dim; row++) { + for (index_t col = 0; col < fft_dim; col++) { + in(row, col) = FFT2JITComplex{ + static_cast((row + 1) * (col + 2) % 17), + static_cast((2 * row + col) % 11)}; + } + } +} - auto in = make_tensor({fft_dim, fft_dim}); - auto jit_out = make_tensor({fft_dim, fft_dim}); - auto ref_out = make_tensor({fft_dim, fft_dim}); +void FillFFT2JITInput3D(auto &in, index_t batches, index_t fft_dim) +{ + for (index_t batch = 0; batch < batches; batch++) { + for (index_t row = 0; row < fft_dim; row++) { + for (index_t col = 0; col < fft_dim; col++) { + in(batch, row, col) = FFT2JITComplex{ + static_cast((batch + 1) * (row + 2) + col), + static_cast((batch + row + 2 * col) % 13)}; + } + } + } +} +void AssertFFT2JITClose2D(auto &jit_out, auto &ref_out, index_t fft_dim, float tol) +{ for (index_t row = 0; row < fft_dim; row++) { for (index_t col = 0; col < fft_dim; col++) { - in(row, col) = complex_type{static_cast(row + 2 * col), - static_cast((row + col) % 3)}; + ASSERT_NEAR(jit_out(row, col).real(), ref_out(row, col).real(), tol); + ASSERT_NEAR(jit_out(row, col).imag(), ref_out(row, col).imag(), tol); + } + } +} + +void AssertFFT2JITClose3D(auto &jit_out, auto &ref_out, index_t batches, index_t fft_dim, float tol) +{ + for (index_t batch = 0; batch < batches; batch++) { + for (index_t row = 0; row < fft_dim; row++) { + for (index_t col = 0; col < fft_dim; col++) { + ASSERT_NEAR(jit_out(batch, row, col).real(), ref_out(batch, row, col).real(), tol); + ASSERT_NEAR(jit_out(batch, row, col).imag(), ref_out(batch, row, col).imag(), tol); + } } } +} +} // namespace + +TEST(FFTJIT, CuFFTDx2DFFT2Fusion) +{ + MATX_ENTER_HANDLER(); + + constexpr index_t fft_dim = 4; + + auto in = make_tensor({fft_dim, fft_dim}); + auto jit_out = make_tensor({fft_dim, fft_dim}); + auto ref_out = make_tensor({fft_dim, fft_dim}); - auto expr = fft2(in) + complex_type{1.0f, -1.0f}; + FillFFT2JITInput2D(in, fft_dim); + + auto expr = fft2(in) + FFT2JITComplex{1.0f, -1.0f}; if (!jit_supported(expr)) { GTEST_SKIP(); } @@ -946,16 +990,11 @@ TEST(FFTJIT, CuFFTDx2DFFT2Fusion) CUDAJITExecutor jit_exec{}; cudaExecutor cuda_exec{}; (jit_out = expr).run(jit_exec); - (ref_out = fft2(in) + complex_type{1.0f, -1.0f}).run(cuda_exec); + (ref_out = fft2(in) + FFT2JITComplex{1.0f, -1.0f}).run(cuda_exec); jit_exec.sync(); cuda_exec.sync(); - for (index_t row = 0; row < fft_dim; row++) { - for (index_t col = 0; col < fft_dim; col++) { - ASSERT_NEAR(jit_out(row, col).real(), ref_out(row, col).real(), 0.01f); - ASSERT_NEAR(jit_out(row, col).imag(), ref_out(row, col).imag(), 0.01f); - } - } + AssertFFT2JITClose2D(jit_out, ref_out, fft_dim, 0.01f); MATX_EXIT_HANDLER(); } @@ -964,21 +1003,44 @@ TEST(FFTJIT, CuFFTDx2DIFFT2Fusion) { MATX_ENTER_HANDLER(); - using complex_type = cuda::std::complex; constexpr index_t fft_dim = 4; - auto in = make_tensor({fft_dim, fft_dim}); - auto jit_out = make_tensor({fft_dim, fft_dim}); - auto ref_out = make_tensor({fft_dim, fft_dim}); + auto in = make_tensor({fft_dim, fft_dim}); + auto jit_out = make_tensor({fft_dim, fft_dim}); + auto ref_out = make_tensor({fft_dim, fft_dim}); - for (index_t row = 0; row < fft_dim; row++) { - for (index_t col = 0; col < fft_dim; col++) { - in(row, col) = complex_type{static_cast(row - col), - static_cast(1 + row + col)}; - } + FillFFT2JITInput2D(in, fft_dim); + + auto expr = ifft2(in, FFTNorm::BACKWARD) * FFT2JITComplex{0.5f, 0.0f}; + if (!jit_supported(expr)) { + GTEST_SKIP(); } - auto expr = ifft2(in, FFTNorm::BACKWARD) * complex_type{0.5f, 0.0f}; + CUDAJITExecutor jit_exec{}; + cudaExecutor cuda_exec{}; + (jit_out = expr).run(jit_exec); + (ref_out = ifft2(in, FFTNorm::BACKWARD) * FFT2JITComplex{0.5f, 0.0f}).run(cuda_exec); + jit_exec.sync(); + cuda_exec.sync(); + + AssertFFT2JITClose2D(jit_out, ref_out, fft_dim, 0.01f); + + MATX_EXIT_HANDLER(); +} + +TEST(FFTJIT, CuFFTDx2DFFT2OrthoBoundary) +{ + MATX_ENTER_HANDLER(); + + constexpr index_t fft_dim = 32; + + auto in = make_tensor({fft_dim, fft_dim}); + auto jit_out = make_tensor({fft_dim, fft_dim}); + auto ref_out = make_tensor({fft_dim, fft_dim}); + + FillFFT2JITInput2D(in, fft_dim); + + auto expr = fft2(in, FFTNorm::ORTHO) - FFT2JITComplex{0.25f, 0.75f}; if (!jit_supported(expr)) { GTEST_SKIP(); } @@ -986,17 +1048,42 @@ TEST(FFTJIT, CuFFTDx2DIFFT2Fusion) CUDAJITExecutor jit_exec{}; cudaExecutor cuda_exec{}; (jit_out = expr).run(jit_exec); - (ref_out = ifft2(in, FFTNorm::BACKWARD) * complex_type{0.5f, 0.0f}).run(cuda_exec); + (ref_out = fft2(in, FFTNorm::ORTHO) - FFT2JITComplex{0.25f, 0.75f}).run(cuda_exec); jit_exec.sync(); cuda_exec.sync(); - for (index_t row = 0; row < fft_dim; row++) { - for (index_t col = 0; col < fft_dim; col++) { - ASSERT_NEAR(jit_out(row, col).real(), ref_out(row, col).real(), 0.01f); - ASSERT_NEAR(jit_out(row, col).imag(), ref_out(row, col).imag(), 0.01f); - } + AssertFFT2JITClose2D(jit_out, ref_out, fft_dim, 0.05f); + + MATX_EXIT_HANDLER(); +} + +TEST(FFTJIT, CuFFTDx2DFFT2BatchedForwardNorm) +{ + MATX_ENTER_HANDLER(); + + constexpr index_t batches = 3; + constexpr index_t fft_dim = 8; + + auto in = make_tensor({batches, fft_dim, fft_dim}); + auto jit_out = make_tensor({batches, fft_dim, fft_dim}); + auto ref_out = make_tensor({batches, fft_dim, fft_dim}); + + FillFFT2JITInput3D(in, batches, fft_dim); + + auto expr = fft2(in, FFTNorm::FORWARD) + FFT2JITComplex{0.125f, -0.5f}; + if (!jit_supported(expr)) { + GTEST_SKIP(); } + CUDAJITExecutor jit_exec{}; + cudaExecutor cuda_exec{}; + (jit_out = expr).run(jit_exec); + (ref_out = fft2(in, FFTNorm::FORWARD) + FFT2JITComplex{0.125f, -0.5f}).run(cuda_exec); + jit_exec.sync(); + cuda_exec.sync(); + + AssertFFT2JITClose3D(jit_out, ref_out, batches, fft_dim, 0.01f); + MATX_EXIT_HANDLER(); } #endif From e9aa8d1777141e1d3192baf3d09ae2b9fbc69561 Mon Sep 17 00:00:00 2001 From: Cliff Burdick Date: Tue, 26 May 2026 16:33:30 -0700 Subject: [PATCH 3/6] address greptile review feedback (greploop iteration 2) --- include/matx/transforms/fft/fft_cufftdx.h | 28 ++++++++++++++++------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/include/matx/transforms/fft/fft_cufftdx.h b/include/matx/transforms/fft/fft_cufftdx.h index f7f76fb9b..6abdd0941 100644 --- a/include/matx/transforms/fft/fft_cufftdx.h +++ b/include/matx/transforms/fft/fft_cufftdx.h @@ -444,8 +444,9 @@ namespace matx { FFTType fft_type_ = FFTType::C2C; FFTDirection direction_ = FFTDirection::FORWARD; int cc_ = 0; - cuFFTDxHelper fft_x_helper_; - cuFFTDxHelper fft_y_helper_; + mutable cuFFTDxHelper fft_x_helper_; + mutable cuFFTDxHelper fft_y_helper_; + mutable bool helpers_configured_ = false; static ElementsPerThread IntToElementsPerThread(int ept) { switch (ept) { @@ -459,7 +460,11 @@ namespace matx { } } - void Configure1DHelpers() { + void Configure1DHelpers() const { + if (helpers_configured_) { + return; + } + if (fft_size_x_ <= 0 || fft_size_y_ <= 0 || cc_ <= 0) { return; } @@ -479,6 +484,7 @@ namespace matx { fft_y_helper_.set_cc(cc_); fft_y_helper_.set_contiguous_input(false); fft_y_helper_.set_method(cuFFTDxMethod::SHARED); + helpers_configured_ = true; } public: @@ -490,11 +496,11 @@ namespace matx { FFTDirection get_direction() const { return direction_; } int get_cc() const { return cc_; } - void set_fft_size_x(index_t size) { fft_size_x_ = size; Configure1DHelpers(); } - void set_fft_size_y(index_t size) { fft_size_y_ = size; Configure1DHelpers(); } - void set_fft_type(FFTType type) { fft_type_ = type; Configure1DHelpers(); } - void set_direction(FFTDirection dir) { direction_ = dir; Configure1DHelpers(); } - void set_cc(int cc) { cc_ = cc; Configure1DHelpers(); } + void set_fft_size_x(index_t size) { fft_size_x_ = size; helpers_configured_ = false; } + void set_fft_size_y(index_t size) { fft_size_y_ = size; helpers_configured_ = false; } + void set_fft_type(FFTType type) { fft_type_ = type; helpers_configured_ = false; } + void set_direction(FFTDirection dir) { direction_ = dir; helpers_configured_ = false; } + void set_cc(int cc) { cc_ = cc; helpers_configured_ = false; } #if defined(MATX_EN_MATHDX) && defined(__CUDACC__) std::string GetSymbolName() const { @@ -545,10 +551,12 @@ namespace matx { } bool IsSupported() const { + Configure1DHelpers(); return fft_x_helper_.IsSupported() && fft_y_helper_.IsSupported(); } int GetShmRequired() const { + Configure1DHelpers(); const auto data_size = static_cast(fft_size_x_) * static_cast(fft_size_y_) * static_cast(sizeof(InputType)); @@ -562,6 +570,7 @@ namespace matx { } int GetBlockDim() const { + Configure1DHelpers(); const auto block_x = fft_x_helper_.GetBlockDim(); const auto block_y = fft_y_helper_.GetBlockDim(); if (block_x != block_y) { @@ -586,15 +595,18 @@ namespace matx { } bool GenerateLTOIR(std::set <oir_symbols) { + Configure1DHelpers(); return fft_x_helper_.GenerateLTOIR(ltoir_symbols) && fft_y_helper_.GenerateLTOIR(ltoir_symbols); } std::string GetXFuncName() { + Configure1DHelpers(); return std::string(FFT_DX_FUNC_PREFIX) + "_" + fft_x_helper_.GetSymbolName(); } std::string GetYFuncName() { + Configure1DHelpers(); return std::string(FFT_DX_FUNC_PREFIX) + "_" + fft_y_helper_.GetSymbolName(); } From c25ed6fa014f51c23c2a5c4e16c43e2268bbd999 Mon Sep 17 00:00:00 2001 From: Cliff Burdick Date: Tue, 26 May 2026 16:53:01 -0700 Subject: [PATCH 4/6] address greptile review feedback (greploop iteration 3) --- include/matx/operators/fft.h | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/include/matx/operators/fft.h b/include/matx/operators/fft.h index 43f913bee..ab9f6dbdf 100644 --- a/include/matx/operators/fft.h +++ b/include/matx/operators/fft.h @@ -832,13 +832,7 @@ namespace matx } } - int major = 0; - int minor = 0; - int device; - cudaGetDevice(&device); - cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); - cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device); - int cc = major * 100 + minor * 10; + const int cc = GetComputeCapability(); dx_fft2_helper_.set_fft_size_y(out_dims_[Rank() - 2]); dx_fft2_helper_.set_fft_size_x(out_dims_[Rank() - 1]); @@ -858,6 +852,8 @@ namespace matx symbol_name += std::to_string(static_cast(Type)); symbol_name += "_D"; symbol_name += Direction == detail::FFTDirection::FORWARD ? std::string("F") : std::string("B"); + symbol_name += "_SM"; + symbol_name += std::to_string(dx_fft2_helper_.get_cc()); return symbol_name; } From 325d394e8bc091db7624e48e5c387748773f2b4c Mon Sep 17 00:00:00 2001 From: Cliff Burdick Date: Tue, 26 May 2026 17:13:28 -0700 Subject: [PATCH 5/6] address greptile review feedback (greploop iteration 4) --- include/matx/executors/jit_kernel.h | 33 ++++++++++++++++++----------- include/matx/operators/fft.h | 3 +++ 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/include/matx/executors/jit_kernel.h b/include/matx/executors/jit_kernel.h index ec6a5103a..4ed516d91 100644 --- a/include/matx/executors/jit_kernel.h +++ b/include/matx/executors/jit_kernel.h @@ -447,12 +447,15 @@ namespace matx { int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\ constexpr int ept = static_cast(CurrentCapabilities::ept);\n\ matx::index_t size1_vectors = (size1 + ept - 1) / ept;\n\ + if (size1_vectors == 0) return;\n\ matx::index_t idx = tid % size1_vectors;\n\ matx::index_t idy = tid / size1_vectors;\n\ - if constexpr (cuda::std::is_pointer_v) {\n\ - (*op).template operator()(idy, idx);\n\ - } else {\n\ - op.template operator()(idy, idx);\n\ + if (idy < size0 && idx * static_cast(ept) < size1) {\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()(idy, idx);\n\ + } else {\n\ + op.template operator()(idy, idx);\n\ + }\n\ }\n\ }\n\ \n\ @@ -461,13 +464,16 @@ namespace matx { int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\ constexpr int ept = static_cast(CurrentCapabilities::ept);\n\ matx::index_t size2_vectors = (size2 + ept - 1) / ept;\n\ + if (size2_vectors == 0) return;\n\ matx::index_t idx = tid % size2_vectors;\n\ matx::index_t idy = tid / size2_vectors;\n\ matx::index_t idz = blockIdx.x;\n\ - if constexpr (cuda::std::is_pointer_v) {\n\ - (*op).template operator()(idz, idy, idx);\n\ - } else {\n\ - op.template operator()(idz, idy, idx);\n\ + if (idz < size0 && idy < size1 && idx * static_cast(ept) < size2) {\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()(idz, idy, idx);\n\ + } else {\n\ + op.template operator()(idz, idy, idx);\n\ + }\n\ }\n\ }\n\ \n\ @@ -476,14 +482,17 @@ namespace matx { int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\ constexpr int ept = static_cast(CurrentCapabilities::ept);\n\ matx::index_t size3_vectors = (size3 + ept - 1) / ept;\n\ + if (size3_vectors == 0) return;\n\ matx::index_t idx = tid % size3_vectors;\n\ matx::index_t idy = tid / size3_vectors;\n\ matx::index_t idz = blockIdx.x;\n\ matx::index_t idw = blockIdx.y;\n\ - if constexpr (cuda::std::is_pointer_v) {\n\ - (*op).template operator()(idw, idz, idy, idx);\n\ - } else {\n\ - op.template operator()(idw, idz, idy, idx);\n\ + if (idw < size0 && idz < size1 && idy < size2 && idx * static_cast(ept) < size3) {\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()(idw, idz, idy, idx);\n\ + } else {\n\ + op.template operator()(idw, idz, idy, idx);\n\ + }\n\ }\n\ }\n\ }\n\ diff --git a/include/matx/operators/fft.h b/include/matx/operators/fft.h index ab9f6dbdf..6cca74298 100644 --- a/include/matx/operators/fft.h +++ b/include/matx/operators/fft.h @@ -845,6 +845,9 @@ namespace matx #if defined(MATX_EN_MATHDX) && defined(__CUDACC__) __MATX_INLINE__ std::string get_jit_class_name() const { std::string symbol_name = "JITFFT2Op_"; + symbol_name += "R"; + symbol_name += std::to_string(jit_rank()); + symbol_name += "_"; symbol_name += std::to_string(dx_fft2_helper_.get_fft_size_y()); symbol_name += "_"; symbol_name += std::to_string(dx_fft2_helper_.get_fft_size_x()); From 14e39b37a4c64ab2ad99cf84080ff4d6732bfc64 Mon Sep 17 00:00:00 2001 From: Cliff Burdick Date: Tue, 26 May 2026 17:38:33 -0700 Subject: [PATCH 6/6] address greptile review feedback (greploop iteration 5) --- include/matx/operators/fft.h | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/include/matx/operators/fft.h b/include/matx/operators/fft.h index 6cca74298..9e994f505 100644 --- a/include/matx/operators/fft.h +++ b/include/matx/operators/fft.h @@ -857,6 +857,8 @@ namespace matx symbol_name += Direction == detail::FFTDirection::FORWARD ? std::string("F") : std::string("B"); symbol_name += "_SM"; symbol_name += std::to_string(dx_fft2_helper_.get_cc()); + symbol_name += "_N"; + symbol_name += std::to_string(static_cast(norm_)); return symbol_name; } @@ -1066,6 +1068,27 @@ namespace matx MATX_LOG_DEBUG("cuFFTDx 2D JIT_TYPE_QUERY: {}", result); return result; } + else if constexpr (Cap == OperatorCapability::JIT_CACHE_KEY) { +#ifdef MATX_EN_JIT + auto key = detail::MakeJITCacheKeyForType("JITFFT2"); + const int actual_rank = jit_rank(); + detail::HashJITCacheValue(key, actual_rank); + for (int i = 0; i < actual_rank; ++i) { + detail::HashJITCacheValue(key, out_dims_[i]); + } + detail::HashJITCacheValue(key, dx_fft2_helper_.get_fft_size_y()); + detail::HashJITCacheValue(key, dx_fft2_helper_.get_fft_size_x()); + detail::HashJITCacheValue(key, dx_fft2_helper_.get_cc()); + detail::HashJITCacheValue(key, static_cast(Type)); + detail::HashJITCacheValue(key, static_cast(Direction)); + detail::HashJITCacheValue(key, static_cast(norm_)); + detail::HashJITCacheString(key, dx_fft2_helper_.GetXFuncName()); + detail::HashJITCacheString(key, dx_fft2_helper_.GetYFuncName()); + return combine_capabilities(key, detail::get_operator_capability(a_, in)); +#else + return detail::MakeInvalidJITCacheKey(); +#endif + } else { auto self_has_cap = capability_attributes::default_value; auto result = combine_capabilities(self_has_cap, detail::get_operator_capability(a_, in));