Add cuFFTDx-backed FFT2 JIT support#1189
Conversation
Generate JIT classes and LTO IR for single-block C2C fft2/ifft2 fusions, including shared-memory tiling through cuFFTDx 1D passes. Teach the JIT launcher about grouped 2D blocks and vectorized EPT indexing so FFT2 operators can return multiple columns per thread. Document the supported FFT2 JIT shape/type limits and add forward/inverse FFT2 JIT fusion coverage.
49912be to
36d180a
Compare
Greptile SummaryThis PR adds single-block complex-to-complex 2D FFT JIT fusion backed by cuFFTDx, restricted to square power-of-two transforms whose total element count fits in one CUDA block. The implementation decomposes the 2D FFT into two cooperative 1D cuFFTDx passes with a shared-memory transpose in between, and extends the JIT executor and Block2D kernels to support vectorized EPT.
Confidence Score: 4/5The change is mechanically sound and the cooperative FFT algorithm is correct, but it introduces implicit invariants in the Block2D kernel that are not enforced in code. The core 2D FFT algorithm is correct for all supported square power-of-two sizes. The JIT-compiled operator body contains __syncthreads() calls inside operator(), invoked once per thread from within the if-guard in the outer Block2D kernels. The block is sized to exactly tile the output so all threads satisfy the guard today, but this invariant is undocumented and fragile under future changes. include/matx/transforms/fft/fft_cufftdx.h: the cooperative execution contract, x-pass scratch aliasing, GetShmRequired cast, and square-only assumption in GetFFTsPerBlock are all implicit. Important Files Changed
Reviews (12): Last reviewed commit: "address greptile review feedback (greplo..." | Re-trigger Greptile |
|
@greptile review |
1 similar comment
|
@greptile review |
|
@greptile review |
|
@greptile review |
|
@greptile review |
|
@greptile review latest commit 325d394 |
|
@greptile review latest commit 14e39b37d |
|
@greptile review latest commit 14e39b3. The previous summary says get_jit_class_name() omits normalization, but this commit includes symbol_name += "N" and symbol_name += std::to_string(static_cast(norm)) in include/matx/operators/fft.h, and JIT_CACHE_KEY hashes norm_. Please re-evaluate. |
|
/build |
Generate JIT classes and LTO IR for single-block C2C fft2/ifft2 fusions, including shared-memory tiling through cuFFTDx 1D passes.
Teach the JIT launcher about grouped 2D blocks and vectorized EPT indexing so FFT2 operators can return multiple columns per thread.
Document the supported FFT2 JIT shape/type limits and add forward/inverse FFT2 JIT fusion coverage.