diff --git a/docs_input/api/dft/fft/fft2d.rst b/docs_input/api/dft/fft/fft2d.rst index 8b3636446..297322100 100644 --- a/docs_input/api/dft/fft/fft2d.rst +++ b/docs_input/api/dft/fft/fft2d.rst @@ -5,6 +5,10 @@ fft2 Perform a 2D FFT. Batching is supported for any tensor with a rank higher than 2. +FFT kernel fusion is supported by cuFFTDx for complex-to-complex power-of-two square +transforms that fit in a single CUDA block when ``-DMATX_EN_MATHDX=ON`` is enabled. +Unsupported 2D FFT sizes and real-valued 2D FFTs use the existing cuFFT execution path. + .. versionadded:: 0.6.0 .. doxygenfunction:: fft2(const OpA &a, FFTNorm norm = FFTNorm::BACKWARD) @@ -31,4 +35,4 @@ Examples :language: cpp :start-after: example-begin fft2-2 :end-before: example-end fft2-2 - :dedent: \ No newline at end of file + :dedent: diff --git a/docs_input/api/dft/fft/ifft2.rst b/docs_input/api/dft/fft/ifft2.rst index be910ff2d..7129bb32a 100644 --- a/docs_input/api/dft/fft/ifft2.rst +++ b/docs_input/api/dft/fft/ifft2.rst @@ -5,6 +5,10 @@ ifft2 Perform a 2D inverse FFT. Batching is supported for any tensor with a rank higher than 2. +IFFT kernel fusion is supported by cuFFTDx for complex-to-complex power-of-two square +transforms that fit in a single CUDA block when ``-DMATX_EN_MATHDX=ON`` is enabled. +Unsupported 2D IFFT sizes and real-valued inverse 2D FFTs use the existing cuFFT execution path. + .. versionadded:: 0.6.0 .. doxygenfunction:: ifft2(const OpA &a, FFTNorm norm = FFTNorm::BACKWARD) @@ -31,4 +35,4 @@ Examples :language: cpp :start-after: example-begin ifft2-2 :end-before: example-end ifft2-2 - :dedent: \ No newline at end of file + :dedent: diff --git a/docs_input/basics/fusion.rst b/docs_input/basics/fusion.rst index 83eb740b4..4989c09f3 100644 --- a/docs_input/basics/fusion.rst +++ b/docs_input/basics/fusion.rst @@ -61,7 +61,8 @@ CUDA JIT Kernel Fusion MatX supports CUDA JIT kernel fusion that compiles the entire expression into a single kernel. Currently this is enabled for all standard MatX element-wise operators, FFT operations via cuFFTDx, GEMM operations via cuBLASDx, and selected -solver operations via cuSolverDx. To enable fusion with MathDx, the following option must be enabled: +solver operations via cuSolverDx. cuFFTDx supports 1D FFT fusion and single-block complex-to-complex 2D ``fft2``/``ifft2`` +fusion for supported power-of-two square transforms. To enable fusion with MathDx, the following option must be enabled: ``-DMATX_EN_MATHDX=ON``. MathDx support also enables the NVRTC-based JIT support used by ``CUDAJITExecutor``. Once enabled, the ``CUDAJITExecutor`` can be used to perform JIT compilation in supported situations. If the expression cannot be JIT compiled, the ``CUDAJITExecutor`` may throw an error. @@ -129,9 +130,9 @@ MathDx Compatibility cuSolverDx ``inv`` when their block-dimension ranges intersect. * - cuFFTDx - Yes - - Enabled via ``-DMATX_EN_MATHDX=ON`` for compatible FFT fusion paths. cuFFTDx has stricter launch-shape - requirements than the generic element-wise JIT path, and does not generally compose with cuBLASDx/cuSolverDx - operations that require a different block/grid model. + - Enabled via ``-DMATX_EN_MATHDX=ON`` for compatible 1D FFT fusion paths and supported single-block 2D C2C FFT + fusion paths. cuFFTDx has stricter launch-shape requirements than the generic element-wise JIT path, and does not + generally compose with cuBLASDx/cuSolverDx operations that require a different block/grid model. * - cuSolverDx - Partial - Enabled via ``-DMATX_EN_MATHDX=ON`` for ``chol``, ``inv``, and selected projection outputs from diff --git a/include/matx/core/get_grid_dims.h b/include/matx/core/get_grid_dims.h index b75eeed1d..a224a3b93 100644 --- a/include/matx/core/get_grid_dims.h +++ b/include/matx/core/get_grid_dims.h @@ -338,12 +338,17 @@ inline bool get_grid_dims_block_reduce(dim3 &blocks, dim3 &threads, const cuda:: template inline bool get_grid_dims_block_2d(dim3 &blocks, dim3 &threads, const cuda::std::array &sizes, - int block_dim) { + int block_dim, + int groups_per_block = 1) { // Threads are set to block_dim in x, y and z are 1 // All threads cooperate via flattened thread ID in the kernel threads.x = block_dim; - threads.y = 1; + threads.y = groups_per_block; threads.z = 1; + + if (static_cast(threads.x) * static_cast(threads.y) > 1024) { + MATX_THROW(matxInvalidParameter, "Block2D launch exceeds CUDA maximum threads per block"); + } // Grid covers batch dimensions only (dims 0 to RANK-3) blocks.x = 1; @@ -372,7 +377,8 @@ inline bool get_grid_dims_block_2d(dim3 &blocks, dim3 &threads, } } - MATX_LOG_DEBUG("Block2D: Blocks {}x{}x{} Threads {}x{}x{}", blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z); + MATX_LOG_DEBUG("Block2D: Blocks {}x{}x{} Threads {}x{}x{} groups_per_block={}", + blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, groups_per_block); // No stride needed for now - could be extended for very large batches return false; @@ -427,11 +433,12 @@ template inline bool get_grid_dims_block_pass_through(dim3 &blocks, dim3 &threads, const cuda::std::array &sizes, int block_dim, - int inner_rank) { + int inner_rank, + int groups_per_block = 1) { if (inner_rank == 1) { return get_grid_dims_block_1d(blocks, threads, sizes, block_dim); } - return get_grid_dims_block_2d(blocks, threads, sizes, block_dim); + return get_grid_dims_block_2d(blocks, threads, sizes, block_dim, groups_per_block); } } // end namespace detail } // end namespace matx diff --git a/include/matx/executors/jit_cuda.h b/include/matx/executors/jit_cuda.h index cddf00678..4fdee6362 100644 --- a/include/matx/executors/jit_cuda.h +++ b/include/matx/executors/jit_cuda.h @@ -510,15 +510,23 @@ namespace matx // For pass-through operators (e.g., MathDx), block dimensions are constrained by the operator. auto block_dim_range = detail::get_operator_capability(op); block_size = detail::SelectJITPassThroughBlockDim(block_dim_range); - stride = detail::get_grid_dims_block_pass_through(blocks, threads, sizes, block_size, pass_through_inner_rank); - - // EPT is 1 for 2D block operators - the operator handles elements internally - best_ept = detail::ElementsPerThread::ONE; + auto group_range = detail::get_operator_capability(op); + groups_per_block = group_range[0]; + if (groups_per_block == detail::capability_attributes::invalid) { + MATX_THROW(matxInvalidParameter, "No valid JIT groups-per-block value satisfies the fused operator requirements"); + } + stride = detail::get_grid_dims_block_pass_through( + blocks, threads, sizes, block_size, pass_through_inner_rank, groups_per_block); + + // Pass-through operators do not run the normal SET_ELEMENTS_PER_THREAD query path. + // Use the lower bound so existing operators with a default [ONE, MAX] range keep + // the historical EPT=ONE behavior. Operators that require vector lanes, such as + // FFT2, advertise a fixed [N, N] range. + best_ept = jit_ept_bounds[0]; shm_size = detail::get_operator_capability(op); - groups_per_block = 1; - MATX_LOG_DEBUG("Block2D: EPT {}, Shm size {}, Block size {}", - static_cast(best_ept), shm_size, block_size); + MATX_LOG_DEBUG("Block2D: EPT {}, Shm size {}, Block size {}, Groups per block {}", + static_cast(best_ept), shm_size, block_size, groups_per_block); } else if constexpr (is_dynamic_rank_op_v) { // Dynamic tensor expressions: pre-compiled kernels don't exist for this Op type, // so we cannot query register pressure. Use conservative defaults. diff --git a/include/matx/executors/jit_kernel.h b/include/matx/executors/jit_kernel.h index cec56ba56..4ed516d91 100644 --- a/include/matx/executors/jit_kernel.h +++ b/include/matx/executors/jit_kernel.h @@ -445,39 +445,54 @@ namespace matx { template \n\ __global__ void matxOpT2KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1) {\n\ int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\ - matx::index_t idx = tid % size1;\n\ - matx::index_t idy = tid / size1;\n\ - if constexpr (cuda::std::is_pointer_v) {\n\ - (*op).template operator()(idy, idx);\n\ - } else {\n\ - op.template operator()(idy, idx);\n\ + constexpr int ept = static_cast(CurrentCapabilities::ept);\n\ + matx::index_t size1_vectors = (size1 + ept - 1) / ept;\n\ + if (size1_vectors == 0) return;\n\ + matx::index_t idx = tid % size1_vectors;\n\ + matx::index_t idy = tid / size1_vectors;\n\ + if (idy < size0 && idx * static_cast(ept) < size1) {\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()(idy, idx);\n\ + } else {\n\ + op.template operator()(idy, idx);\n\ + }\n\ }\n\ }\n\ \n\ template \n\ __global__ void matxOpT3KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2) {\n\ int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\ - matx::index_t idx = tid % size2;\n\ - matx::index_t idy = tid / size2;\n\ + constexpr int ept = static_cast(CurrentCapabilities::ept);\n\ + matx::index_t size2_vectors = (size2 + ept - 1) / ept;\n\ + if (size2_vectors == 0) return;\n\ + matx::index_t idx = tid % size2_vectors;\n\ + matx::index_t idy = tid / size2_vectors;\n\ matx::index_t idz = blockIdx.x;\n\ - if constexpr (cuda::std::is_pointer_v) {\n\ - (*op).template operator()(idz, idy, idx);\n\ - } else {\n\ - op.template operator()(idz, idy, idx);\n\ + if (idz < size0 && idy < size1 && idx * static_cast(ept) < size2) {\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()(idz, idy, idx);\n\ + } else {\n\ + op.template operator()(idz, idy, idx);\n\ + }\n\ }\n\ }\n\ \n\ template \n\ __global__ void matxOpT4KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2, matx::index_t size3) {\n\ int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\ - matx::index_t idx = tid % size3;\n\ - matx::index_t idy = tid / size3;\n\ + constexpr int ept = static_cast(CurrentCapabilities::ept);\n\ + matx::index_t size3_vectors = (size3 + ept - 1) / ept;\n\ + if (size3_vectors == 0) return;\n\ + matx::index_t idx = tid % size3_vectors;\n\ + matx::index_t idy = tid / size3_vectors;\n\ matx::index_t idz = blockIdx.x;\n\ matx::index_t idw = blockIdx.y;\n\ - if constexpr (cuda::std::is_pointer_v) {\n\ - (*op).template operator()(idw, idz, idy, idx);\n\ - } else {\n\ - op.template operator()(idw, idz, idy, idx);\n\ + if (idw < size0 && idz < size1 && idy < size2 && idx * static_cast(ept) < size3) {\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()(idw, idz, idy, idx);\n\ + } else {\n\ + op.template operator()(idw, idz, idy, idx);\n\ + }\n\ }\n\ }\n\ }\n\ diff --git a/include/matx/operators/fft.h b/include/matx/operators/fft.h index 15706ba26..9e994f505 100644 --- a/include/matx/operators/fft.h +++ b/include/matx/operators/fft.h @@ -239,7 +239,7 @@ namespace matx " constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int dim) const\n" + " {\n" + " return out_dims_[dim];\n " + - " }\n" + + " }\n" + "};\n") ); } @@ -754,12 +754,31 @@ namespace matx mutable detail::tensor_impl_t tmp_out_; mutable ttype *ptr = nullptr; mutable bool prerun_done_ = false; +#if defined(MATX_EN_MATHDX) && defined(__CUDACC__) + mutable cuFFTDx2DHelper dx_fft2_helper_; + bool jit_axes_supported_ = true; +#endif public: using matxop = bool; using value_type = typename OpA::value_type; using matx_transform_op = bool; using fft2_xform_op = bool; + using self_type = FFT2Op; + + // Propagate dynamic tensor marker through expression tree + using dynamic_tensor_expr = cuda::std::bool_constant< + is_dynamic_tensor_v || is_dynamic_rank_op_v>; + +#ifdef MATX_EN_JIT + struct JIT_Storage { + typename detail::inner_storage_or_self_t> a_; + }; + + JIT_Storage ToJITStorage() const { + return JIT_Storage{detail::to_jit_storage(a_)}; + } +#endif __MATX_INLINE__ std::string str() const { if constexpr (Direction == detail::FFTDirection::FORWARD) { @@ -802,9 +821,88 @@ namespace matx out_dims_[Rank() - 2] = out_dims_[Rank() - 2]; } } + +#if defined(MATX_EN_MATHDX) && defined(__CUDACC__) + if constexpr (!std::is_same_v) { + for (int32_t i = 0; i < Rank(); i++) { + if (perm_[static_cast(i)] != i) { + jit_axes_supported_ = false; + break; + } + } + } + + const int cc = GetComputeCapability(); + + dx_fft2_helper_.set_fft_size_y(out_dims_[Rank() - 2]); + dx_fft2_helper_.set_fft_size_x(out_dims_[Rank() - 1]); + dx_fft2_helper_.set_fft_type(Type); + dx_fft2_helper_.set_direction(Direction); + dx_fft2_helper_.set_cc(cc); +#endif } - +#if defined(MATX_EN_MATHDX) && defined(__CUDACC__) + __MATX_INLINE__ std::string get_jit_class_name() const { + std::string symbol_name = "JITFFT2Op_"; + symbol_name += "R"; + symbol_name += std::to_string(jit_rank()); + symbol_name += "_"; + symbol_name += std::to_string(dx_fft2_helper_.get_fft_size_y()); + symbol_name += "_"; + symbol_name += std::to_string(dx_fft2_helper_.get_fft_size_x()); + symbol_name += "_T"; + symbol_name += std::to_string(static_cast(Type)); + symbol_name += "_D"; + symbol_name += Direction == detail::FFTDirection::FORWARD ? std::string("F") : std::string("B"); + symbol_name += "_SM"; + symbol_name += std::to_string(dx_fft2_helper_.get_cc()); + symbol_name += "_N"; + symbol_name += std::to_string(static_cast(norm_)); + return symbol_name; + } + + __MATX_INLINE__ auto get_jit_op_str() const { + const int actual_rank = jit_rank(); + const std::string class_name = get_jit_class_name(); + const std::string fft_x_func_name = dx_fft2_helper_.GetXFuncName(); + const std::string fft_y_func_name = dx_fft2_helper_.GetYFuncName(); + + std::string declarations = + " extern \"C\" __device__ void " + fft_x_func_name + "(" + detail::type_to_string() + "*);\n"; + if (fft_y_func_name != fft_x_func_name) { + declarations += + " extern \"C\" __device__ void " + fft_y_func_name + "(" + detail::type_to_string() + "*);\n"; + } + + return cuda::std::make_tuple( + class_name, + std::string( + declarations + + " template struct " + class_name + " {\n" + + " using input_type = typename OpA::value_type;\n" + + " using matxop = bool;\n" + + " using value_type = input_type;\n" + + " typename detail::inner_storage_or_self_t> a_;\n" + + " constexpr static cuda::std::array out_dims_ = { " + + detail::array_to_string(out_dims_, actual_rank) + " };\n" + + " template \n" + + " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) const\n" + + " {\n" + + " " + dx_fft2_helper_.GetFuncStr(fft_x_func_name, fft_y_func_name, static_cast(norm_), actual_rank) + "\n" + + " }\n" + + " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank()\n" + + " {\n" + + " return " + std::to_string(actual_rank) + ";\n" + + " }\n" + + " constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int dim) const\n" + + " {\n" + + " return out_dims_[dim];\n " + + " }\n" + + "};\n") + ); + } +#endif template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const @@ -825,6 +923,14 @@ namespace matx return out_dims_[dim]; } + __MATX_INLINE__ __MATX_HOST__ int32_t DynRank() const { + return detail::get_dyn_rank(a_); + } + + __MATX_INLINE__ __MATX_HOST__ int32_t jit_rank() const { + if constexpr (is_dynamic_rank_op_v) return DynRank(); + else return Rank(); + } __MATX_HOST__ __MATX_INLINE__ auto Data() const noexcept { return ptr; } @@ -865,10 +971,143 @@ namespace matx template __MATX_INLINE__ __MATX_HOST__ auto get_capability([[maybe_unused]] InType &in) const { - // 1. Determine if the binary operation ITSELF intrinsically has this capability. - auto self_has_cap = capability_attributes::default_value; - auto result = combine_capabilities(self_has_cap, detail::get_operator_capability(a_, in)); - return result; +#if defined(MATX_EN_MATHDX) && defined(__CUDACC__) + if constexpr (Cap == OperatorCapability::DYN_SHM_SIZE) { + const int self_shm = (jit_axes_supported_ && dx_fft2_helper_.template CheckJITSizeAndTypeRequirements()) ? + dx_fft2_helper_.GetShmRequired() : capability_attributes::default_value; + auto result = combine_capabilities(self_shm, detail::get_operator_capability(a_, in)); + MATX_LOG_DEBUG("cuFFTDx 2D DYN_SHM_SIZE: {}", result); + return result; + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { + bool supported = jit_axes_supported_; + if (supported && !dx_fft2_helper_.template CheckJITSizeAndTypeRequirements()) { + supported = false; + } + else if (supported) { + supported = dx_fft2_helper_.IsSupported(); + if (supported && dx_fft2_helper_.GetElementsPerThread() == ElementsPerThread::INVALID) { + supported = false; + } + } + + auto result = combine_capabilities(supported, detail::get_operator_capability(a_, in)); + MATX_LOG_DEBUG("cuFFTDx 2D SUPPORTS_JIT: {}", result); + return result; + } + else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { + const auto [key, value] = get_jit_op_str(); + if (in.find(key) == in.end()) { + in[key] = value; + } + detail::get_operator_capability(a_, in); + MATX_LOG_DEBUG("cuFFTDx 2D JIT_CLASS_QUERY: true"); + return true; + } + else if constexpr (Cap == OperatorCapability::ELEMENTS_PER_THREAD) { + ElementsPerThread ept = ElementsPerThread::ONE; + if (in.jit) { + ept = (jit_axes_supported_ && dx_fft2_helper_.template CheckJITSizeAndTypeRequirements()) ? + dx_fft2_helper_.GetElementsPerThread() : ElementsPerThread::INVALID; + } + const auto my_cap = cuda::std::array{ept, ept}; + auto result = combine_capabilities(my_cap, detail::get_operator_capability(a_, in)); + MATX_LOG_DEBUG("cuFFTDx 2D ELEMENTS_PER_THREAD: [{},{}]", + static_cast(result[0]), static_cast(result[1])); + return result; + } + else if constexpr (Cap == OperatorCapability::MAX_EPT_VEC_LOAD) { + auto result = combine_capabilities(1, detail::get_operator_capability(a_, in)); + MATX_LOG_DEBUG("cuFFTDx 2D MAX_EPT_VEC_LOAD: {}", result); + return result; + } + else if constexpr (Cap == OperatorCapability::GLOBAL_KERNEL) { + MATX_LOG_DEBUG("cuFFTDx 2D GLOBAL_KERNEL: false"); + return false; + } + else if constexpr (Cap == OperatorCapability::PASS_THROUGH_THREADS) { + MATX_LOG_DEBUG("cuFFTDx 2D PASS_THROUGH_THREADS: true"); + return true; + } + else if constexpr (Cap == OperatorCapability::PASS_THROUGH_INNER_RANK) { + MATX_LOG_DEBUG("cuFFTDx 2D PASS_THROUGH_INNER_RANK: 2"); + return 2; + } + else if constexpr (Cap == OperatorCapability::GROUPS_PER_BLOCK) { + const int ffts_per_block = dx_fft2_helper_.GetFFTsPerBlock(); + const auto my_cap = cuda::std::array{ffts_per_block, ffts_per_block}; + auto result = combine_capabilities(my_cap, detail::get_operator_capability(a_, in)); + MATX_LOG_DEBUG("cuFFTDx 2D GROUPS_PER_BLOCK: [{},{}]", result[0], result[1]); + return result; + } + else if constexpr (Cap == OperatorCapability::SET_GROUPS_PER_BLOCK || + Cap == OperatorCapability::SET_ELEMENTS_PER_THREAD) { + auto result = combine_capabilities(capability_attributes::default_value, detail::get_operator_capability(a_, in)); + MATX_LOG_DEBUG("cuFFTDx 2D SET_GROUPS/EPT: {}", result); + return result; + } + else if constexpr (Cap == OperatorCapability::BLOCK_DIM) { + const int block_dim = dx_fft2_helper_.GetBlockDim(); + const auto my_block = block_dim > 0 ? + cuda::std::array{block_dim, block_dim} : + cuda::std::array{capability_attributes::invalid, capability_attributes::invalid}; + auto result = combine_capabilities(my_block, detail::get_operator_capability(a_, in)); + MATX_LOG_DEBUG("cuFFTDx 2D BLOCK_DIM: [{},{}]", result[0], result[1]); + return result; + } + else if constexpr (Cap == OperatorCapability::GENERATE_LTOIR) { + auto result = combine_capabilities( + dx_fft2_helper_.GenerateLTOIR(in.ltoir_symbols), + detail::get_operator_capability(a_, in)); + MATX_LOG_DEBUG("cuFFTDx 2D GENERATE_LTOIR: {}", result); + return result; + } + else if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { + const auto inner_op_jit_name = detail::get_operator_capability(a_, in); + auto result = get_jit_class_name() + "<" + inner_op_jit_name + ">"; + MATX_LOG_DEBUG("cuFFTDx 2D JIT_TYPE_QUERY: {}", result); + return result; + } + else if constexpr (Cap == OperatorCapability::JIT_CACHE_KEY) { +#ifdef MATX_EN_JIT + auto key = detail::MakeJITCacheKeyForType("JITFFT2"); + const int actual_rank = jit_rank(); + detail::HashJITCacheValue(key, actual_rank); + for (int i = 0; i < actual_rank; ++i) { + detail::HashJITCacheValue(key, out_dims_[i]); + } + detail::HashJITCacheValue(key, dx_fft2_helper_.get_fft_size_y()); + detail::HashJITCacheValue(key, dx_fft2_helper_.get_fft_size_x()); + detail::HashJITCacheValue(key, dx_fft2_helper_.get_cc()); + detail::HashJITCacheValue(key, static_cast(Type)); + detail::HashJITCacheValue(key, static_cast(Direction)); + detail::HashJITCacheValue(key, static_cast(norm_)); + detail::HashJITCacheString(key, dx_fft2_helper_.GetXFuncName()); + detail::HashJITCacheString(key, dx_fft2_helper_.GetYFuncName()); + return combine_capabilities(key, detail::get_operator_capability(a_, in)); +#else + return detail::MakeInvalidJITCacheKey(); +#endif + } + else { + auto self_has_cap = capability_attributes::default_value; + auto result = combine_capabilities(self_has_cap, detail::get_operator_capability(a_, in)); + return result; + } +#else + if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { + bool supported = false; + return combine_capabilities(supported, detail::get_operator_capability(a_, in)); + } + else if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { + return ""; + } + else { + auto self_has_cap = capability_attributes::default_value; + auto result = combine_capabilities(self_has_cap, detail::get_operator_capability(a_, in)); + return result; + } +#endif } template diff --git a/include/matx/transforms/fft/fft_cufftdx.h b/include/matx/transforms/fft/fft_cufftdx.h index f1f8bbd94..6abdd0941 100644 --- a/include/matx/transforms/fft/fft_cufftdx.h +++ b/include/matx/transforms/fft/fft_cufftdx.h @@ -38,6 +38,7 @@ #include "matx/core/operator_options.h" #include "matx/core/capabilities.h" #include "matx/core/log.h" +#include #include #define LIBMATHDX_CHECK(ans) \ @@ -435,5 +436,288 @@ namespace matx { #endif }; + template + class cuFFTDx2DHelper { + private: + index_t fft_size_x_ = 0; + index_t fft_size_y_ = 0; + FFTType fft_type_ = FFTType::C2C; + FFTDirection direction_ = FFTDirection::FORWARD; + int cc_ = 0; + mutable cuFFTDxHelper fft_x_helper_; + mutable cuFFTDxHelper fft_y_helper_; + mutable bool helpers_configured_ = false; + + static ElementsPerThread IntToElementsPerThread(int ept) { + switch (ept) { + case 1: return ElementsPerThread::ONE; + case 2: return ElementsPerThread::TWO; + case 4: return ElementsPerThread::FOUR; + case 8: return ElementsPerThread::EIGHT; + case 16: return ElementsPerThread::SIXTEEN; + case 32: return ElementsPerThread::THIRTY_TWO; + default: return ElementsPerThread::INVALID; + } + } + + void Configure1DHelpers() const { + if (helpers_configured_) { + return; + } + + if (fft_size_x_ <= 0 || fft_size_y_ <= 0 || cc_ <= 0) { + return; + } + + fft_x_helper_.set_fft_size(fft_size_x_); + fft_x_helper_.set_fft_type(fft_type_); + fft_x_helper_.set_direction(direction_); + fft_x_helper_.set_ffts_per_block(static_cast(fft_size_y_)); + fft_x_helper_.set_cc(cc_); + fft_x_helper_.set_contiguous_input(false); + fft_x_helper_.set_method(cuFFTDxMethod::SHARED); + + fft_y_helper_.set_fft_size(fft_size_y_); + fft_y_helper_.set_fft_type(fft_type_); + fft_y_helper_.set_direction(direction_); + fft_y_helper_.set_ffts_per_block(static_cast(fft_size_x_)); + fft_y_helper_.set_cc(cc_); + fft_y_helper_.set_contiguous_input(false); + fft_y_helper_.set_method(cuFFTDxMethod::SHARED); + helpers_configured_ = true; + } + + public: + cuFFTDx2DHelper() = default; + + index_t get_fft_size_x() const { return fft_size_x_; } + index_t get_fft_size_y() const { return fft_size_y_; } + FFTType get_fft_type() const { return fft_type_; } + FFTDirection get_direction() const { return direction_; } + int get_cc() const { return cc_; } + + void set_fft_size_x(index_t size) { fft_size_x_ = size; helpers_configured_ = false; } + void set_fft_size_y(index_t size) { fft_size_y_ = size; helpers_configured_ = false; } + void set_fft_type(FFTType type) { fft_type_ = type; helpers_configured_ = false; } + void set_direction(FFTDirection dir) { direction_ = dir; helpers_configured_ = false; } + void set_cc(int cc) { cc_ = cc; helpers_configured_ = false; } + +#if defined(MATX_EN_MATHDX) && defined(__CUDACC__) + std::string GetSymbolName() const { + std::string symbol_name; + symbol_name += std::to_string(fft_size_x_); + symbol_name += "_"; + symbol_name += std::to_string(fft_size_y_); + symbol_name += "_T"; + symbol_name += std::to_string(static_cast(fft_type_)); + symbol_name += "_D"; + symbol_name += std::to_string(static_cast(direction_)); + symbol_name += "_CC"; + symbol_name += std::to_string(cc_); + +#if defined(CUDA_VERSION) + symbol_name += "_CUDA"; + symbol_name += std::to_string(CUDART_VERSION); +#else + symbol_name += "_CUDAUNKNOWN"; +#endif + + return symbol_name; + } + + template + bool CheckJITSizeAndTypeRequirements() const { + using OpInputType = typename OpType::value_type; + + if (fft_type_ != FFTType::C2C) { + return false; + } + + if ((fft_size_x_ & (fft_size_x_ - 1)) != 0 || fft_size_x_ == 0 || + (fft_size_y_ & (fft_size_y_ - 1)) != 0 || fft_size_y_ == 0) { + return false; + } + + // The single-kernel JIT path uses one CUDA block for a complete 2D tile. + if (fft_size_x_ != fft_size_y_ || (fft_size_x_ * fft_size_y_) > 1024) { + return false; + } + + if constexpr (is_complex_half_v || !is_complex_v) { + return false; + } + + return true; + } + + bool IsSupported() const { + Configure1DHelpers(); + return fft_x_helper_.IsSupported() && fft_y_helper_.IsSupported(); + } + + int GetShmRequired() const { + Configure1DHelpers(); + const auto data_size = static_cast(fft_size_x_) * + static_cast(fft_size_y_) * + static_cast(sizeof(InputType)); + const auto x_shm = static_cast(fft_x_helper_.GetShmRequired()); + const auto y_shm = static_cast(fft_y_helper_.GetShmRequired()); + const auto extra_shm = std::max(0, std::max(x_shm, y_shm) - data_size); + const auto total = data_size * 2 + extra_shm; + MATX_LOG_DEBUG("cuFFTDx 2D shared memory: data={}, scratch={}, extra={}, total={}", + data_size, data_size, extra_shm, total); + return static_cast(total); + } + + int GetBlockDim() const { + Configure1DHelpers(); + const auto block_x = fft_x_helper_.GetBlockDim(); + const auto block_y = fft_y_helper_.GetBlockDim(); + if (block_x != block_y) { + MATX_LOG_DEBUG("cuFFTDx 2D block dims differ: x={}, y={}", block_x, block_y); + return -1; + } + + return block_x; + } + + int GetFFTsPerBlock() const { + return static_cast(fft_size_y_); + } + + ElementsPerThread GetElementsPerThread() const { + const auto block_dim = GetBlockDim(); + if (block_dim <= 0 || fft_size_x_ % block_dim != 0) { + return ElementsPerThread::INVALID; + } + + return IntToElementsPerThread(static_cast(fft_size_x_ / block_dim)); + } + + bool GenerateLTOIR(std::set <oir_symbols) { + Configure1DHelpers(); + return fft_x_helper_.GenerateLTOIR(ltoir_symbols) && + fft_y_helper_.GenerateLTOIR(ltoir_symbols); + } + + std::string GetXFuncName() { + Configure1DHelpers(); + return std::string(FFT_DX_FUNC_PREFIX) + "_" + fft_x_helper_.GetSymbolName(); + } + + std::string GetYFuncName() { + Configure1DHelpers(); + return std::string(FFT_DX_FUNC_PREFIX) + "_" + fft_y_helper_.GetSymbolName(); + } + + std::string GetFuncStr(const std::string &fft_x_func_name, + const std::string &fft_y_func_name, + int fft_norm, + int actual_rank) { + const int fft_forward = (direction_ == FFTDirection::FORWARD) ? 1 : 0; + + std::string result = R"( + using input_type = )"; + result += detail::type_to_string(); + result += R"(; + using input_type_converted = typename detail::convert_matx_type_t; + using precision = typename detail::inner_precision::type; + using ScalarCap = typename CapType::scalar_cap; + [[maybe_unused]] static constexpr int fft_size_x = )"; + result += std::to_string(static_cast(fft_size_x_)); + result += R"(; + [[maybe_unused]] static constexpr int fft_size_y = )"; + result += std::to_string(static_cast(fft_size_y_)); + result += R"(; + [[maybe_unused]] static constexpr int fft_forward = )"; + result += std::to_string(fft_forward); + result += R"(; + [[maybe_unused]] static constexpr int fft_norm = )"; + result += std::to_string(fft_norm); + result += R"(; + [[maybe_unused]] static constexpr int fft_rank = )"; + result += std::to_string(actual_rank); + result += R"(; + [[maybe_unused]] static constexpr int fft_elements = fft_size_x * fft_size_y; + + extern __shared__ __align__(16) unsigned char fft2_smem_raw[]; + auto *fft_data = reinterpret_cast(fft2_smem_raw); + auto *fft_scratch = fft_data + fft_elements; + + const int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; + const int total_threads = blockDim.x * blockDim.y * blockDim.z; + cuda::std::array fft_indices{static_cast(indices)...}; + + for (int elem = tid; elem < fft_elements; elem += total_threads) { + const int row = elem / fft_size_x; + const int col = elem - row * fft_size_x; + fft_indices[fft_rank - 2] = row; + fft_indices[fft_rank - 1] = col; + fft_data[elem] = static_cast( + cuda::std::apply([&](auto... args) { + return a_.template operator()(args...); + }, fft_indices)); + } + + __syncthreads(); + )"; + result += fft_x_func_name; + result += R"((reinterpret_cast(fft_data)); + __syncthreads(); + + for (int elem = tid; elem < fft_elements; elem += total_threads) { + const int row = elem / fft_size_x; + const int col = elem - row * fft_size_x; + fft_scratch[col * fft_size_y + row] = fft_data[row * fft_size_x + col]; + } + + __syncthreads(); + )"; + result += fft_y_func_name; + result += R"((reinterpret_cast(fft_scratch)); + __syncthreads(); + + static constexpr int fft_output_ept = static_cast(CapType::ept); + const int out_row = static_cast(cuda::std::get(cuda::std::make_tuple(indices...))); + const int out_col_base = static_cast(cuda::std::get(cuda::std::make_tuple(indices...))) * fft_output_ept; + + if constexpr (CapType::ept == ElementsPerThread::ONE) { + input_type_converted result = fft_scratch[out_col_base * fft_size_y + out_row]; + if constexpr (fft_norm == 2) { + result = result * static_cast(1.f) / static_cast(cuda::std::sqrt(static_cast(fft_elements))); + } + else if constexpr ((fft_norm == 1 && fft_forward) || (fft_norm == 0 && !fft_forward)) { + result = result * static_cast(1.f) / static_cast(fft_elements); + } + + return static_cast(result); + } + else { + Vector result_vec; + #pragma unroll + for (int i = 0; i < fft_output_ept; i++) { + const int out_col = out_col_base + i; + input_type_converted result = input_type_converted{}; + if (out_col < fft_size_x) { + result = fft_scratch[out_col * fft_size_y + out_row]; + if constexpr (fft_norm == 2) { + result = result * static_cast(1.f) / static_cast(cuda::std::sqrt(static_cast(fft_elements))); + } + else if constexpr ((fft_norm == 1 && fft_forward) || (fft_norm == 0 && !fft_forward)) { + result = result * static_cast(1.f) / static_cast(fft_elements); + } + } + result_vec.data[i] = static_cast(result); + } + + return result_vec; + } + )"; + + return result; + } +#endif + }; + } // namespace detail } // namespace matx diff --git a/test/00_transform/FFT.cu b/test/00_transform/FFT.cu index 796218d14..294181aff 100644 --- a/test/00_transform/FFT.cu +++ b/test/00_transform/FFT.cu @@ -919,6 +919,175 @@ TYPED_TEST(FFTTestComplexTypes, FFT2D16R2C) MATX_EXIT_HANDLER(); } +#if defined(MATX_EN_JIT) && defined(MATX_EN_MATHDX) +namespace { +using FFT2JITComplex = cuda::std::complex; + +void FillFFT2JITInput2D(auto &in, index_t fft_dim) +{ + for (index_t row = 0; row < fft_dim; row++) { + for (index_t col = 0; col < fft_dim; col++) { + in(row, col) = FFT2JITComplex{ + static_cast((row + 1) * (col + 2) % 17), + static_cast((2 * row + col) % 11)}; + } + } +} + +void FillFFT2JITInput3D(auto &in, index_t batches, index_t fft_dim) +{ + for (index_t batch = 0; batch < batches; batch++) { + for (index_t row = 0; row < fft_dim; row++) { + for (index_t col = 0; col < fft_dim; col++) { + in(batch, row, col) = FFT2JITComplex{ + static_cast((batch + 1) * (row + 2) + col), + static_cast((batch + row + 2 * col) % 13)}; + } + } + } +} + +void AssertFFT2JITClose2D(auto &jit_out, auto &ref_out, index_t fft_dim, float tol) +{ + for (index_t row = 0; row < fft_dim; row++) { + for (index_t col = 0; col < fft_dim; col++) { + ASSERT_NEAR(jit_out(row, col).real(), ref_out(row, col).real(), tol); + ASSERT_NEAR(jit_out(row, col).imag(), ref_out(row, col).imag(), tol); + } + } +} + +void AssertFFT2JITClose3D(auto &jit_out, auto &ref_out, index_t batches, index_t fft_dim, float tol) +{ + for (index_t batch = 0; batch < batches; batch++) { + for (index_t row = 0; row < fft_dim; row++) { + for (index_t col = 0; col < fft_dim; col++) { + ASSERT_NEAR(jit_out(batch, row, col).real(), ref_out(batch, row, col).real(), tol); + ASSERT_NEAR(jit_out(batch, row, col).imag(), ref_out(batch, row, col).imag(), tol); + } + } + } +} +} // namespace + +TEST(FFTJIT, CuFFTDx2DFFT2Fusion) +{ + MATX_ENTER_HANDLER(); + + constexpr index_t fft_dim = 4; + + auto in = make_tensor({fft_dim, fft_dim}); + auto jit_out = make_tensor({fft_dim, fft_dim}); + auto ref_out = make_tensor({fft_dim, fft_dim}); + + FillFFT2JITInput2D(in, fft_dim); + + auto expr = fft2(in) + FFT2JITComplex{1.0f, -1.0f}; + if (!jit_supported(expr)) { + GTEST_SKIP(); + } + + CUDAJITExecutor jit_exec{}; + cudaExecutor cuda_exec{}; + (jit_out = expr).run(jit_exec); + (ref_out = fft2(in) + FFT2JITComplex{1.0f, -1.0f}).run(cuda_exec); + jit_exec.sync(); + cuda_exec.sync(); + + AssertFFT2JITClose2D(jit_out, ref_out, fft_dim, 0.01f); + + MATX_EXIT_HANDLER(); +} + +TEST(FFTJIT, CuFFTDx2DIFFT2Fusion) +{ + MATX_ENTER_HANDLER(); + + constexpr index_t fft_dim = 4; + + auto in = make_tensor({fft_dim, fft_dim}); + auto jit_out = make_tensor({fft_dim, fft_dim}); + auto ref_out = make_tensor({fft_dim, fft_dim}); + + FillFFT2JITInput2D(in, fft_dim); + + auto expr = ifft2(in, FFTNorm::BACKWARD) * FFT2JITComplex{0.5f, 0.0f}; + if (!jit_supported(expr)) { + GTEST_SKIP(); + } + + CUDAJITExecutor jit_exec{}; + cudaExecutor cuda_exec{}; + (jit_out = expr).run(jit_exec); + (ref_out = ifft2(in, FFTNorm::BACKWARD) * FFT2JITComplex{0.5f, 0.0f}).run(cuda_exec); + jit_exec.sync(); + cuda_exec.sync(); + + AssertFFT2JITClose2D(jit_out, ref_out, fft_dim, 0.01f); + + MATX_EXIT_HANDLER(); +} + +TEST(FFTJIT, CuFFTDx2DFFT2OrthoBoundary) +{ + MATX_ENTER_HANDLER(); + + constexpr index_t fft_dim = 32; + + auto in = make_tensor({fft_dim, fft_dim}); + auto jit_out = make_tensor({fft_dim, fft_dim}); + auto ref_out = make_tensor({fft_dim, fft_dim}); + + FillFFT2JITInput2D(in, fft_dim); + + auto expr = fft2(in, FFTNorm::ORTHO) - FFT2JITComplex{0.25f, 0.75f}; + if (!jit_supported(expr)) { + GTEST_SKIP(); + } + + CUDAJITExecutor jit_exec{}; + cudaExecutor cuda_exec{}; + (jit_out = expr).run(jit_exec); + (ref_out = fft2(in, FFTNorm::ORTHO) - FFT2JITComplex{0.25f, 0.75f}).run(cuda_exec); + jit_exec.sync(); + cuda_exec.sync(); + + AssertFFT2JITClose2D(jit_out, ref_out, fft_dim, 0.05f); + + MATX_EXIT_HANDLER(); +} + +TEST(FFTJIT, CuFFTDx2DFFT2BatchedForwardNorm) +{ + MATX_ENTER_HANDLER(); + + constexpr index_t batches = 3; + constexpr index_t fft_dim = 8; + + auto in = make_tensor({batches, fft_dim, fft_dim}); + auto jit_out = make_tensor({batches, fft_dim, fft_dim}); + auto ref_out = make_tensor({batches, fft_dim, fft_dim}); + + FillFFT2JITInput3D(in, batches, fft_dim); + + auto expr = fft2(in, FFTNorm::FORWARD) + FFT2JITComplex{0.125f, -0.5f}; + if (!jit_supported(expr)) { + GTEST_SKIP(); + } + + CUDAJITExecutor jit_exec{}; + cudaExecutor cuda_exec{}; + (jit_out = expr).run(jit_exec); + (ref_out = fft2(in, FFTNorm::FORWARD) + FFT2JITComplex{0.125f, -0.5f}).run(cuda_exec); + jit_exec.sync(); + cuda_exec.sync(); + + AssertFFT2JITClose3D(jit_out, ref_out, batches, fft_dim, 0.01f); + + MATX_EXIT_HANDLER(); +} +#endif + TYPED_TEST(FFTTestComplexTypes, FFT2D16x32R2C) { MATX_ENTER_HANDLER(); @@ -1023,4 +1192,4 @@ TYPED_TEST(FFTTestComplexNonHalfTypesAllExecs, IFFT1D1024C2CShort) MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh); MATX_EXIT_HANDLER(); -} \ No newline at end of file +}