|
1 | 1 | """ |
2 | 2 | Image Generation and Sampling Module |
3 | | -Handles the core image generation, decoding, and batch management |
| 3 | +Handles the core image generation, decoding, and batch management. |
| 4 | +
|
| 5 | +============================================================================ |
| 6 | +🚨 CRITICAL: torch.inference_mode() WRAPPING — DO NOT REMOVE 🚨 |
| 7 | +============================================================================ |
| 8 | +Every top-level GPU-inference entry point in this module is decorated with |
| 9 | +`@torch.inference_mode()`. This mirrors ComfyUI's PromptExecutor which wraps |
| 10 | +every node execution in inference_mode (ComfyUI/execution.py:732). |
| 11 | +
|
| 12 | +WHY THIS MATTERS: |
| 13 | + We are an out-of-graph caller — we invoke ComfyUI's sampling / VAE / upscale |
| 14 | + models directly from a background thread (the dashboard upscale runner) and |
| 15 | + from our orchestrator, bypassing the normal prompt executor. Without |
| 16 | + inference_mode wrapping, autograd's version-counter and reference-keeping |
| 17 | + machinery stays active. For diffusion sampling, beam-search generation, |
| 18 | + VAE encode/decode, and upscale-model inference, that means intermediate |
| 19 | + activations and past_key_values can't be released — measured as a 6× VRAM |
| 20 | + blowup on Florence2 (29GB on a 16GB card vs <5GB in the standalone workflow |
| 21 | + on the same image, May 2026). Same pattern applies to SDXL sampling, |
| 22 | + SeedVR2 diffusion, and ESRGAN-style upscalers, just at a smaller magnitude. |
| 23 | +
|
| 24 | + torch.inference_mode() is STRICTER than torch.no_grad(): |
| 25 | + - no_grad: disables gradient tracking, but version counters and autograd |
| 26 | + machinery stay partly active |
| 27 | + - inference_mode: tensors are entirely outside autograd, true read-only — |
| 28 | + this is what lets HF generate() actually release KV cache between beam |
| 29 | + steps and what lets ComfyUI's sample loop reclaim activations between |
| 30 | + diffusion steps |
| 31 | +
|
| 32 | +WHERE IT'S APPLIED: |
| 33 | + - @torch.inference_mode() on generate_image, upscale_image, |
| 34 | + decode_latent_with_vae, seedvr2_upscale (this file) |
| 35 | + - Inside _call_node() in ltx_video_generation.py (covers all 25 ComfyUI |
| 36 | + node invocations across florence2_hires.py + ltx_video_generation.py) |
| 37 | +
|
| 38 | +IF YOU ADD A NEW INFERENCE ENTRY POINT: |
| 39 | + Wrap it in @torch.inference_mode() (decorator) or |
| 40 | + `with torch.inference_mode():` (context). Symptom of missing it: GPU OOM |
| 41 | + on workloads that work fine in a standalone ComfyUI workflow, or sudden |
| 42 | + VRAM growth that doesn't match the model size. See |
| 43 | + ComfyUI/execution.py:732 for the canonical pattern. |
| 44 | +============================================================================ |
4 | 45 | """ |
5 | 46 |
|
6 | 47 | import time |
|
13 | 54 | from PIL import Image |
14 | 55 |
|
15 | 56 |
|
| 57 | +# NOTE: torch.inference_mode() decorator — see module docstring above for why |
| 58 | +# this is REQUIRED on all out-of-graph inference entry points. Removing it |
| 59 | +# will reintroduce the same 6× VRAM blowup we fixed in May 2026. |
| 60 | +@torch.inference_mode() |
16 | 61 | def generate_image( |
17 | 62 | patched_model, |
18 | 63 | seed, |
@@ -478,6 +523,9 @@ def calc_tiles(total, tile_size, padding, uniform): |
478 | 523 | return {"samples": result_samples} |
479 | 524 |
|
480 | 525 |
|
| 526 | +# @torch.inference_mode() — see module docstring. Covers KSampler hires fix, |
| 527 | +# ImageUpscaleWithModel (ESRGAN-style), and VAE encode/decode in this function. |
| 528 | +@torch.inference_mode() |
481 | 529 | def upscale_image(result_latent, vae, patched_model, upscaling_config, config, positive_conditioning, negative_conditioning, width, height): |
482 | 530 | """ |
483 | 531 | Apply upscaling to a generated latent based on upscaling settings. |
@@ -668,6 +716,9 @@ def upscale_image(result_latent, vae, patched_model, upscaling_config, config, p |
668 | 716 | return result_latent, 0 |
669 | 717 |
|
670 | 718 |
|
| 719 | +# @torch.inference_mode() — see module docstring. VAE decode is a model forward |
| 720 | +# pass and needs inference_mode to avoid pinning intermediate activations. |
| 721 | +@torch.inference_mode() |
671 | 722 | def decode_latent_with_vae(vae, latent_samples): |
672 | 723 | """ |
673 | 724 | Decode latent samples to pixel space using VAE. |
@@ -699,9 +750,10 @@ def decode_latent_with_vae(vae, latent_samples): |
699 | 750 | print(f"[GridTester] ✅ float32 retry succeeded") |
700 | 751 |
|
701 | 752 | # Convert to PIL Image |
702 | | - # .detach() is required because the tensor may have requires_grad=True |
703 | | - # (e.g., when called from distributed worker threads outside ComfyUI's |
704 | | - # normal execution context where autograd state may differ) |
| 753 | + # .detach() is a defensive no-op under @torch.inference_mode() (tensors are |
| 754 | + # already non-grad). Kept for safety in case this function is ever called |
| 755 | + # without the decorator, but inference_mode is the primary defense — see |
| 756 | + # module docstring for why. |
705 | 757 | img_np = decoded.detach().cpu().float().numpy() |
706 | 758 |
|
707 | 759 | # Remove extra dimensions (handle shapes like (1, 1, H, W, C) or (1, H, W, C)) |
@@ -1027,6 +1079,11 @@ def flush_batch_with_remote_vae(pending_batch, remote_vae_worker, existing_data, |
1027 | 1079 | # Requires ComfyUI-SeedVR2_VideoUpscaler to be installed as a dependency. |
1028 | 1080 | # ============================================================================= |
1029 | 1081 |
|
| 1082 | +# @torch.inference_mode() — see module docstring. SeedVR2 is a diffusion-based |
| 1083 | +# upscaler with its own iterative sampling loop, identical autograd concerns |
| 1084 | +# to SDXL KSampler. The CurrentNodeContext below sets up V3-API execution but |
| 1085 | +# does NOT include inference_mode — that's our job at the function boundary. |
| 1086 | +@torch.inference_mode() |
1030 | 1087 | def seedvr2_upscale(pil_image, seedvr2_config): |
1031 | 1088 | """ |
1032 | 1089 | Upscale an image using SeedVR2 diffusion-based upscaler. |
|
0 commit comments