Skip to content

Commit 55ac08a

Browse files
committed
perf(inference_mode): wrap all diffusion/VAE/upscale entry points + prominent docs
Extends the May 2026 Florence2 fix (29GB -> <5GB) to every other out-of-graph inference entry point in the project. Same root cause everywhere: we call ComfyUI models from background threads / our orchestrator, bypassing ComfyUI/execution.py:732 which wraps every node in torch.inference_mode(). Decorated with @torch.inference_mode() in image_generation.py: - generate_image() — SDXL/diffusion sampling (nodes.common_ksampler) - upscale_image() — KSampler hires fix + ImageUpscaleWithModel (ESRGAN-style) + VAE encode/decode - decode_latent_with_vae() — VAE decode (with NaN float32 retry) - seedvr2_upscale() — SeedVR2 diffusion-based video upscaler (DiT + VAE + sampling loop) Plus prominent module-docstring warning blocks in: - image_generation.py: full "why this matters / where / what to do if you add a new entry point" explanation - ltx_video_generation.py: warning that _call_node()'s inference_mode wrap is load-bearing and must not be removed Inline comment above each decorated function points back to the module docstring so anyone editing one function still sees the rationale. Updated the stale comment in decode_latent_with_vae explaining .detach() (was 'tensor may have requires_grad=True' — under inference_mode it can't, so .detach() is now a defensive no-op; kept for double-safety). Tests: 118/118 pass.
1 parent d0729c4 commit 55ac08a

2 files changed

Lines changed: 86 additions & 4 deletions

File tree

image_generation.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,47 @@
11
"""
22
Image Generation and Sampling Module
3-
Handles the core image generation, decoding, and batch management
3+
Handles the core image generation, decoding, and batch management.
4+
5+
============================================================================
6+
🚨 CRITICAL: torch.inference_mode() WRAPPING — DO NOT REMOVE 🚨
7+
============================================================================
8+
Every top-level GPU-inference entry point in this module is decorated with
9+
`@torch.inference_mode()`. This mirrors ComfyUI's PromptExecutor which wraps
10+
every node execution in inference_mode (ComfyUI/execution.py:732).
11+
12+
WHY THIS MATTERS:
13+
We are an out-of-graph caller — we invoke ComfyUI's sampling / VAE / upscale
14+
models directly from a background thread (the dashboard upscale runner) and
15+
from our orchestrator, bypassing the normal prompt executor. Without
16+
inference_mode wrapping, autograd's version-counter and reference-keeping
17+
machinery stays active. For diffusion sampling, beam-search generation,
18+
VAE encode/decode, and upscale-model inference, that means intermediate
19+
activations and past_key_values can't be released — measured as a 6× VRAM
20+
blowup on Florence2 (29GB on a 16GB card vs <5GB in the standalone workflow
21+
on the same image, May 2026). Same pattern applies to SDXL sampling,
22+
SeedVR2 diffusion, and ESRGAN-style upscalers, just at a smaller magnitude.
23+
24+
torch.inference_mode() is STRICTER than torch.no_grad():
25+
- no_grad: disables gradient tracking, but version counters and autograd
26+
machinery stay partly active
27+
- inference_mode: tensors are entirely outside autograd, true read-only —
28+
this is what lets HF generate() actually release KV cache between beam
29+
steps and what lets ComfyUI's sample loop reclaim activations between
30+
diffusion steps
31+
32+
WHERE IT'S APPLIED:
33+
- @torch.inference_mode() on generate_image, upscale_image,
34+
decode_latent_with_vae, seedvr2_upscale (this file)
35+
- Inside _call_node() in ltx_video_generation.py (covers all 25 ComfyUI
36+
node invocations across florence2_hires.py + ltx_video_generation.py)
37+
38+
IF YOU ADD A NEW INFERENCE ENTRY POINT:
39+
Wrap it in @torch.inference_mode() (decorator) or
40+
`with torch.inference_mode():` (context). Symptom of missing it: GPU OOM
41+
on workloads that work fine in a standalone ComfyUI workflow, or sudden
42+
VRAM growth that doesn't match the model size. See
43+
ComfyUI/execution.py:732 for the canonical pattern.
44+
============================================================================
445
"""
546

647
import time
@@ -13,6 +54,10 @@
1354
from PIL import Image
1455

1556

57+
# NOTE: torch.inference_mode() decorator — see module docstring above for why
58+
# this is REQUIRED on all out-of-graph inference entry points. Removing it
59+
# will reintroduce the same 6× VRAM blowup we fixed in May 2026.
60+
@torch.inference_mode()
1661
def generate_image(
1762
patched_model,
1863
seed,
@@ -478,6 +523,9 @@ def calc_tiles(total, tile_size, padding, uniform):
478523
return {"samples": result_samples}
479524

480525

526+
# @torch.inference_mode() — see module docstring. Covers KSampler hires fix,
527+
# ImageUpscaleWithModel (ESRGAN-style), and VAE encode/decode in this function.
528+
@torch.inference_mode()
481529
def upscale_image(result_latent, vae, patched_model, upscaling_config, config, positive_conditioning, negative_conditioning, width, height):
482530
"""
483531
Apply upscaling to a generated latent based on upscaling settings.
@@ -668,6 +716,9 @@ def upscale_image(result_latent, vae, patched_model, upscaling_config, config, p
668716
return result_latent, 0
669717

670718

719+
# @torch.inference_mode() — see module docstring. VAE decode is a model forward
720+
# pass and needs inference_mode to avoid pinning intermediate activations.
721+
@torch.inference_mode()
671722
def decode_latent_with_vae(vae, latent_samples):
672723
"""
673724
Decode latent samples to pixel space using VAE.
@@ -699,9 +750,10 @@ def decode_latent_with_vae(vae, latent_samples):
699750
print(f"[GridTester] ✅ float32 retry succeeded")
700751

701752
# Convert to PIL Image
702-
# .detach() is required because the tensor may have requires_grad=True
703-
# (e.g., when called from distributed worker threads outside ComfyUI's
704-
# normal execution context where autograd state may differ)
753+
# .detach() is a defensive no-op under @torch.inference_mode() (tensors are
754+
# already non-grad). Kept for safety in case this function is ever called
755+
# without the decorator, but inference_mode is the primary defense — see
756+
# module docstring for why.
705757
img_np = decoded.detach().cpu().float().numpy()
706758

707759
# Remove extra dimensions (handle shapes like (1, 1, H, W, C) or (1, H, W, C))
@@ -1027,6 +1079,11 @@ def flush_batch_with_remote_vae(pending_batch, remote_vae_worker, existing_data,
10271079
# Requires ComfyUI-SeedVR2_VideoUpscaler to be installed as a dependency.
10281080
# =============================================================================
10291081

1082+
# @torch.inference_mode() — see module docstring. SeedVR2 is a diffusion-based
1083+
# upscaler with its own iterative sampling loop, identical autograd concerns
1084+
# to SDXL KSampler. The CurrentNodeContext below sets up V3-API execution but
1085+
# does NOT include inference_mode — that's our job at the function boundary.
1086+
@torch.inference_mode()
10301087
def seedvr2_upscale(pil_image, seedvr2_config):
10311088
"""
10321089
Upscale an image using SeedVR2 diffusion-based upscaler.

ltx_video_generation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,31 @@
22
LTX 2.3 Video Generation Module
33
Two-stage SamplerCustomAdvanced pipeline with parallel audio rail.
44
5+
============================================================================
6+
🚨 CRITICAL: _call_node() WRAPS EVERY INVOCATION IN torch.inference_mode() 🚨
7+
============================================================================
8+
The _call_node() helper in THIS file is what every ComfyUI node call in USCG
9+
goes through (florence2_hires.py and this file's 21 LTX node invocations).
10+
That helper applies `torch.inference_mode()` around the underlying execute()
11+
or FUNCTION call — mirroring ComfyUI/execution.py:732 which wraps every node
12+
in the prompt executor's inference_mode block.
13+
14+
DO NOT remove the inference_mode wrap from _call_node. Without it the
15+
SamplerCustomAdvanced stage 1+2, VAEDecodeTiled, LTXVLatentUpsampler, CLIP
16+
encode calls, and all other model-forward nodes leak intermediate
17+
activations across diffusion steps. We measured a 6× VRAM blowup on
18+
Florence2 from this exact missing wrapper (29GB on a 16GB card vs <5GB
19+
in the standalone workflow on the same image, May 2026). Same root cause
20+
applies to every other model-inference path that doesn't go through
21+
ComfyUI's prompt executor.
22+
23+
If you add new node invocations that DON'T go through _call_node (e.g.,
24+
direct `.execute()` or `instance.FUNCTION()` calls), wrap them yourself
25+
in `with torch.inference_mode():` — or better, route them through
26+
_call_node so the wrapper is automatic.
27+
============================================================================
28+
29+
530
Pinned LTX node pack version: TBD — set during first smoke test.
631
Required nodes (looked up via nodes.NODE_CLASS_MAPPINGS):
732
- DiffusionModelLoaderKJ

0 commit comments

Comments
 (0)