[None][feat] VisualGen: enable CUDA graph capture with torch.compile#15603
[None][feat] VisualGen: enable CUDA graph capture with torch.compile#15603chang-l wants to merge 1 commit into
Conversation
The base pipeline's _setup_cuda_graphs() skipped CUDA graph capture entirely whenever torch.compile was enabled, logging "CUDA graphs with torch.compile not yet supported." Because torch.compile defaults on, opting into cuda_graph alongside it silently did nothing. The two compose: the CUDA graph runner wraps the outer transformer forward while torch.compile compiles the inner transformer blocks (the per-block path used by all VisualGen transformers). Graph capture runs during warmup, after the runner's own WARMUP_STEPS eager iterations have already triggered torch.compile's lazy compilation, so the captured graph holds the optimized compiled kernels. LTX2Pipeline already overrides _setup_cuda_graphs() this way; this brings the base class in line. CUDA graph remains opt-in (CudaGraphConfig.enable defaults to False). Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
📝 WalkthroughWalkthrough
ChangesCUDA graph setup with torch.compile
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
Verified on B200 (Qwen-Image)Ran on a B200 in the Stock rc19 — both With this PR — both enabled:
No errors / no illegal-memory access. The generated image is byte-identical to the compile-only baseline (matching MD5, PSNR ∞, max pixel diff 0) — CUDA-graph replay of the torch.compiled kernels is numerically exact. This is the same ordering |
|
/bot run |
|
PR_Github #55622 [ run ] triggered by Bot. Commit: |
|
PR_Github #55622 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #55675 [ run ] triggered by Bot. Commit: |
|
PR_Github #55675 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #55759 [ run ] triggered by Bot. Commit: |
Description
BasePipeline._setup_cuda_graphs()skipped CUDA graph capture entirely whenevertorch.compilewas enabled, logging "CUDA graphs with torch.compile not yetsupported. Using torch.compile only." Since
torch.compileis on by default forVisualGen, opting into
cuda_graph_config.enable=Truealongside it silently didnothing.
The two actually compose. The CUDA graph runner wraps the outer transformer
forward, whiletorch.compilecompiles the inner transformer blocks (theper-block path taken by every VisualGen transformer — WAN, FLUX/FLUX2, Cosmos3,
Qwen-Image, LTX-2). Graph capture happens during
warmup(), after the runner'sown
WARMUP_STEPSeager iterations have already triggeredtorch.compile's lazycompilation — so the captured graph contains the optimized compiled kernels.
This is exactly the pattern
LTX2Pipelinealready implements in its_setup_cuda_graphs()override; this PR brings the base class in line by removingthe stale early-return.
CUDA graph remains opt-in —
CudaGraphConfig.enablestill defaults toFalse,so the default
torch.compile-only path is unchanged. Users now get bothoptimizations together when they set
cuda_graph_config.enable=True.Test Coverage
tests/unittest/_torch/visual_gen/exercise the
torch_compileandcuda_graphpaths.LTX2Pipelinealready ships the compile + CUDA-graph composition via its own_setup_cuda_graphs()override, validating the ordering this PR adopts in thebase class.
release:1.3.0rc19container) with Qwen-Image,both
cuda_graph_config.enable=Trueandtorch_compile_config.enable=True:the CUDA graph is captured over the torch.compiled blocks, and the generated
image is byte-identical to the compile-only baseline (matching MD5, PSNR ∞).
See the verification comment below for logs.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why.
PR Follows TRT-LLM CODING GUIDELINES to the best of my knowledge.
No API changes (no
api-compatible/api-breakinglabel needed).No new dependencies.
No CODEOWNERS / tava diagram changes required.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.🤖 Generated with Claude Code