Skip to content

Add cuFFTDx-backed FFT2 JIT support#1189

Open
cliffburdick wants to merge 6 commits into
mainfrom
cburdick/fft2-jit-cufftdx
Open

Add cuFFTDx-backed FFT2 JIT support#1189
cliffburdick wants to merge 6 commits into
mainfrom
cburdick/fft2-jit-cufftdx

Conversation

@cliffburdick
Copy link
Copy Markdown
Collaborator

Generate JIT classes and LTO IR for single-block C2C fft2/ifft2 fusions, including shared-memory tiling through cuFFTDx 1D passes.

Teach the JIT launcher about grouped 2D blocks and vectorized EPT indexing so FFT2 operators can return multiple columns per thread.

Document the supported FFT2 JIT shape/type limits and add forward/inverse FFT2 JIT fusion coverage.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 26, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Generate JIT classes and LTO IR for single-block C2C fft2/ifft2 fusions, including shared-memory tiling through cuFFTDx 1D passes.

Teach the JIT launcher about grouped 2D blocks and vectorized EPT indexing so FFT2 operators can return multiple columns per thread.

Document the supported FFT2 JIT shape/type limits and add forward/inverse FFT2 JIT fusion coverage.
@cliffburdick cliffburdick force-pushed the cburdick/fft2-jit-cufftdx branch from 49912be to 36d180a Compare May 26, 2026 20:41
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 26, 2026

Greptile Summary

This PR adds single-block complex-to-complex 2D FFT JIT fusion backed by cuFFTDx, restricted to square power-of-two transforms whose total element count fits in one CUDA block. The implementation decomposes the 2D FFT into two cooperative 1D cuFFTDx passes with a shared-memory transpose in between, and extends the JIT executor and Block2D kernels to support vectorized EPT.

  • cuFFTDx2DHelper lazily configures two cuFFTDxHelper instances, queries block-dim/EPT/shm from cuFFTDx, generates both LTO-IR functions, and emits the JIT string for the cooperative load/x-FFT/transpose/y-FFT/output flow.
  • FFT2Op::get_capability wires in all required JIT capabilities and falls back to the cuFFT path for unsupported shapes; PASS_THROUGH_INNER_RANK now explicitly returns 2.
  • Block2D kernels (ranks 2-4) move to vector column indexing, and jit_cuda.h selects jit_ept_bounds[0] so existing 1D cuFFTDx operators retain EPT=ONE.

Confidence Score: 4/5

The change is mechanically sound and the cooperative FFT algorithm is correct, but it introduces implicit invariants in the Block2D kernel that are not enforced in code.

The core 2D FFT algorithm is correct for all supported square power-of-two sizes. The JIT-compiled operator body contains __syncthreads() calls inside operator(), invoked once per thread from within the if-guard in the outer Block2D kernels. The block is sized to exactly tile the output so all threads satisfy the guard today, but this invariant is undocumented and fragile under future changes.

include/matx/transforms/fft/fft_cufftdx.h: the cooperative execution contract, x-pass scratch aliasing, GetShmRequired cast, and square-only assumption in GetFFTsPerBlock are all implicit.

Important Files Changed

Filename Overview
include/matx/transforms/fft/fft_cufftdx.h Adds cuFFTDx2DHelper with lazy-init 1D sub-helpers, shm allocation, block-dim/EPT queries, LTO-IR generation, and cooperative 2D FFT JIT codegen. Core algorithm is correct; implicit invariants and int truncation need guards.
include/matx/operators/fft.h Adds full JIT capability protocol to FFT2Op including get_jit_op_str, get_jit_class_name, all get_capability overrides, JIT_Storage, and dynamic_tensor_expr. Normalization and output indexing are correct.
include/matx/executors/jit_cuda.h Changes pass-through EPT from hardcoded ONE to jit_ept_bounds[0] and queries GROUPS_PER_BLOCK. Lower bound preserves ONE for default-range operators while picking the fixed EPT advertised by FFT2.
include/matx/executors/jit_kernel.h Block2D kernels updated for EPT>1 via vectorized column indexing and corrected bounds guards. Safe for FFT2 since all threads stay within bounds.
include/matx/core/get_grid_dims.h Adds groups_per_block to get_grid_dims_block_2d and get_grid_dims_block_pass_through with a 1024-thread guard. Default behavior unchanged.
test/00_transform/FFT.cu Adds four JIT 2D FFT tests covering forward/inverse, all three normalization modes, and batched 3D tensors.

Reviews (12): Last reviewed commit: "address greptile review feedback (greplo..." | Re-trigger Greptile

Comment thread include/matx/executors/jit_cuda.h Outdated
Comment thread include/matx/operators/fft.h
Comment thread include/matx/transforms/fft/fft_cufftdx.h Outdated
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

1 similar comment
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review latest commit 325d394

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review latest commit 14e39b37d

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review latest commit 14e39b3. The previous summary says get_jit_class_name() omits normalization, but this commit includes symbol_name += "N" and symbol_name += std::to_string(static_cast(norm)) in include/matx/operators/fft.h, and JIT_CACHE_KEY hashes norm_. Please re-evaluate.

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

/build

@coveralls
Copy link
Copy Markdown

Coverage Status

Coverage is 93.523%cburdick/fft2-jit-cufftdx into main. No base build found for main.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants