diff --git a/.codex b/.codex new file mode 100644 index 00000000..e69de29b diff --git a/inference_cli.py b/inference_cli.py index 2d4fff18..5e1fec4a 100644 --- a/inference_cli.py +++ b/inference_cli.py @@ -130,8 +130,19 @@ decode_all_batches, postprocess_all_batches ) +from src.core.model_configuration import ( + _evict_claimed_cached_models, + _finalize_claimed_cached_models_for_reuse, +) from src.utils.debug import Debug -from src.optimization.memory_manager import clear_memory, get_gpu_backend, is_cuda_available +from src.optimization.memory_manager import ( + cleanup_text_embeddings, + clear_memory, + complete_cleanup, + get_gpu_backend, + is_cuda_available, + set_model_cache_claimed_state, +) debug = Debug(enabled=False) # Will be enabled via --debug CLI flag @@ -913,103 +924,220 @@ def _process_frames_core( dit_id = "cli_dit" if cache_dit else None vae_id = "cli_vae" if cache_vae else None - runner, cache_context = prepare_runner( - dit_model=args.dit_model, - vae_model=DEFAULT_VAE, - model_dir=model_dir, - debug=debug, - ctx=ctx, - dit_cache=cache_dit, - vae_cache=cache_vae, - dit_id=dit_id, - vae_id=vae_id, - block_swap_config={ - 'blocks_to_swap': args.blocks_to_swap, - 'swap_io_components': args.swap_io_components, - 'offload_device': dit_offload, - }, - encode_tiled=args.vae_encode_tiled, - encode_tile_size=(args.vae_encode_tile_size, args.vae_encode_tile_size), - encode_tile_overlap=(args.vae_encode_tile_overlap, args.vae_encode_tile_overlap), - decode_tiled=args.vae_decode_tiled, - decode_tile_size=(args.vae_decode_tile_size, args.vae_decode_tile_size), - decode_tile_overlap=(args.vae_decode_tile_overlap, args.vae_decode_tile_overlap), - tile_debug=args.tile_debug.lower() if args.tile_debug else "false", - attention_mode=args.attention_mode, - torch_compile_args_dit=torch_compile_args_dit, - torch_compile_args_vae=torch_compile_args_vae - ) - - ctx['cache_context'] = cache_context - if runner_cache is not None: - runner_cache['runner'] = runner - - # Preload text embeddings before Phase 1 to avoid sync stall in Phase 2 - ctx['text_embeds'] = load_text_embeddings(script_directory, ctx['dit_device'], ctx['compute_dtype'], debug) - debug.log("Loaded text embeddings for DiT", category="dit") - - # Compute generation info and log start (handles prepending internally) - frames_tensor, gen_info = compute_generation_info( - ctx=ctx, - images=frames_tensor, - resolution=args.resolution, - max_resolution=args.max_resolution, - batch_size=args.batch_size, - uniform_batch_size=args.uniform_batch_size, - seed=args.seed, - prepend_frames=args.prepend_frames, - temporal_overlap=args.temporal_overlap, - debug=debug - ) - log_generation_start(gen_info, debug) - - # Phase 1: Encode - ctx = encode_all_batches( - runner, ctx=ctx, images=frames_tensor, - debug=debug, - batch_size=args.batch_size, - uniform_batch_size=args.uniform_batch_size, - seed=args.seed, - progress_callback=None, - temporal_overlap=args.temporal_overlap, - resolution=args.resolution, - max_resolution=args.max_resolution, - input_noise_scale=args.input_noise_scale, - color_correction=args.color_correction - ) - - # Phase 2: Upscale - ctx = upscale_all_batches( - runner, ctx=ctx, debug=debug, progress_callback=None, - seed=args.seed, - latent_noise_scale=args.latent_noise_scale, - cache_model=cache_dit - ) - - # Phase 3: Decode - ctx = decode_all_batches( - runner, ctx=ctx, debug=debug, progress_callback=None, - cache_model=cache_vae - ) - - # Phase 4: Post-process - ctx = postprocess_all_batches( - ctx=ctx, debug=debug, progress_callback=None, - color_correction=args.color_correction, - prepend_frames=0, # Worker mode handles this in main process - temporal_overlap=args.temporal_overlap, - batch_size=args.batch_size - ) - - result_tensor = ctx['final_video'] - - # Convert to CPU and compatible dtype - if result_tensor.is_cuda or result_tensor.is_mps: - result_tensor = result_tensor.cpu() - if result_tensor.dtype in (torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2): - result_tensor = result_tensor.to(torch.float32) - - return result_tensor + runner = None + cache_context = None + + def cleanup(dit_cache_flag: bool = False, vae_cache_flag: bool = False) -> None: + nonlocal runner, ctx + + if runner is not None: + claimed_dit = cache_context.get('cached_dit') if cache_context is not None else None + claimed_vae = cache_context.get('cached_vae') if cache_context is not None else None + refreshed_dit = None + refreshed_vae = None + try: + try: + complete_cleanup( + runner=runner, + debug=debug, + dit_cache=dit_cache_flag, + vae_cache=vae_cache_flag, + ) + if dit_cache_flag or vae_cache_flag: + refreshed_dit, refreshed_vae = _finalize_claimed_cached_models_for_reuse(cache_context, runner, debug) + except Exception: + try: + _evict_claimed_cached_models(cache_context, runner, debug) + except Exception as evict_error: + if debug is not None: + debug.log( + f"Failed to evict claimed cached models while handling prior cleanup/finalize exception: {evict_error}", + level="WARNING", + category="cleanup", + force=True, + ) + raise + finally: + if dit_cache_flag and claimed_dit is not None: + set_model_cache_claimed_state(claimed_dit, False) + if dit_cache_flag and refreshed_dit is not None and refreshed_dit is not claimed_dit: + set_model_cache_claimed_state(refreshed_dit, False) + if vae_cache_flag and claimed_vae is not None: + set_model_cache_claimed_state(claimed_vae, False) + if vae_cache_flag and refreshed_vae is not None and refreshed_vae is not claimed_vae: + set_model_cache_claimed_state(refreshed_vae, False) + runner._seedvr2_execution_active = False + + if not (dit_cache_flag or vae_cache_flag): + runner = None + if runner_cache is not None: + runner_cache.pop('runner', None) + + if ctx is not None: + cleanup_text_embeddings(ctx, debug) + if not (dit_cache_flag or vae_cache_flag): + ctx = None + if runner_cache is not None: + runner_cache.pop('ctx', None) + + try: + runner, cache_context = prepare_runner( + dit_model=args.dit_model, + vae_model=DEFAULT_VAE, + model_dir=model_dir, + debug=debug, + ctx=ctx, + dit_cache=cache_dit, + vae_cache=cache_vae, + dit_id=dit_id, + vae_id=vae_id, + block_swap_config={ + 'blocks_to_swap': args.blocks_to_swap, + 'swap_io_components': args.swap_io_components, + 'offload_device': dit_offload, + }, + encode_tiled=args.vae_encode_tiled, + encode_tile_size=(args.vae_encode_tile_size, args.vae_encode_tile_size), + encode_tile_overlap=(args.vae_encode_tile_overlap, args.vae_encode_tile_overlap), + decode_tiled=args.vae_decode_tiled, + decode_tile_size=(args.vae_decode_tile_size, args.vae_decode_tile_size), + decode_tile_overlap=(args.vae_decode_tile_overlap, args.vae_decode_tile_overlap), + tile_debug=args.tile_debug.lower() if args.tile_debug else "false", + attention_mode=args.attention_mode, + torch_compile_args_dit=torch_compile_args_dit, + torch_compile_args_vae=torch_compile_args_vae + ) + + runner._seedvr2_execution_active = True + runner._seedvr2_runner_tainted = False + runner._seedvr2_dit_phase_cleaned = False + runner._seedvr2_vae_phase_cleaned = False + + if ( + cache_context is not None + and not cache_context.get('reusing_runner', False) + and cache_context.get('cached_dit') is not None + and cache_context.get('cached_vae') is not None + ): + cache_context['global_cache'].set_runner( + cache_context.get('dit_id'), + cache_context.get('vae_id'), + runner, + debug, + ) + + ctx['cache_context'] = cache_context + if runner_cache is not None: + runner_cache['runner'] = runner + + # Preload text embeddings before Phase 1 to avoid sync stall in Phase 2 + ctx['text_embeds'] = load_text_embeddings(script_directory, ctx['dit_device'], ctx['compute_dtype'], debug) + debug.log("Loaded text embeddings for DiT", category="dit") + + # Compute generation info and log start (handles prepending internally) + frames_tensor, gen_info = compute_generation_info( + ctx=ctx, + images=frames_tensor, + resolution=args.resolution, + max_resolution=args.max_resolution, + batch_size=args.batch_size, + uniform_batch_size=args.uniform_batch_size, + seed=args.seed, + prepend_frames=args.prepend_frames, + temporal_overlap=args.temporal_overlap, + debug=debug + ) + log_generation_start(gen_info, debug) + + # Phase 1: Encode + ctx = encode_all_batches( + runner, ctx=ctx, images=frames_tensor, + debug=debug, + batch_size=args.batch_size, + uniform_batch_size=args.uniform_batch_size, + seed=args.seed, + progress_callback=None, + temporal_overlap=args.temporal_overlap, + resolution=args.resolution, + max_resolution=args.max_resolution, + input_noise_scale=args.input_noise_scale, + color_correction=args.color_correction + ) + + # Phase 2: Upscale + ctx = upscale_all_batches( + runner, ctx=ctx, debug=debug, progress_callback=None, + seed=args.seed, + latent_noise_scale=args.latent_noise_scale, + cache_model=cache_dit + ) + + # Phase 3: Decode + ctx = decode_all_batches( + runner, ctx=ctx, debug=debug, progress_callback=None, + cache_model=cache_vae + ) + + # Phase 4: Post-process + ctx = postprocess_all_batches( + ctx=ctx, debug=debug, progress_callback=None, + color_correction=args.color_correction, + prepend_frames=0, # Worker mode handles this in main process + temporal_overlap=args.temporal_overlap, + batch_size=args.batch_size + ) + + result_tensor = ctx['final_video'] + + # Convert to CPU and compatible dtype + if result_tensor.is_cuda or result_tensor.is_mps: + result_tensor = result_tensor.cpu() + if result_tensor.dtype in (torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2): + result_tensor = result_tensor.to(torch.float32) + + cleanup(dit_cache_flag=cache_dit, vae_cache_flag=cache_vae) + return result_tensor + except BaseException: + claimed_dit = cache_context.get('cached_dit') if cache_context is not None else None + claimed_vae = cache_context.get('cached_vae') if cache_context is not None else None + if cache_context is not None: + _evict_claimed_cached_models(cache_context, runner, debug) + if runner is not None and cache_context.get('reusing_runner', False): + try: + cache_context['global_cache'].taint_and_remove_runner( + cache_context.get('dit_id'), + cache_context.get('vae_id'), + debug, + expected_runner=runner, + ) + except BaseException as cache_error: + if debug is not None: + debug.log( + f"Failed to evict cached runner while handling prior exception " + f"(runner={id(runner)}, dit_id={cache_context.get('dit_id')}, vae_id={cache_context.get('vae_id')}): {cache_error}", + level="WARNING", + category="cleanup", + force=True, + ) + + try: + try: + cleanup(dit_cache_flag=False, vae_cache_flag=False) + except BaseException as cleanup_error: + if debug is not None: + debug.log( + f"Cleanup failed while handling prior exception " + f"(runner={id(runner) if runner is not None else 'none'}, dit_id={cache_context.get('dit_id') if cache_context is not None else 'none'}, vae_id={cache_context.get('vae_id') if cache_context is not None else 'none'}): {cleanup_error}", + level="WARNING", + category="cleanup", + force=True, + ) + finally: + if claimed_dit is not None: + set_model_cache_claimed_state(claimed_dit, False) + if claimed_vae is not None: + set_model_cache_claimed_state(claimed_vae, False) + raise def _worker_process( @@ -1709,4 +1837,4 @@ def main() -> None: debug.print_footer() if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/core/generation_phases.py b/src/core/generation_phases.py index 3b7e6ea0..971365b3 100644 --- a/src/core/generation_phases.py +++ b/src/core/generation_phases.py @@ -47,9 +47,11 @@ cleanup_dit, cleanup_vae, cleanup_text_embeddings, + is_model_cache_cold, manage_tensor, manage_model_device, release_tensor_memory, + synchronize_device, release_tensor_collection ) from ..optimization.performance import ( @@ -141,7 +143,7 @@ def _reconstruct_and_transform_batch( Transformed video in CTHW format, ready for color correction """ start_idx, end_idx, uniform_padding = ctx['batch_metadata'][batch_idx] - + # Prepare video batch video = _prepare_video_batch( images=ctx['input_images'], @@ -151,6 +153,17 @@ def _reconstruct_and_transform_batch( debug=None, log_info=False ) + + # Mirror Phase 1 ordering: move to the VAE device before padding and transform. + video = manage_tensor( + tensor=video, + target_device=ctx['vae_device'], + tensor_name=f"reconstructed_video_batch_{batch_idx+1}", + dtype=ctx['compute_dtype'], + debug=debug, + reason="Phase 4 input reconstruction", + indent_level=1, + ) # Apply 4n+1 padding using shared helper video = _apply_4n1_padding(video) @@ -292,14 +305,23 @@ def encode_all_batches( ctx['batch_metadata'] = [None] * num_encode_batches encode_idx = 0 + validated_computed_target_dims = False try: + vae_needs_reactivation = runner.vae is not None and is_model_cache_cold(runner.vae) + # Materialize VAE if still on meta device if runner.vae and next(runner.vae.parameters()).device.type == 'meta': materialize_model(runner, "vae", ctx['vae_device'], runner.config, debug) else: + # Cold cached models keep weights/config, but execution state is rebuilt each run. + if vae_needs_reactivation: + debug.log("Rebuilding VAE execution state from cold cache", category="vae", force=True) + manage_model_device(model=runner.vae, target_device=ctx['vae_device'], + model_name="VAE", debug=debug, reason="cold-cache activation", runner=runner) + apply_model_specific_config(runner.vae, runner, runner.config, False, debug) # Model already materialized (cached) - apply any pending configs if needed - if getattr(runner, '_vae_config_needs_application', False): + elif getattr(runner, '_vae_config_needs_application', False): debug.log("Applying updated VAE configuration", category="vae", force=True) apply_model_specific_config(runner.vae, runner, runner.config, False, debug) @@ -309,11 +331,13 @@ def encode_all_batches( # Cache VAE now that it's fully configured and ready for inference if ctx['cache_context']['vae_cache'] and not ctx['cache_context']['cached_vae']: runner.vae._model_name = ctx['cache_context']['vae_model'] - ctx['cache_context']['global_cache'].set_vae( + cached_vae_id = ctx['cache_context']['global_cache'].set_vae( {'node_id': ctx['cache_context']['vae_id'], 'cache_model': True}, runner.vae, ctx['cache_context']['vae_model'], debug ) - ctx['cache_context']['vae_newly_cached'] = True + if cached_vae_id is not None: + ctx['cache_context']['vae_newly_cached'] = True + ctx['cache_context']['cached_vae'] = runner.vae # If both models now cached, cache runner template dit_is_cached = ctx['cache_context']['cached_dit'] or ctx['cache_context']['dit_newly_cached'] @@ -412,6 +436,34 @@ def encode_all_batches( transformed_video = ctx['video_transform'](rgb_video) + if ( + not validated_computed_target_dims + and 'padded_target_dims' in ctx + and 'true_target_dims' in ctx + ): + actual_padded_h, actual_padded_w = transformed_video.shape[-2:] + expected_padded_h, expected_padded_w = ctx['padded_target_dims'] + expected_true_h, expected_true_w = ctx['true_target_dims'] + + if (actual_padded_h, actual_padded_w) != (expected_padded_h, expected_padded_w): + msg = ( + "Computed target dims mismatch: " + f"expected padded {expected_padded_w}x{expected_padded_h}, " + f"actual transform output {actual_padded_w}x{actual_padded_h}, " + f"cached true target {expected_true_w}x{expected_true_h}" + ) + if debug is not None: + debug.log(msg, level="ERROR", category="setup", force=True) + raise RuntimeError(msg) + + if getattr(debug, "enabled", False): + debug.log( + f"Validated computed target dims against actual transform output: padded {expected_padded_w}x{expected_padded_h}, true {expected_true_w}x{expected_true_h}", + category="setup", + force=True, + ) + validated_computed_target_dims = True + # Apply input noise if requested (to reduce artifacts at high resolutions) if input_noise_scale > 0: debug.log(f"Applying input noise (scale: {input_noise_scale:.2f})", category="video", indent_level=1) @@ -616,12 +668,20 @@ def upscale_all_batches( upscale_idx = 0 try: + dit_needs_reactivation = runner.dit is not None and is_model_cache_cold(runner.dit) + # Materialize DiT if still on meta device if runner.dit and next(runner.dit.parameters()).device.type == 'meta': materialize_model(runner, "dit", ctx['dit_device'], runner.config, debug) else: + # Cold cached models keep weights/config, but execution state is rebuilt each run. + if dit_needs_reactivation: + debug.log("Rebuilding DiT execution state from cold cache", category="dit", force=True) + manage_model_device(model=runner.dit, target_device=ctx['dit_device'], + model_name="DiT", debug=debug, reason="cold-cache activation", runner=runner) + apply_model_specific_config(runner.dit, runner, runner.config, True, debug) # Model already materialized (cached) - apply any pending configs if needed - if getattr(runner, '_dit_config_needs_application', False): + elif getattr(runner, '_dit_config_needs_application', False): debug.log("Applying updated DiT configuration", category="dit", force=True) apply_model_specific_config(runner.dit, runner, runner.config, True, debug) @@ -631,11 +691,13 @@ def upscale_all_batches( # Cache DiT now that it's fully configured and ready for inference if ctx['cache_context']['dit_cache'] and not ctx['cache_context']['cached_dit']: runner.dit._model_name = ctx['cache_context']['dit_model'] - ctx['cache_context']['global_cache'].set_dit( + cached_dit_id = ctx['cache_context']['global_cache'].set_dit( {'node_id': ctx['cache_context']['dit_id'], 'cache_model': True}, runner.dit, ctx['cache_context']['dit_model'], debug ) - ctx['cache_context']['dit_newly_cached'] = True + if cached_dit_id is not None: + ctx['cache_context']['dit_newly_cached'] = True + ctx['cache_context']['cached_dit'] = runner.dit # If both models now cached, cache runner template vae_is_cached = ctx['cache_context']['cached_vae'] or ctx['cache_context']['vae_newly_cached'] @@ -1236,6 +1298,7 @@ def postprocess_all_batches( sample = sample_thwc.permute(0, 3, 1, 2) # [T, H, W, C] → [T, C, H, W] # Move to VAE device for processing + phase4_probe_enabled = getattr(debug, "enabled", False) sample = manage_tensor( tensor=sample, target_device=ctx['vae_device'], @@ -1268,7 +1331,23 @@ def postprocess_all_batches( # Trim spatial dimensions to true target size if 'true_target_dims' in ctx: true_h, true_w = ctx['true_target_dims'] - if input_video.shape[-2] != true_h or input_video.shape[-1] != true_w: + current_h, current_w = input_video.shape[-2:] + if current_h != true_h or current_w != true_w: + if current_h < true_h or current_w < true_w: + msg = ( + "Reconstructed input spatial dims smaller than expected true target dims: " + f"{current_w}x{current_h} < {true_w}x{true_h}" + ) + if debug: + debug.log(msg, level="ERROR", category="video", force=True) + raise RuntimeError(msg) + + if debug: + debug.log( + f"Trimming reconstructed input spatial padding: {current_w}x{current_h} → {true_w}x{true_h}", + category="video", + indent_level=1, + ) input_video = input_video[:, :, :true_h, :true_w] # Apply color correction if enabled (RGB only) @@ -1297,24 +1376,33 @@ def postprocess_all_batches( # Apply selected color correction method debug.start_timer(f"color_correction_{color_correction}") + color_correction_applied = False if color_correction == "lab": debug.log("Applying LAB perceptual color transfer", category="video", force=True, indent_level=1) sample = lab_color_transfer(sample, input_video, debug, luminance_weight=0.8) + color_correction_applied = True elif color_correction == "wavelet_adaptive": debug.log("Applying wavelet with adaptive saturation correction", category="video", force=True, indent_level=1) sample = wavelet_adaptive_color_correction(sample, input_video, debug) + color_correction_applied = True elif color_correction == "wavelet": debug.log("Applying wavelet color reconstruction", category="video", force=True, indent_level=1) sample = wavelet_reconstruction(sample, input_video, debug) + color_correction_applied = True elif color_correction == "hsv": debug.log("Applying HSV hue-conditional saturation matching", category="video", force=True, indent_level=1) sample = hsv_saturation_histogram_match(sample, input_video, debug) + color_correction_applied = True elif color_correction == "adain": debug.log("Applying AdaIN color correction", category="video", force=True, indent_level=1) sample = adaptive_instance_normalization(sample, input_video) + color_correction_applied = True else: debug.log(f"Unknown color correction method: {color_correction}", level="WARNING", category="video", force=True, indent_level=1) + + if phase4_probe_enabled and color_correction_applied and sample.device.type != "cpu": + synchronize_device(sample.device, debug=debug, reason=f"phase4 batch {info_idx+1} after {color_correction}") debug.end_timer(f"color_correction_{color_correction}", f"Color correction ({color_correction})") @@ -1466,6 +1554,8 @@ def postprocess_all_batches( del ctx['all_ori_lengths'] if 'true_target_dims' in ctx: del ctx['true_target_dims'] + if 'padded_target_dims' in ctx: + del ctx['padded_target_dims'] if 'batch_metadata' in ctx: del ctx['batch_metadata'] if 'input_images' in ctx: diff --git a/src/core/generation_utils.py b/src/core/generation_utils.py index 9a4cb3c5..e7c69c68 100644 --- a/src/core/generation_utils.py +++ b/src/core/generation_utils.py @@ -29,9 +29,12 @@ import os import torch -from typing import Dict, List, Optional, Tuple, Any, Callable, Union +from typing import Dict, List, Optional, Tuple, Any, Callable, Union, TYPE_CHECKING from torchvision.transforms import Compose, Lambda, Normalize +if TYPE_CHECKING: + from ..utils.debug import Debug + from .model_configuration import configure_runner from .infer import VideoDiffusionInfer from ..data.image.transforms.divisible_crop import DivisiblePad @@ -44,6 +47,56 @@ script_directory = get_script_directory() +def _compute_side_resize_output_dims( + input_height: int, + input_width: int, + resolution: int, + max_resolution: int = 0, +) -> Tuple[int, int]: + """Compute output dims for the target-dimension probe path. + + Matches SideResize with downsample_only=False and the current max_size + second pass, without executing the resize kernel. + """ + short, long = (input_width, input_height) if input_width <= input_height else (input_height, input_width) + + # Match torchvision Resize(int) behavior used by SideResize: + # shortest edge = resolution, longest edge scaled with floor(int(...)) + resized_short = resolution + resized_long = int(resolution * long / short) + + if input_width <= input_height: + resized_w, resized_h = resized_short, resized_long + else: + resized_w, resized_h = resized_long, resized_short + + # Match SideResize's second-pass max_size handling exactly. + if max_resolution > 0 and max(resized_h, resized_w) > max_resolution: + scale = max_resolution / max(resized_h, resized_w) + resized_h = round(resized_h * scale) + resized_w = round(resized_w * scale) + + return resized_h, resized_w + + +def _compute_sample_frame_target_dims( + sample_frame: torch.Tensor, + resolution: int, + max_resolution: int = 0, + debug: Optional['Debug'] = None, +) -> Tuple[int, int]: + """Compute probe target dimensions without running a real resize.""" + input_h, input_w = sample_frame.shape[-2:] + resized_h, resized_w = _compute_side_resize_output_dims( + input_height=input_h, + input_width=input_w, + resolution=resolution, + max_resolution=max_resolution, + ) + + return resized_h, resized_w + + def prepare_video_transforms(resolution: int, max_resolution: int = 0, debug: Optional['Debug'] = None) -> Compose: """ Prepare optimized video transformation pipeline @@ -105,15 +158,42 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: existing_transform = ctx.get('video_transform') if existing_transform is not None: - # Transform exists - check if we need to compute dimensions - if 'true_target_dims' in ctx and sample_frame is not None: - # Return cached dimensions + recompute padded from sample + # Transform exists - return cached dimensions without re-running the pipeline + if 'true_target_dims' in ctx and 'padded_target_dims' in ctx: true_h, true_w = ctx['true_target_dims'] - transformed = existing_transform(sample_frame) - padded_h, padded_w = transformed.shape[-2:] + padded_h, padded_w = ctx['padded_target_dims'] if debug: debug.log("Reusing pre-initialized video transformation pipeline", category="reuse") return true_h, true_w, padded_h, padded_w + if sample_frame is not None: + resized_h, resized_w = _compute_sample_frame_target_dims( + sample_frame, + resolution, + max_resolution, + debug, + ) + + # Round to even numbers for video codec compatibility (libx264 requirement) + true_h = (resized_h // 2) * 2 + true_w = (resized_w // 2) * 2 + + # Cache for later use in trimming + ctx['true_target_dims'] = (true_h, true_w) + + # Compute padded dimensions from the resized shape before even-rounding + padded_h = ((resized_h + 15) // 16) * 16 + padded_w = ((resized_w + 15) // 16) * 16 + ctx['padded_target_dims'] = (padded_h, padded_w) + + if debug: + if true_h == padded_h and true_w == padded_w: + debug.log(f"Target dimensions: {true_w}x{true_h} (no padding needed)", + category="setup", indent_level=1) + else: + debug.log(f"Target dimensions: {true_w}x{true_h} (padded to {padded_w}x{padded_h} for processing)", + category="setup", indent_level=1) + + return true_h, true_w, padded_h, padded_w elif debug: debug.log("Reusing pre-initialized video transformation pipeline", category="reuse") return 0, 0, 0, 0 @@ -124,23 +204,24 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: # Compute dimensions if sample frame provided if sample_frame is not None: # Get true target size (after resize, before padding) - temp_transform = Compose([ - NaResize(resolution=resolution, mode="side", downsample_only=False, max_resolution=max_resolution), - Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) - ]) - resized_sample = temp_transform(sample_frame) - true_h, true_w = resized_sample.shape[-2:] + resized_h, resized_w = _compute_sample_frame_target_dims( + sample_frame, + resolution, + max_resolution, + debug, + ) # Round to even numbers for video codec compatibility (libx264 requirement) - true_h = (true_h // 2) * 2 - true_w = (true_w // 2) * 2 + true_h = (resized_h // 2) * 2 + true_w = (resized_w // 2) * 2 # Cache for later use in trimming ctx['true_target_dims'] = (true_h, true_w) - - # Get padded dimensions - transformed_sample = ctx['video_transform'](sample_frame) - padded_h, padded_w = transformed_sample.shape[-2:] + + # Compute padded dimensions from the resized shape before even-rounding + padded_h = ((resized_h + 15) // 16) * 16 + padded_w = ((resized_w + 15) // 16) * 16 + ctx['padded_target_dims'] = (padded_h, padded_w) if debug: if true_h == padded_h and true_w == padded_w: @@ -149,8 +230,7 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: else: debug.log(f"Target dimensions: {true_w}x{true_h} (padded to {padded_w}x{padded_h} for processing)", category="setup", indent_level=1) - - del temp_transform, resized_sample, transformed_sample + return true_h, true_w, padded_h, padded_w return 0, 0, 0, 0 @@ -436,6 +516,9 @@ def prepare_runner( decode_tile_size: Optional[Tuple[int, int]] = None, decode_tile_overlap: Optional[Tuple[int, int]] = None, tile_debug: str = "false", + dit_tiled: bool = False, + dit_tile_size: Optional[Tuple[int, int]] = None, + dit_tile_overlap: Optional[Tuple[int, int]] = None, attention_mode: str = 'sdpa', torch_compile_args_dit: Optional[Dict[str, Any]] = None, torch_compile_args_vae: Optional[Dict[str, Any]] = None @@ -462,6 +545,9 @@ def prepare_runner( decode_tile_size: Tile size for decoding (height, width) decode_tile_overlap: Tile overlap for decoding (height, width) tile_debug: Tile visualization mode (false/encode/decode) + dit_tiled: Enable spatial DiT tiling during upscaling + dit_tile_size: Spatial DiT tile size (height, width) in latent-space pixels + dit_tile_overlap: Spatial overlap (height, width) between DiT tiles in latent-space pixels attention_mode: Attention computation backend ('sdpa', 'flash_attn_2', 'flash_attn_3', 'sageattn_2', or 'sageattn_3') torch_compile_args_dit: Optional torch.compile configuration for DiT model torch_compile_args_vae: Optional torch.compile configuration for VAE model @@ -506,6 +592,9 @@ def prepare_runner( decode_tile_size=decode_tile_size, decode_tile_overlap=decode_tile_overlap, tile_debug=tile_debug, + dit_tiled=dit_tiled, + dit_tile_size=dit_tile_size, + dit_tile_overlap=dit_tile_overlap, attention_mode=attention_mode, torch_compile_args_dit=torch_compile_args_dit, torch_compile_args_vae=torch_compile_args_vae @@ -824,4 +913,4 @@ def ensure_precision_initialized( debug.log(f"Model precision: {', '.join(parts)}", category="precision") except Exception as e: - debug.log(f"Could not log model dtypes: {e}", level="WARNING", category="precision", force=True) \ No newline at end of file + debug.log(f"Could not log model dtypes: {e}", level="WARNING", category="precision", force=True) diff --git a/src/core/infer.py b/src/core/infer.py index a0869cae..0431b05b 100644 --- a/src/core/infer.py +++ b/src/core/infer.py @@ -39,7 +39,9 @@ def __init__(self, config: DictConfig, debug: 'Debug', encode_tile_overlap: Tuple[int, int] = (64, 64), decode_tiled: bool = False, decode_tile_size: Tuple[int, int] = (512, 512), decode_tile_overlap: Tuple[int, int] = (64, 64), - tile_debug: str = "false"): + tile_debug: str = "false", + dit_tiled: bool = False, dit_tile_size: Tuple[int, int] = (128, 128), + dit_tile_overlap: Tuple[int, int] = (16, 16)): self.config = config self.debug = debug # Store separate encode and decode tiling parameters @@ -50,6 +52,9 @@ def __init__(self, config: DictConfig, debug: 'Debug', self.decode_tile_size = decode_tile_size self.decode_tile_overlap = decode_tile_overlap self.tile_debug = tile_debug + self.dit_tiled = dit_tiled + self.dit_tile_size = dit_tile_size + self.dit_tile_overlap = dit_tile_overlap def get_condition(self, latent: Tensor, latent_blur: Tensor, task: str) -> Tensor: t, h, w, c = latent.shape @@ -310,9 +315,68 @@ def get_lin_function(x1, y1, x2, y2): timesteps = timesteps * self.schedule.T return timesteps - - @torch.no_grad() - def inference( + @staticmethod + def _tile_axis_starts(length: int, tile: int, overlap: int) -> List[int]: + if length <= tile: + return [0] + + stride = max(1, tile - overlap) + starts: List[int] = [] + start = 0 + while True: + starts.append(start) + if start + tile >= length: + break + next_start = min(start + stride, length - tile) + if next_start <= start: + break + start = next_start + return starts + + @staticmethod + def _tile_blend_vector( + length: int, + overlap: int, + is_start_edge: bool, + is_end_edge: bool, + device: torch.device, + dtype: torch.dtype, + ) -> Tensor: + weight = torch.ones((length,), device=device, dtype=dtype) + if overlap <= 0 or length <= 1: + return weight + + ramp_extent = min(overlap, length - 1) + if ramp_extent <= 0: + return weight + + ramp = torch.linspace(1.0 / (ramp_extent + 1), 1.0, steps=ramp_extent, device=device, dtype=dtype) + if not is_start_edge: + weight[:ramp_extent] = ramp + if not is_end_edge: + weight[-ramp_extent:] = torch.minimum(weight[-ramp_extent:], torch.flip(ramp, dims=[0])) + return weight + + def _dit_blend_mask( + self, + tile_h: int, + tile_w: int, + y0: int, + y1: int, + x0: int, + x1: int, + full_h: int, + full_w: int, + device: torch.device, + dtype: torch.dtype, + ) -> Tensor: + overlap_h = max(0, min(self.dit_tile_overlap[0], tile_h - 1)) + overlap_w = max(0, min(self.dit_tile_overlap[1], tile_w - 1)) + weight_y = self._tile_blend_vector(tile_h, overlap_h, y0 == 0, y1 >= full_h, device, dtype) + weight_x = self._tile_blend_vector(tile_w, overlap_w, x0 == 0, x1 >= full_w, device, dtype) + return (weight_y[:, None] * weight_x[None, :]).view(1, tile_h, tile_w, 1) + + def _inference_flat( self, noises: List[Tensor], conditions: List[Tensor], @@ -323,15 +387,12 @@ def inference( assert len(noises) == len(conditions) == len(texts_pos) == len(texts_neg) batch_size = len(noises) - # Return if empty. if batch_size == 0: return [] - - # Set cfg scale + if cfg_scale is None: cfg_scale = self.config.diffusion.cfg.scale - - # Text embeddings. + assert type(texts_pos[0]) is type(texts_neg[0]) if isinstance(texts_pos[0], str): text_pos_embeds, text_pos_shapes = self.text_encode(texts_pos) @@ -350,11 +411,10 @@ def inference( else: text_pos_embeds, text_pos_shapes = na.flatten(texts_pos) text_neg_embeds, text_neg_shapes = na.flatten(texts_neg) - - # Flatten. + latents, latents_shapes = na.flatten(noises) latents_cond, _ = na.flatten(conditions) - + latents = self.sampler.sample( x=latents, f=lambda args: classifier_free_guidance_dispatcher( @@ -384,12 +444,115 @@ def inference( latents = na.unflatten(latents, latents_shapes) - # Clean up temporary tensors del latents_cond del latents_shapes del text_pos_embeds del text_neg_embeds del text_pos_shapes del text_neg_shapes - - return latents \ No newline at end of file + + return latents + + def _inference_tiled_single( + self, + noise: Tensor, + condition: Tensor, + texts_pos: Union[List[str], List[Tensor], List[Tuple[Tensor]]], + texts_neg: Union[List[str], List[Tensor], List[Tuple[Tensor]]], + cfg_scale: Optional[float] = None, + ) -> Tensor: + if noise.ndim != 4 or condition.ndim != 4: + return self._inference_flat([noise], [condition], texts_pos, texts_neg, cfg_scale=cfg_scale)[0] + + _, full_h, full_w, _ = noise.shape + tile_h = max(1, min(self.dit_tile_size[0], full_h)) + tile_w = max(1, min(self.dit_tile_size[1], full_w)) + + if full_h <= tile_h and full_w <= tile_w: + return self._inference_flat([noise], [condition], texts_pos, texts_neg, cfg_scale=cfg_scale)[0] + + overlap_h = max(0, min(self.dit_tile_overlap[0], tile_h - 1)) + overlap_w = max(0, min(self.dit_tile_overlap[1], tile_w - 1)) + y_starts = self._tile_axis_starts(full_h, tile_h, overlap_h) + x_starts = self._tile_axis_starts(full_w, tile_w, overlap_w) + tile_count = len(y_starts) * len(x_starts) + + if self.debug is not None: + self.debug.log( + f"Using DiT tiled inference ({tile_count} tiles, size {tile_h}x{tile_w}, overlap {overlap_h}x{overlap_w})", + category="dit", + force=True, + indent_level=1, + ) + + result = None + weight_sum = None + tile_index = 0 + + for y0 in y_starts: + y1 = min(y0 + tile_h, full_h) + for x0 in x_starts: + x1 = min(x0 + tile_w, full_w) + tile_index += 1 + if self.debug is not None and (tile_index == 1 or tile_index == tile_count or tile_index % 4 == 0): + self.debug.log( + f"DiT tile {tile_index}/{tile_count}: y={y0}:{y1}, x={x0}:{x1}", + category="dit", + indent_level=2, + ) + + noise_tile = noise[:, y0:y1, x0:x1, :] + condition_tile = condition[:, y0:y1, x0:x1, :] + tile_result = self._inference_flat([noise_tile], [condition_tile], texts_pos, texts_neg, cfg_scale=cfg_scale)[0] + + if result is None: + result = torch.zeros_like(noise) + weight_sum = torch.zeros((*noise.shape[:-1], 1), device=noise.device, dtype=tile_result.dtype) + + blend_mask = self._dit_blend_mask( + tile_h=y1 - y0, + tile_w=x1 - x0, + y0=y0, + y1=y1, + x0=x0, + x1=x1, + full_h=full_h, + full_w=full_w, + device=tile_result.device, + dtype=tile_result.dtype, + ) + + result[:, y0:y1, x0:x1, :] += tile_result * blend_mask + weight_sum[:, y0:y1, x0:x1, :] += blend_mask + + del noise_tile, condition_tile, tile_result, blend_mask + + weight_sum = torch.clamp(weight_sum, min=torch.finfo(weight_sum.dtype).eps) + result = result / weight_sum + del weight_sum + return result + + @torch.no_grad() + def inference( + self, + noises: List[Tensor], + conditions: List[Tensor], + texts_pos: Union[List[str], List[Tensor], List[Tuple[Tensor]]], + texts_neg: Union[List[str], List[Tensor], List[Tuple[Tensor]]], + cfg_scale: Optional[float] = None, + ) -> List[Tensor]: + if len(noises) == 0: + return [] + + if not self.dit_tiled or len(noises) != 1: + return self._inference_flat(noises, conditions, texts_pos, texts_neg, cfg_scale=cfg_scale) + + return [ + self._inference_tiled_single( + noise=noises[0], + condition=conditions[0], + texts_pos=texts_pos, + texts_neg=texts_neg, + cfg_scale=cfg_scale, + ) + ] diff --git a/src/core/model_cache.py b/src/core/model_cache.py index 2c54ea51..eb11fdfb 100644 --- a/src/core/model_cache.py +++ b/src/core/model_cache.py @@ -3,8 +3,18 @@ Enables independent DiT and VAE model sharing across multiple upscaler node instances """ -from typing import Dict, Any, Optional, Tuple -from ..optimization.memory_manager import release_model_memory +import threading +from typing import Dict, Any, Optional, Tuple, TYPE_CHECKING +from ..optimization.memory_manager import ( + is_model_cache_claimed, + iter_model_wrapper_chain, + release_model_memory, + set_model_cache_claimed_state, + set_model_cache_cold_state, +) + +if TYPE_CHECKING: + from ..utils.debug import Debug class GlobalModelCache: @@ -21,6 +31,19 @@ def __init__(self): self._vae_models: Dict[str, Tuple[Any, Dict]] = {} # Storage for runner templates: "dit_id+vae_id" -> runner self._runner_templates: Dict[str, Any] = {} + # Synchronizes DiT/VAE model cache claim/set/replace/remove operations + self._model_cache_lock = threading.RLock() + # Synchronizes runner-template claim/set/remove operations + self._runner_templates_lock = threading.RLock() + + def _models_share_identity(self, cached_model: Any, expected_model: Any) -> bool: + """Return True when two model references point into the same wrapper/base chain.""" + if cached_model is None or expected_model is None: + return False + + cached_ids = {id(model) for model in iter_model_wrapper_chain(cached_model)} + expected_ids = {id(model) for model in iter_model_wrapper_chain(expected_model)} + return bool(cached_ids & expected_ids) def get_dit(self, dit_config: Dict[str, Any], debug: Optional['Debug'] = None) -> Optional[Any]: """ @@ -37,10 +60,30 @@ def get_dit(self, dit_config: Dict[str, Any], debug: Optional['Debug'] = None) - return None node_id = dit_config.get('node_id') - if node_id in self._dit_models: - model, stored_config = self._dit_models[node_id] - return model + with self._model_cache_lock: + if node_id in self._dit_models: + model, stored_config = self._dit_models[node_id] + if is_model_cache_claimed(model): + if debug: + debug.log( + f"Cached DiT is already claimed by another execution; skipping reuse (node {node_id})", + category="cache", + force=True, + ) + return None + set_model_cache_claimed_state(model, True) + return model return None + + def peek_dit(self, dit_config: Dict[str, Any]) -> Optional[Any]: + """Return the cached DiT model without claiming it.""" + node_id = dit_config.get('node_id') + if node_id is None: + return None + + with self._model_cache_lock: + entry = self._dit_models.get(node_id) + return None if entry is None else entry[0] def get_vae(self, vae_config: Dict[str, Any], debug: Optional['Debug'] = None) -> Optional[Any]: """ @@ -57,10 +100,30 @@ def get_vae(self, vae_config: Dict[str, Any], debug: Optional['Debug'] = None) - return None node_id = vae_config.get('node_id') - if node_id in self._vae_models: - model, stored_config = self._vae_models[node_id] - return model + with self._model_cache_lock: + if node_id in self._vae_models: + model, stored_config = self._vae_models[node_id] + if is_model_cache_claimed(model): + if debug: + debug.log( + f"Cached VAE is already claimed by another execution; skipping reuse (node {node_id})", + category="cache", + force=True, + ) + return None + set_model_cache_claimed_state(model, True) + return model return None + + def peek_vae(self, vae_config: Dict[str, Any]) -> Optional[Any]: + """Return the cached VAE model without claiming it.""" + node_id = vae_config.get('node_id') + if node_id is None: + return None + + with self._model_cache_lock: + entry = self._vae_models.get(node_id) + return None if entry is None else entry[0] def get_runner(self, dit_id: Optional[int], vae_id: Optional[int], debug: Optional['Debug'] = None) -> Optional[Any]: @@ -79,9 +142,44 @@ def get_runner(self, dit_id: Optional[int], vae_id: Optional[int], return None runner_key = f"{dit_id}+{vae_id}" - if runner_key in self._runner_templates: - return self._runner_templates[runner_key] - return None + with self._runner_templates_lock: + return self._runner_templates.get(runner_key) + + def claim_runner(self, dit_id: Optional[int], vae_id: Optional[int], + dit_model: str, vae_model: str) -> Tuple[Optional[Any], str]: + """ + Atomically inspect and claim a cached runner template for exclusive reuse. + + Returns: + (template, status) where status is one of: + - "missing": no cached template exists + - "active": template exists but is already in use + - "tainted": template exists but was marked failed/interrupted + - "mismatch": template exists but was built for different DiT/VAE names + - "claimed": template was successfully claimed for reuse + """ + if dit_id is None or vae_id is None: + return None, "missing" + + runner_key = f"{dit_id}+{vae_id}" + with self._runner_templates_lock: + template = self._runner_templates.get(runner_key) + if template is None: + return None, "missing" + + if getattr(template, '_seedvr2_execution_active', False): + return template, "active" + + if getattr(template, '_seedvr2_runner_tainted', False): + return template, "tainted" + + current_dit = getattr(template, '_dit_model_name', None) + current_vae = getattr(template, '_vae_model_name', None) + if current_dit != dit_model or current_vae != vae_model: + return template, "mismatch" + + template._seedvr2_execution_active = True + return template, "claimed" def set_dit(self, dit_config: Dict[str, Any], model: Any, model_name: str, debug: Optional['Debug'] = None) -> Optional[str]: """ @@ -100,7 +198,22 @@ def set_dit(self, dit_config: Dict[str, Any], model: Any, model_name: str, debug return None node_id = dit_config.get('node_id') - self._dit_models[node_id] = (model, dit_config) + with self._model_cache_lock: + existing = self._dit_models.get(node_id) + if existing is not None: + existing_model, _ = existing + if not self._models_share_identity(existing_model, model) and is_model_cache_claimed(existing_model): + if debug: + debug.log( + f"Skipped caching DiT model for node {node_id}: cache entry is currently claimed by another execution", + level="WARNING", + category="cache", + force=True, + ) + return None + set_model_cache_cold_state(model, False) + set_model_cache_claimed_state(model, True) + self._dit_models[node_id] = (model, dit_config) if debug: debug.log(f"DiT model cached in memory (node {node_id}): {model_name}", @@ -125,13 +238,86 @@ def set_vae(self, vae_config: Dict[str, Any], model: Any, model_name: str, debug return None node_id = vae_config.get('node_id') - self._vae_models[node_id] = (model, vae_config) + with self._model_cache_lock: + existing = self._vae_models.get(node_id) + if existing is not None: + existing_model, _ = existing + if not self._models_share_identity(existing_model, model) and is_model_cache_claimed(existing_model): + if debug: + debug.log( + f"Skipped caching VAE model for node {node_id}: cache entry is currently claimed by another execution", + level="WARNING", + category="cache", + force=True, + ) + return None + set_model_cache_cold_state(model, False) + set_model_cache_claimed_state(model, True) + self._vae_models[node_id] = (model, vae_config) if debug: debug.log(f"VAE model cached in memory (node {node_id}): {model_name}", category="cache", force=True) return node_id + + def replace_dit( + self, + dit_config: Dict[str, Any], + model: Any, + debug: Optional['Debug'] = None, + expected_model: Optional[Any] = None, + ) -> bool: + """Rewrite a cached DiT entry to the latest claimed model object.""" + node_id = dit_config.get('node_id') + with self._model_cache_lock: + if node_id not in self._dit_models: + return False + + cached_model, stored_config = self._dit_models[node_id] + if expected_model is not None and not self._models_share_identity(cached_model, expected_model): + if debug: + debug.log( + f"Skipped cached DiT rewrite for node {node_id}: cache entry no longer matches the claimed model", + level="WARNING", + category="cache", + force=True, + ) + return False + + self._dit_models[node_id] = (model, stored_config) + if debug: + debug.log(f"Refreshed cached DiT entry to the latest claimed model object (node {node_id})", category="cache", force=True) + return True + + def replace_vae( + self, + vae_config: Dict[str, Any], + model: Any, + debug: Optional['Debug'] = None, + expected_model: Optional[Any] = None, + ) -> bool: + """Rewrite a cached VAE entry to the latest claimed model object.""" + node_id = vae_config.get('node_id') + with self._model_cache_lock: + if node_id not in self._vae_models: + return False + + cached_model, stored_config = self._vae_models[node_id] + if expected_model is not None and not self._models_share_identity(cached_model, expected_model): + if debug: + debug.log( + f"Skipped cached VAE rewrite for node {node_id}: cache entry no longer matches the claimed model", + level="WARNING", + category="cache", + force=True, + ) + return False + + self._vae_models[node_id] = (model, stored_config) + if debug: + debug.log(f"Refreshed cached VAE entry to the latest claimed model object (node {node_id})", category="cache", force=True) + return True def set_runner(self, dit_id: Optional[int], vae_id: Optional[int], runner: Any, debug: Optional['Debug'] = None) -> Optional[str]: @@ -145,22 +331,103 @@ def set_runner(self, dit_id: Optional[int], vae_id: Optional[int], debug: Optional debug instance for logging Returns: - Runner key string (format: "dit_id+vae_id") if cached successfully, - None if either ID is None or runner already cached + Runner key string (format: "dit_id+vae_id") if this call cached or + replaced the template, None if either ID is None or an existing + non-tainted template is intentionally kept. """ if dit_id is None or vae_id is None: return None runner_key = f"{dit_id}+{vae_id}" - if runner_key not in self._runner_templates: - self._runner_templates[runner_key] = runner - if debug: - debug.log(f"Runner template cached in memory: nodes {runner_key}", category="cache", force=True) - return runner_key + with self._runner_templates_lock: + existing = self._runner_templates.get(runner_key) + if existing is runner: + return runner_key + + replace_existing = False + if existing is not None: + replace_existing = getattr(existing, '_seedvr2_runner_tainted', False) + + if existing is None or replace_existing: + self._runner_templates[runner_key] = runner + if debug: + action = "replaced" if replace_existing else "cached" + debug.log(f"Runner template {action} in memory: nodes {runner_key}", category="cache", force=True) + return runner_key return None + + def remove_runner(self, dit_id: Optional[int], vae_id: Optional[int], + debug: Optional['Debug'] = None, + expected_runner: Optional[Any] = None) -> bool: + """Remove a cached runner template for the given DiT/VAE node pair. + + If expected_runner is provided, only remove the cache entry when the + currently stored runner is that exact object. + """ + if dit_id is None or vae_id is None: + return False + + with self._runner_templates_lock: + runner_key = f"{dit_id}+{vae_id}" + cached_runner = self._runner_templates.get(runner_key) + if cached_runner is None: + return False + + if expected_runner is not None and cached_runner is not expected_runner: + if debug: + debug.log( + f"Skipped cached runner removal for nodes {runner_key}: cache entry no longer matches expected runner", + level="WARNING", + category="cache", + force=True, + ) + return False + + del self._runner_templates[runner_key] + if debug: + debug.log(f"Removed cached runner template: nodes {runner_key}", category="cache", force=True) + return True + + def taint_and_remove_runner(self, + dit_id: Optional[int], + vae_id: Optional[int], + debug: Optional['Debug'] = None, + expected_runner: Optional[Any] = None) -> bool: + """Atomically mark a cached runner template tainted/inactive and remove it.""" + if dit_id is None or vae_id is None: + return False + + runner_key = f"{dit_id}+{vae_id}" + with self._runner_templates_lock: + cached_runner = self._runner_templates.get(runner_key) + if cached_runner is None: + return False + + if expected_runner is not None and cached_runner is not expected_runner: + if debug: + debug.log( + f"Skipped taint+remove for cached runner nodes {runner_key}: cache entry no longer matches expected runner", + level="WARNING", + category="cache", + force=True, + ) + return False + + cached_runner._seedvr2_runner_tainted = True + cached_runner._seedvr2_execution_active = False + del self._runner_templates[runner_key] + + if debug: + debug.log(f"Tainted and removed cached runner template: nodes {runner_key}", category="cache", force=True) + return True - def remove_dit(self, dit_config: Dict[str, Any], debug: Optional['Debug'] = None) -> bool: + def remove_dit( + self, + dit_config: Dict[str, Any], + debug: Optional['Debug'] = None, + expected_model: Optional[Any] = None, + ) -> bool: """ Remove DiT model from cache if it exists. @@ -175,27 +442,52 @@ def remove_dit(self, dit_config: Dict[str, Any], debug: Optional['Debug'] = None Also removes any runner templates that used this DiT model """ node_id = dit_config.get('node_id') - if node_id in self._dit_models: + with self._model_cache_lock: + if node_id not in self._dit_models: + return False + + cached_model, stored_config = self._dit_models[node_id] + if expected_model is None and is_model_cache_claimed(cached_model): + if debug: + debug.log( + f"Skipped cached DiT removal for node {node_id}: cache entry is currently claimed by another execution", + level="WARNING", + category="cache", + force=True, + ) + return False + if expected_model is not None and not self._models_share_identity(cached_model, expected_model): + if debug: + debug.log( + f"Skipped cached DiT removal for node {node_id}: cache entry no longer matches the claimed model", + level="WARNING", + category="cache", + force=True, + ) + return False + if debug: debug.log(f"Removing cached DiT: {node_id}", category="cache", force=True) - model, stored_config = self._dit_models[node_id] - - # Release model memory - if model is not None: - release_model_memory(model=model, debug=debug) - + model = cached_model del self._dit_models[node_id] - - # Remove any runner templates that used this DiT + + if model is not None: + release_model_memory(model=model, debug=debug) + + with self._runner_templates_lock: templates_to_remove = [k for k in self._runner_templates.keys() if k.startswith(str(node_id) + "+")] for template_key in templates_to_remove: - del self._runner_templates[template_key] - - return True - return False + del self._runner_templates[template_key] + + return True - def remove_vae(self, vae_config: Dict[str, Any], debug: Optional['Debug'] = None) -> bool: + def remove_vae( + self, + vae_config: Dict[str, Any], + debug: Optional['Debug'] = None, + expected_model: Optional[Any] = None, + ) -> bool: """ Remove VAE model from cache if it exists. @@ -210,25 +502,45 @@ def remove_vae(self, vae_config: Dict[str, Any], debug: Optional['Debug'] = None Also removes any runner templates that used this VAE model """ node_id = vae_config.get('node_id') - if node_id in self._vae_models: + with self._model_cache_lock: + if node_id not in self._vae_models: + return False + + cached_model, stored_config = self._vae_models[node_id] + if expected_model is None and is_model_cache_claimed(cached_model): + if debug: + debug.log( + f"Skipped cached VAE removal for node {node_id}: cache entry is currently claimed by another execution", + level="WARNING", + category="cache", + force=True, + ) + return False + if expected_model is not None and not self._models_share_identity(cached_model, expected_model): + if debug: + debug.log( + f"Skipped cached VAE removal for node {node_id}: cache entry no longer matches the claimed model", + level="WARNING", + category="cache", + force=True, + ) + return False + if debug: debug.log(f"Removing cached VAE: {node_id}", category="cache", force=True) - model, stored_config = self._vae_models[node_id] - - # Release model memory directly - if model is not None: - release_model_memory(model=model, debug=debug) - + model = cached_model del self._vae_models[node_id] - - # Remove any runner templates that used this VAE + + if model is not None: + release_model_memory(model=model, debug=debug) + + with self._runner_templates_lock: templates_to_remove = [k for k in self._runner_templates.keys() if k.endswith("+" + str(node_id))] for template_key in templates_to_remove: del self._runner_templates[template_key] - - return True - return False + + return True # Global singleton instance @@ -236,4 +548,4 @@ def remove_vae(self, vae_config: Dict[str, Any], debug: Optional['Debug'] = None def get_global_cache() -> GlobalModelCache: """Get the global model cache instance.""" - return _global_cache \ No newline at end of file + return _global_cache diff --git a/src/core/model_configuration.py b/src/core/model_configuration.py index 61297627..28c2c9dc 100644 --- a/src/core/model_configuration.py +++ b/src/core/model_configuration.py @@ -75,7 +75,13 @@ validate_attention_mode ) from ..optimization.blockswap import is_blockswap_enabled, validate_blockswap_config, apply_block_swap_to_dit, cleanup_blockswap -from ..optimization.memory_manager import cleanup_dit, cleanup_vae +from ..optimization.memory_manager import ( + cleanup_dit, + cleanup_vae, + is_model_cache_claimed, + set_model_cache_claimed_state, + set_model_cache_cold_state, +) from ..utils.constants import find_model_file @@ -320,7 +326,7 @@ def _update_model_config( elif config_name == 'attention_mode': display_name = 'Attention Mode' - config_changes.append(f"{display_name}: {old_desc} → {new_desc}") + config_changes.append(f"{display_name}: {old_desc} -> {new_desc}") # If nothing changed, reuse model as-is if not any(changes_detected.values()): @@ -592,41 +598,89 @@ def _initialize_cache_context( # Check for cached DiT model with model name validation # Model name validation prevents stale cache when user switches models in UI if dit_cache and dit_model and dit_id is not None: - cached_model = global_cache.get_dit({'node_id': dit_id, 'cache_model': True}, debug) - if cached_model: + cached_model = global_cache.peek_dit({'node_id': dit_id}) + if cached_model is not None: + cached_claimed = is_model_cache_claimed(cached_model) # Verify cached model matches requested model by checking _model_name attribute cached_model_name = getattr(cached_model, '_model_name', None) if cached_model_name == dit_model: # Cache hit with valid model - reuse it - context['cached_dit'] = cached_model + claimed_model = global_cache.get_dit({'node_id': dit_id, 'cache_model': True}, debug) + if claimed_model is not None: + claimed_model_name = getattr(claimed_model, '_model_name', None) + if claimed_model_name == dit_model: + context['cached_dit'] = claimed_model + else: + if claimed_model_name: + debug.log( + f"Claimed DiT no longer matches requested model ({claimed_model_name} -> {dit_model}), " + f"evicting claimed cache entry", + category="cache", + force=True, + ) + global_cache.remove_dit({'node_id': dit_id}, debug, expected_model=claimed_model) else: # Model changed - remove stale cache and log the change if cached_model_name: - debug.log(f"DiT model changed in cache ({cached_model_name} → {dit_model}), " + debug.log(f"DiT model changed in cache ({cached_model_name} -> {dit_model}), " f"removing stale cached model", category="cache", force=True) - global_cache.remove_dit({'node_id': dit_id}, debug) + if cached_claimed: + debug.log( + f"Cached DiT for node {dit_id} is stale but currently claimed; leaving it in cache until the owning execution releases it", + level="WARNING", + category="cache", + force=True, + ) + else: + global_cache.remove_dit({'node_id': dit_id}, debug) else: # Caching disabled or no ID - clean up any existing cache for this node if dit_id is not None: - global_cache.remove_dit({'node_id': dit_id}, debug) + cached_model = global_cache.peek_dit({'node_id': dit_id}) + if cached_model is not None and not is_model_cache_claimed(cached_model): + global_cache.remove_dit({'node_id': dit_id}, debug) # Check for cached VAE model with model name validation if vae_cache and vae_model and vae_id is not None: - cached_model = global_cache.get_vae({'node_id': vae_id, 'cache_model': True}, debug) - if cached_model: + cached_model = global_cache.peek_vae({'node_id': vae_id}) + if cached_model is not None: + cached_claimed = is_model_cache_claimed(cached_model) # Verify cached model matches requested model by checking _model_name attribute cached_model_name = getattr(cached_model, '_model_name', None) if cached_model_name == vae_model: - context['cached_vae'] = cached_model + claimed_model = global_cache.get_vae({'node_id': vae_id, 'cache_model': True}, debug) + if claimed_model is not None: + claimed_model_name = getattr(claimed_model, '_model_name', None) + if claimed_model_name == vae_model: + context['cached_vae'] = claimed_model + else: + if claimed_model_name: + debug.log( + f"Claimed VAE no longer matches requested model ({claimed_model_name} -> {vae_model}), " + f"evicting claimed cache entry", + category="cache", + force=True, + ) + global_cache.remove_vae({'node_id': vae_id}, debug, expected_model=claimed_model) else: # Model changed - remove stale cache and log the change if cached_model_name: - debug.log(f"VAE model changed in cache ({cached_model_name} → {vae_model}), " + debug.log(f"VAE model changed in cache ({cached_model_name} -> {vae_model}), " f"removing stale cached model", category="cache", force=True) - global_cache.remove_vae({'node_id': vae_id}, debug) + if cached_claimed: + debug.log( + f"Cached VAE for node {vae_id} is stale but currently claimed; leaving it in cache until the owning execution releases it", + level="WARNING", + category="cache", + force=True, + ) + else: + global_cache.remove_vae({'node_id': vae_id}, debug) else: if vae_id is not None: - global_cache.remove_vae({'node_id': vae_id}, debug) + cached_model = global_cache.peek_vae({'node_id': vae_id}) + if cached_model is not None and not is_model_cache_claimed(cached_model): + global_cache.remove_vae({'node_id': vae_id}, debug) return context @@ -658,26 +712,82 @@ def _acquire_runner( Returns: VideoDiffusionInfer: Runner instance (cached template or newly created) """ - # Try to get runner template from global cache - template = cache_context['global_cache'].get_runner( - cache_context['dit_id'], cache_context['vae_id'], debug + # Try to atomically claim a reusable runner template from global cache + template, template_status = cache_context['global_cache'].claim_runner( + cache_context['dit_id'], + cache_context['vae_id'], + dit_model, + vae_model, ) if template: - # We have a template - check if we can use it + runner_key = f"{cache_context['dit_id']}+{cache_context['vae_id']}" + + if template_status == "active": + debug.log( + f"Cached runner template still marked active: nodes {runner_key}; creating a fresh runner", + level="WARNING", + category="cache", + force=True, + ) + return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) + + if template_status == "tainted": + debug.log( + f"Cached runner template was tainted by a prior failed/interrupted run: nodes {runner_key}; creating a fresh runner", + level="WARNING", + category="cache", + force=True, + ) + cache_context['global_cache'].remove_runner( + cache_context['dit_id'], + cache_context['vae_id'], + debug, + expected_runner=template, + ) + return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) + + if template_status == "claimed": + need_dit = bool(cache_context.get('dit_cache') and cache_context.get('dit_id') is not None) + need_vae = bool(cache_context.get('vae_cache') and cache_context.get('vae_id') is not None) + have_dit = (not need_dit) or (cache_context.get('cached_dit') is not None) + have_vae = (not need_vae) or (cache_context.get('cached_vae') is not None) + + if have_dit and have_vae: + debug.log(f"Reusing cached runner template: nodes {runner_key}", category="reuse", force=True) + cache_context['reusing_runner'] = True + return template + + debug.log( + "Runner template matched, but required claimed cached models were not acquired; creating a fresh runner", + level="WARNING", + category="cache", + force=True, + ) + cache_context['global_cache'].taint_and_remove_runner( + cache_context['dit_id'], + cache_context['vae_id'], + debug, + expected_runner=template, + ) + return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) + current_dit = getattr(template, '_dit_model_name', None) current_vae = getattr(template, '_vae_model_name', None) - models_match = (current_dit == dit_model and current_vae == vae_model) - - if models_match: - # Perfect match - reuse template directly - runner_key = f"{cache_context['dit_id']}+{cache_context['vae_id']}" - debug.log(f"Reusing cached runner template: nodes {runner_key}", category="reuse", force=True) - cache_context['reusing_runner'] = True - return template - else: - # Template exists but models changed and no cached models - create new - return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) + debug.log( + f"Cached runner template models no longer match: nodes {runner_key} " + f"({current_dit}/{current_vae} -> {dit_model}/{vae_model}); creating a fresh runner", + level="WARNING", + category="cache", + force=True, + ) + cache_context['global_cache'].remove_runner( + cache_context['dit_id'], + cache_context['vae_id'], + debug, + expected_runner=template, + ) + return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) else: # No template - create new runner return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) @@ -698,8 +808,8 @@ def _create_new_runner( Args: dit_model: DiT model filename (determines config selection) - - Contains "7b" → loads configs_7b/main.yaml - - Otherwise → loads configs_3b/main.yaml + - Contains "7b" -> loads configs_7b/main.yaml + - Otherwise -> loads configs_3b/main.yaml vae_model: VAE model filename (stored for reference, not used in config selection) base_cache_dir: Base directory for model files (not used directly but passed for context) debug: Debug instance for logging and timing @@ -747,6 +857,9 @@ def configure_runner( decode_tile_size: Optional[Tuple[int, int]] = None, decode_tile_overlap: Optional[Tuple[int, int]] = None, tile_debug: str = "false", + dit_tiled: bool = False, + dit_tile_size: Optional[Tuple[int, int]] = None, + dit_tile_overlap: Optional[Tuple[int, int]] = None, attention_mode: str = 'sdpa', torch_compile_args_dit: Optional[Dict[str, Any]] = None, torch_compile_args_vae: Optional[Dict[str, Any]] = None @@ -774,6 +887,9 @@ def configure_runner( decode_tile_size: Tile size for decoding (height, width) decode_tile_overlap: Tile overlap for decoding (height, width) tile_debug: Tile visualization mode (false/encode/decode) + dit_tiled: Enable spatial DiT tiling during upscaling + dit_tile_size: Spatial DiT tile size (height, width) in latent-space pixels + dit_tile_overlap: Spatial overlap (height, width) between DiT tiles in latent-space pixels attention_mode: Attention computation backend ('sdpa', 'flash_attn_2', 'flash_attn_3', 'sageattn_2', or 'sageattn_3') torch_compile_args_dit: Optional torch.compile configuration for DiT model torch_compile_args_vae: Optional torch.compile configuration for VAE model @@ -810,33 +926,139 @@ def configure_runner( ) # Phase 2: Get or create runner + runner = None runner = _acquire_runner( cache_context, dit_model, vae_model, base_cache_dir, debug ) - # Phase 3: Configure runner settings - _configure_runner_settings( - runner, ctx, - encode_tiled, encode_tile_size, encode_tile_overlap, - decode_tiled, decode_tile_size, decode_tile_overlap, - tile_debug, attention_mode, - torch_compile_args_dit, torch_compile_args_vae, - block_swap_config, debug - ) - - # Phase 4: Setup models (load from cache or create new) - _setup_models( - runner, cache_context, dit_model, vae_model, - base_cache_dir, block_swap_config, debug - ) + try: + # Phase 3: Configure runner settings + _configure_runner_settings( + runner, ctx, + cache_context.get('dit_id') if dit_cache else None, + cache_context.get('vae_id') if vae_cache else None, + encode_tiled, encode_tile_size, encode_tile_overlap, + decode_tiled, decode_tile_size, decode_tile_overlap, + tile_debug, dit_tiled, dit_tile_size, dit_tile_overlap, attention_mode, + torch_compile_args_dit, torch_compile_args_vae, + block_swap_config, debug + ) + + # Phase 4: Setup models (load from cache or create new) + _setup_models( + runner, cache_context, dit_model, vae_model, + base_cache_dir, block_swap_config, debug + ) + except BaseException: + _evict_claimed_cached_models(cache_context, runner, debug) + if runner is not None and cache_context.get('reusing_runner', False): + try: + cache_context['global_cache'].taint_and_remove_runner( + cache_context.get('dit_id'), + cache_context.get('vae_id'), + debug, + expected_runner=runner, + ) + except Exception as cache_error: + if debug is not None: + debug.log( + f"Failed to evict claimed runner after setup failure: {cache_error}", + level="WARNING", + category="cleanup", + force=True, + ) + raise return runner, cache_context +def _evict_claimed_cached_models( + cache_context: Dict[str, Any], + runner: Optional[VideoDiffusionInfer], + debug: Optional['Debug'] = None, +) -> None: + """ + Evict claimed cached models after activation/setup failure. + + Claimed cached models may be partially materialized or partially reconfigured + when an exception interrupts setup. In that case they must be removed from the + global cache rather than merely unclaimed. + """ + if not cache_context: + return + + global_cache = cache_context.get('global_cache') + if global_cache is None: + return + + claimed_dit = cache_context.get('cached_dit') + claimed_vae = cache_context.get('cached_vae') + + dit_id = cache_context.get('dit_id') + if cache_context.get('dit_cache') and dit_id is not None and claimed_dit is not None: + global_cache.remove_dit({'node_id': dit_id}, debug, expected_model=claimed_dit) + + vae_id = cache_context.get('vae_id') + if cache_context.get('vae_cache') and vae_id is not None and claimed_vae is not None: + global_cache.remove_vae({'node_id': vae_id}, debug, expected_model=claimed_vae) + + +def _finalize_claimed_cached_models_for_reuse( + cache_context: Dict[str, Any], + runner: Optional[VideoDiffusionInfer], + debug: Optional['Debug'] = None, +) -> Tuple[Optional[Any], Optional[Any]]: + """Refresh or evict claimed cache entries using the released runner-held model refs.""" + refreshed_dit = None + refreshed_vae = None + + if not cache_context or runner is None: + return refreshed_dit, refreshed_vae + + global_cache = cache_context.get('global_cache') + if global_cache is None: + return refreshed_dit, refreshed_vae + + claimed_dit = cache_context.get('cached_dit') + claimed_vae = cache_context.get('cached_vae') + + dit_id = cache_context.get('dit_id') + if cache_context.get('dit_cache') and dit_id is not None and claimed_dit is not None: + released_dit = getattr(runner, 'dit', None) + if released_dit is not None: + if global_cache.replace_dit({'node_id': dit_id}, released_dit, debug, expected_model=claimed_dit): + refreshed_dit = released_dit + runner.dit = released_dit + else: + global_cache.remove_dit({'node_id': dit_id}, debug, expected_model=claimed_dit) + runner.dit = None + else: + global_cache.remove_dit({'node_id': dit_id}, debug, expected_model=claimed_dit) + runner.dit = None + + vae_id = cache_context.get('vae_id') + if cache_context.get('vae_cache') and vae_id is not None and claimed_vae is not None: + released_vae = getattr(runner, 'vae', None) + if released_vae is not None: + if global_cache.replace_vae({'node_id': vae_id}, released_vae, debug, expected_model=claimed_vae): + refreshed_vae = released_vae + runner.vae = released_vae + else: + global_cache.remove_vae({'node_id': vae_id}, debug, expected_model=claimed_vae) + runner.vae = None + else: + global_cache.remove_vae({'node_id': vae_id}, debug, expected_model=claimed_vae) + runner.vae = None + + return refreshed_dit, refreshed_vae + + def _configure_runner_settings( runner: VideoDiffusionInfer, ctx: Dict[str, Any], + dit_cache_node_id: Optional[int], + vae_cache_node_id: Optional[int], encode_tiled: bool, encode_tile_size: Optional[Tuple[int, int]], encode_tile_overlap: Optional[Tuple[int, int]], @@ -844,6 +1066,9 @@ def _configure_runner_settings( decode_tile_size: Optional[Tuple[int, int]], decode_tile_overlap: Optional[Tuple[int, int]], tile_debug: str, + dit_tiled: bool, + dit_tile_size: Optional[Tuple[int, int]], + dit_tile_overlap: Optional[Tuple[int, int]], attention_mode: str, torch_compile_args_dit: Optional[Dict[str, Any]], torch_compile_args_vae: Optional[Dict[str, Any]], @@ -868,6 +1093,9 @@ def _configure_runner_settings( decode_tile_size: Tile dimensions (height, width) for decoding in pixels decode_tile_overlap: Overlap dimensions (height, width) between decoding tiles tile_debug: Tile visualization mode (false/encode/decode) + dit_tiled: Enable spatial DiT tiling during upscaling + dit_tile_size: Spatial DiT tile size (height, width) in latent-space pixels + dit_tile_overlap: Spatial overlap (height, width) between DiT tiles in latent-space pixels attention_mode: Attention computation backend ('sdpa', 'flash_attn_2', 'flash_attn_3', 'sageattn_2', or 'sageattn_3') torch_compile_args_dit: torch.compile configuration for DiT model or None torch_compile_args_vae: torch.compile configuration for VAE model or None @@ -882,6 +1110,9 @@ def _configure_runner_settings( runner.decode_tile_size = decode_tile_size runner.decode_tile_overlap = decode_tile_overlap runner.tile_debug = tile_debug + runner.dit_tiled = dit_tiled + runner.dit_tile_size = dit_tile_size + runner.dit_tile_overlap = dit_tile_overlap # Store the new configs temporarily for later comparison # Don't set them as attributes yet - let the update functions handle that @@ -905,6 +1136,8 @@ def _configure_runner_settings( runner._vae_offload_device = ctx['vae_offload_device'] runner._tensor_offload_device = ctx['tensor_offload_device'] runner._compute_dtype = ctx['compute_dtype'] + runner._dit_cache_node_id = dit_cache_node_id + runner._vae_cache_node_id = vae_cache_node_id runner.debug = debug @@ -1021,7 +1254,7 @@ def _setup_dit_model( current_dit_name = getattr(runner, '_dit_model_name', None) if current_dit_name and current_dit_name != dit_model: if hasattr(runner, 'dit') and runner.dit is not None: - debug.log(f"DiT model changed ({current_dit_name} → {dit_model}), cleaning old model", + debug.log(f"DiT model changed ({current_dit_name} -> {dit_model}), cleaning old model", category="cache", force=True) cleanup_dit(runner=runner, debug=debug, cache_model=False) @@ -1094,7 +1327,7 @@ def _setup_vae_model( current_vae_name = getattr(runner, '_vae_model_name', None) if current_vae_name and current_vae_name != vae_model: if hasattr(runner, 'vae') and runner.vae is not None: - debug.log(f"VAE model changed ({current_vae_name} → {vae_model}), cleaning old model", + debug.log(f"VAE model changed ({current_vae_name} -> {vae_model}), cleaning old model", category="cache", force=True) cleanup_vae(runner=runner, debug=debug, cache_model=False) @@ -1278,7 +1511,9 @@ def apply_model_specific_config(model: torch.nn.Module, runner: VideoDiffusionIn # Clear the config application flag after successful application if hasattr(runner, '_vae_config_needs_application'): runner._vae_config_needs_application = False - + + set_model_cache_cold_state(model, False) + set_model_cache_claimed_state(model, True) return model @@ -1476,5 +1711,4 @@ def _propagate_debug_to_modules(module: torch.nn.Module, debug: 'Debug') -> None for name, submodule in module.named_modules(): if submodule.__class__.__name__ in target_modules: - if not hasattr(submodule, 'debug'): # Only set if not already present - submodule.debug = debug \ No newline at end of file + submodule.debug = debug diff --git a/src/interfaces/dit_model_loader.py b/src/interfaces/dit_model_loader.py index 9d8204e8..2ca5a441 100644 --- a/src/interfaces/dit_model_loader.py +++ b/src/interfaces/dit_model_loader.py @@ -5,7 +5,7 @@ from comfy_api.latest import io from comfy_execution.utils import get_executing_context -from typing import Dict, Any, Tuple +from typing import Dict, Any from ..utils.model_registry import get_available_dit_models, DEFAULT_DIT from ..optimization.memory_manager import get_device_list @@ -124,6 +124,39 @@ def define_schema(cls) -> io.Schema: "Provides 20-40% speedup with compatible PyTorch 2.0+ and Triton installation." ) ), + io.Boolean.Input("dit_tiled", + default=False, + optional=True, + tooltip=( + "Enable spatial tiling for the DiT upscaling phase.\n" + "Reduces peak VRAM during final SeedVR2 diffusion inference by processing latent tiles with overlap blending.\n" + "Slower than full-frame DiT inference, but can prevent VRAM overflow on large crops." + ) + ), + io.Int.Input("dit_tile_size", + default=128, + min=32, + max=2048, + step=8, + optional=True, + tooltip=( + "Spatial tile size for DiT inference in latent-space pixels (default: 128).\n" + "Smaller tiles reduce VRAM further but increase runtime and may reduce global consistency.\n" + "Only used when dit_tiled is enabled." + ) + ), + io.Int.Input("dit_tile_overlap", + default=16, + min=0, + max=512, + step=1, + optional=True, + tooltip=( + "Overlap between DiT latent tiles in pixels (default: 16).\n" + "Higher overlap reduces visible seams but increases compute.\n" + "Only used when dit_tiled is enabled." + ) + ), ], outputs=[ io.Custom("SEEDVR2_DIT").Output( @@ -136,7 +169,8 @@ def define_schema(cls) -> io.Schema: def execute(cls, model: str, device: str, offload_device: str = "none", cache_model: bool = False, blocks_to_swap: int = 0, swap_io_components: bool = False, attention_mode: str = "sdpa", - torch_compile_args: Dict[str, Any] = None) -> io.NodeOutput: + torch_compile_args: Dict[str, Any] = None, dit_tiled: bool = False, + dit_tile_size: int = 128, dit_tile_overlap: int = 16) -> io.NodeOutput: """ Create DiT model configuration for SeedVR2 main node @@ -149,6 +183,9 @@ def execute(cls, model: str, device: str, offload_device: str = "none", swap_io_components: Whether to offload I/O components (requires offload_device != device) attention_mode: Attention computation backend ('sdpa', 'flash_attn_2', 'flash_attn_3', 'sageattn_2', or 'sageattn_3') torch_compile_args: Optional torch.compile configuration from settings node + dit_tiled: Enable spatial DiT tiling during upscaling + dit_tile_size: Spatial DiT tile size in latent-space pixels + dit_tile_overlap: Spatial overlap between DiT tiles in latent-space pixels Returns: NodeOutput containing configuration dictionary for SeedVR2 main node @@ -174,7 +211,10 @@ def execute(cls, model: str, device: str, offload_device: str = "none", "swap_io_components": swap_io_components, "attention_mode": attention_mode, "torch_compile_args": torch_compile_args, + "dit_tiled": dit_tiled, + "dit_tile_size": dit_tile_size, + "dit_tile_overlap": dit_tile_overlap, "node_id": get_executing_context().node_id, } - return io.NodeOutput(config) \ No newline at end of file + return io.NodeOutput(config) diff --git a/src/interfaces/video_upscaler.py b/src/interfaces/video_upscaler.py index 159ca2dc..8592dcc4 100644 --- a/src/interfaces/video_upscaler.py +++ b/src/interfaces/video_upscaler.py @@ -23,10 +23,15 @@ load_text_embeddings, script_directory ) +from ..core.model_configuration import ( + _evict_claimed_cached_models, + _finalize_claimed_cached_models_for_reuse, +) from ..optimization.memory_manager import ( cleanup_text_embeddings, complete_cleanup, - get_device_list + get_device_list, + set_model_cache_claimed_state, ) # Import ComfyUI progress reporting @@ -267,6 +272,7 @@ def execute(cls, image: torch.Tensor, dit: Dict[str, Any], vae: Dict[str, Any], # Track execution state in local variables (not instance) runner = None ctx = None + cache_context = None pbar = None # Define progress callback as local closure @@ -319,8 +325,42 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: # Use complete_cleanup for all cleanup operations if runner: - complete_cleanup(runner=runner, debug=debug, - dit_cache=dit_cache, vae_cache=vae_cache) + claimed_dit = cache_context.get('cached_dit') if cache_context is not None else None + claimed_vae = cache_context.get('cached_vae') if cache_context is not None else None + refreshed_dit = None + refreshed_vae = None + try: + try: + complete_cleanup( + runner=runner, + debug=debug, + dit_cache=dit_cache, + vae_cache=vae_cache, + ) + if dit_cache or vae_cache: + refreshed_dit, refreshed_vae = _finalize_claimed_cached_models_for_reuse(cache_context, runner, debug) + except Exception: + try: + _evict_claimed_cached_models(cache_context, runner, debug) + except Exception as evict_error: + if debug is not None: + debug.log( + f"Failed to evict claimed cached models while handling prior cleanup/finalize exception: {evict_error}", + level="WARNING", + category="cleanup", + force=True, + ) + raise + finally: + if dit_cache and claimed_dit is not None: + set_model_cache_claimed_state(claimed_dit, False) + if dit_cache and refreshed_dit is not None and refreshed_dit is not claimed_dit: + set_model_cache_claimed_state(refreshed_dit, False) + if vae_cache and claimed_vae is not None: + set_model_cache_claimed_state(claimed_vae, False) + if vae_cache and refreshed_vae is not None and refreshed_vae is not claimed_vae: + set_model_cache_claimed_state(refreshed_vae, False) + runner._seedvr2_execution_active = False # Delete runner only if neither model is cached if not (dit_cache or vae_cache): @@ -375,6 +415,9 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: # TorchCompile args (optional connection, can be None) dit_torch_compile_args = dit.get("torch_compile_args") + dit_tiled = dit.get("dit_tiled", False) + dit_tile_size = max(1, int(dit.get("dit_tile_size", 128))) + dit_tile_overlap = max(0, int(dit.get("dit_tile_overlap", 16))) vae_torch_compile_args = vae.get("torch_compile_args") # Print header @@ -431,11 +474,34 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: decode_tile_size=(decode_tile_size, decode_tile_size), decode_tile_overlap=(decode_tile_overlap, decode_tile_overlap), tile_debug=tile_debug, + dit_tiled=dit_tiled, + dit_tile_size=(dit_tile_size, dit_tile_size), + dit_tile_overlap=(dit_tile_overlap, dit_tile_overlap), attention_mode=attention_mode, torch_compile_args_dit=dit_torch_compile_args, torch_compile_args_vae=vae_torch_compile_args ) + runner._seedvr2_execution_active = True + runner._seedvr2_runner_tainted = False + runner._seedvr2_dit_phase_cleaned = False + runner._seedvr2_vae_phase_cleaned = False + + # If both models were already cached but the runner template had been + # invalidated or missing, cache this freshly configured runner now. + if ( + cache_context is not None + and not cache_context.get('reusing_runner', False) + and cache_context.get('cached_dit') is not None + and cache_context.get('cached_vae') is not None + ): + cache_context['global_cache'].set_runner( + cache_context.get('dit_id'), + cache_context.get('vae_id'), + runner, + debug, + ) + # Store cache context in ctx for use in generation phases ctx['cache_context'] = cache_context @@ -444,6 +510,7 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: debug.log("Loaded text embeddings for DiT", category="dit") debug.log_memory_state("After model preparation", show_tensors=False, detailed_tensors=False) + debug.end_timer("model_preparation", "Model preparation", force=True, show_breakdown=True) # Compute generation info and log start (handles prepending internally) @@ -568,6 +635,9 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: # Print footer debug.print_footer() + if runner is not None: + runner._seedvr2_runner_tainted = False + debug.clear_history() pbar = None ctx = None @@ -575,6 +645,42 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: # V3-compatible return with optional UI preview return io.NodeOutput(sample) - except Exception as e: - cleanup(dit_cache=dit_cache, vae_cache=vae_cache) - raise e \ No newline at end of file + except BaseException: + claimed_dit = cache_context.get('cached_dit') if cache_context is not None else None + claimed_vae = cache_context.get('cached_vae') if cache_context is not None else None + if cache_context is not None: + _evict_claimed_cached_models(cache_context, runner, debug) + if runner is not None and cache_context.get('reusing_runner', False): + try: + cache_context['global_cache'].taint_and_remove_runner( + cache_context.get('dit_id'), + cache_context.get('vae_id'), + debug, + expected_runner=runner, + ) + except Exception as cache_error: + if debug is not None: + debug.log( + f"Failed to evict cached runner while handling prior exception: {cache_error}", + level="WARNING", + category="cleanup", + force=True, + ) + + try: + try: + cleanup(dit_cache=False, vae_cache=False) + except BaseException as cleanup_error: + if debug is not None: + debug.log( + f"Cleanup failed while handling prior exception: {cleanup_error}", + level="WARNING", + category="cleanup", + force=True, + ) + finally: + if claimed_dit is not None: + set_model_cache_claimed_state(claimed_dit, False) + if claimed_vae is not None: + set_model_cache_claimed_state(claimed_vae, False) + raise diff --git a/src/optimization/compatibility.py b/src/optimization/compatibility.py index c462022b..147b8b55 100644 --- a/src/optimization/compatibility.py +++ b/src/optimization/compatibility.py @@ -681,21 +681,86 @@ def _check_conv3d_memory_bug(): # Bfloat16 CUBLAS support +# +# Original behavior performed a CUDA BF16 matmul probe at import time and +# re-raised most CUDA errors. Under WSL / Blackwell / newer PyTorch builds this +# can fail with ``CUDA driver error: unknown error`` while ComfyUI is merely +# importing custom nodes, which prevents the whole SeedVR2 node pack from +# loading. Import-time CUDA probes must be best-effort only. +_SEEDVR2_BFLOAT16_PATCH = "wsl-safe-import-probe-2026-06-21" + + +def _env_flag(name: str, default: bool = False) -> bool: + value = os.environ.get(name) + if value is None: + return default + return value.strip().lower() in ("1", "true", "yes", "y", "on") + + +def _env_flag_set(name: str): + value = os.environ.get(name) + if value is None: + return None + value = value.strip().lower() + if value in ("1", "true", "yes", "y", "on"): + return True + if value in ("0", "false", "no", "n", "off"): + return False + return None + + def _probe_bfloat16_support() -> bool: - if not torch.cuda.is_available(): + """ + Import-safe BF16 capability selection. + + Defaults to float16 without touching CUDA at import time. This keeps the + custom node importable even if the CUDA context/driver is temporarily in a + bad state during ComfyUI startup. + + Environment controls: + SEEDVR2_FORCE_BFLOAT16=1 -> force BF16 on, no probe + SEEDVR2_FORCE_BFLOAT16=0 -> force BF16 off, no probe + SEEDVR2_IMPORT_BFLOAT16_PROBE=1 -> run best-effort CUDA probe at import + """ + forced = _env_flag_set("SEEDVR2_FORCE_BFLOAT16") + if forced is True: + print("[SeedVR2] BF16 forced on via SEEDVR2_FORCE_BFLOAT16=1", flush=True) return True + if forced is False: + print("[SeedVR2] BF16 forced off via SEEDVR2_FORCE_BFLOAT16=0; using float16", flush=True) + return False + + # Safer default: do not allocate CUDA tensors while ComfyUI is importing + # custom nodes. Users who want automatic probing can opt in explicitly. + if not _env_flag("SEEDVR2_IMPORT_BFLOAT16_PROBE", default=False): + print("[SeedVR2] Import-time BF16 CUDA probe skipped; using float16. Set SEEDVR2_IMPORT_BFLOAT16_PROBE=1 to probe.", flush=True) + return False + try: - a = torch.randn(8, 8, dtype=torch.bfloat16, device='cuda:0') - _ = torch.matmul(a, a) - del a - return True - except RuntimeError as e: - if "CUBLAS_STATUS_NOT_SUPPORTED" in str(e): + if not torch.cuda.is_available(): return False - raise + with torch.no_grad(): + a = torch.empty((8, 8), dtype=torch.bfloat16, device="cuda:0") + b = torch.matmul(a, a) + torch.cuda.synchronize() + del a, b + return True + except BaseException as e: + # Never let an optional import-time feature probe kill the node pack. + try: + print( + f"[SeedVR2] BF16 CUDA import probe failed; disabling BF16 for this session: " + f"{type(e).__name__}: {e}", + flush=True, + ) + except Exception: + pass + return False + BFLOAT16_SUPPORTED = _probe_bfloat16_support() COMPUTE_DTYPE = torch.bfloat16 if BFLOAT16_SUPPORTED else torch.float16 +print(f"[SeedVR2] compute dtype selected at import: {COMPUTE_DTYPE}", flush=True) def call_rope_with_stability(method, *args, **kwargs): diff --git a/src/optimization/memory_manager.py b/src/optimization/memory_manager.py index 780c6909..4354b307 100644 --- a/src/optimization/memory_manager.py +++ b/src/optimization/memory_manager.py @@ -11,7 +11,10 @@ import time import psutil import platform -from typing import Tuple, Dict, Any, Optional, List, Union +from typing import Tuple, Dict, Any, Optional, List, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from ..utils.debug import Debug def _device_str(device: Union[torch.device, str]) -> str: @@ -20,6 +23,156 @@ def _device_str(device: Union[torch.device, str]) -> str: return 'MPS' if s.startswith('MPS') else s +def _normalize_device(device: Optional[Union[torch.device, str]]) -> Optional[torch.device]: + """Normalize an optional device spec to torch.device.""" + if device is None: + return None + if isinstance(device, torch.device): + return device + return torch.device(device) + + +def synchronize_device(device: Optional[Union[torch.device, str]], + debug: Optional['Debug'] = None, + reason: Optional[str] = None) -> bool: + """Synchronize a single accelerator device before cleanup or device moves.""" + device = _normalize_device(device) + if device is None or device.type in ('cpu', 'meta'): + return False + + try: + if device.type == 'cuda' and is_cuda_available(): + if debug: + why = f" ({reason})" if reason else "" + debug.log(f"Synchronizing {_device_str(device)}{why}", category="cleanup") + torch.cuda.synchronize(device) + if debug: + why = f" ({reason})" if reason else "" + debug.log(f"Synchronized {_device_str(device)}{why}", category="cleanup") + return True + if device.type == 'mps' and is_mps_available(): + if debug: + why = f" ({reason})" if reason else "" + debug.log(f"Synchronizing MPS{why}", category="cleanup") + torch.mps.synchronize() + if debug: + why = f" ({reason})" if reason else "" + debug.log(f"Synchronized MPS{why}", category="cleanup") + return True + except Exception as e: + if debug: + why = f" ({reason})" if reason else "" + debug.log( + f"Device synchronization failed for {device}{why}: {e}", + level="WARNING", + category="cleanup", + force=True, + ) + return False + + +def _iter_runtime_tensors(value: Any): + """Yield tensors stored in ad-hoc runtime containers like module.memory.""" + stack = [value] + seen = set() + + while stack: + current = stack.pop() + + if torch.is_tensor(current): + yield current + continue + + if isinstance(current, dict): + obj_id = id(current) + if obj_id in seen: + continue + seen.add(obj_id) + stack.extend(current.values()) + continue + + if isinstance(current, (list, tuple, set, frozenset)): + obj_id = id(current) + if obj_id in seen: + continue + seen.add(obj_id) + stack.extend(current) + continue + + +def _clear_runtime_memory_attr(module: Any) -> int: + """Drop a module.memory attribute when it contains runtime tensors.""" + if not hasattr(module, 'memory'): + return 0 + + tensor_count = sum(1 for _ in _iter_runtime_tensors(getattr(module, 'memory', None))) + if tensor_count <= 0: + return 0 + + module.memory = None + return tensor_count + + +def synchronize_model(model: Optional[torch.nn.Module], + debug: Optional['Debug'] = None, + reason: Optional[str] = None) -> int: + """Synchronize all non-CPU/non-meta devices touched by a model.""" + if model is None: + return 0 + + devices = set() + try: + for tensor in model.parameters(): + if tensor is None or not torch.is_tensor(tensor): + continue + device = tensor.device + if device.type not in ('cpu', 'meta'): + devices.add(str(device)) + + for tensor in model.buffers(): + if tensor is None or not torch.is_tensor(tensor): + continue + device = tensor.device + if device.type not in ('cpu', 'meta'): + devices.add(str(device)) + + for module in model.modules(): + for tensor in _iter_runtime_tensors(getattr(module, 'memory', None)): + device = tensor.device + if device.type not in ('cpu', 'meta'): + devices.add(str(device)) + except Exception as e: + if debug: + why = f" ({reason})" if reason else "" + debug.log( + f"Failed to inspect model devices{why}: {e}", + level="WARNING", + category="cleanup", + force=True, + ) + return 0 + + synced = 0 + for device_str in sorted(devices): + if synchronize_device(torch.device(device_str), debug=debug, reason=reason): + synced += 1 + return synced + + +def synchronize_visible_accelerators(debug: Optional['Debug'] = None, + reason: Optional[str] = None) -> int: + """Synchronize all visible accelerator devices before allocator/cache operations.""" + synced = 0 + if is_cuda_available(): + for idx in range(torch.cuda.device_count()): + if synchronize_device(torch.device(f"cuda:{idx}"), debug=debug, reason=reason): + synced += 1 + elif is_mps_available(): + if synchronize_device(torch.device('mps'), debug=debug, reason=reason): + synced += 1 + return synced + + def is_mps_available() -> bool: """Check if MPS (Apple Metal) backend is available.""" return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() @@ -291,7 +444,10 @@ def clear_memory(debug: Optional['Debug'] = None, deep: bool = False, force: boo debug.log(f"Clearing memory caches ({cleanup_mode})...", category="cleanup") # ===== MINIMAL OPERATIONS (Always performed) ===== - # Step 1: Clear GPU caches - Fast operations (~1-5ms) + # Step 1: Synchronize devices before touching allocators / caches + synchronize_visible_accelerators(debug=debug, reason="before allocator cache clearing") + + # Step 2: Clear GPU caches - Fast operations (~1-5ms) if debug: debug.start_timer(gpu_timer) @@ -456,11 +612,8 @@ def clear_rope_lru_caches(model: Optional[torch.nn.Module], debug: Optional['Deb def release_tensor_memory(tensor: Optional[torch.Tensor]) -> None: - """Release tensor memory from any device (CPU/CUDA/MPS)""" + """Release tensor references without invalidating accelerator-backed storage.""" if tensor is not None and torch.is_tensor(tensor): - # Release storage for all devices (CPU, CUDA, MPS) - if tensor.numel() > 0: - tensor.data.set_() tensor.grad = None @@ -533,7 +686,7 @@ def cleanup_text_embeddings(ctx: Dict[str, Any], debug: Optional['Debug'] = None names.append(key) if embeddings: - release_text_embeddings(embeddings, names, debug) + release_text_embeddings(*embeddings, debug=debug, names=names) if debug: debug.log(f"Cleaned up text embeddings: {', '.join(names)}", category="cleanup") @@ -543,44 +696,148 @@ def cleanup_text_embeddings(ctx: Dict[str, Any], debug: Optional['Debug'] = None def release_model_memory(model: Optional[torch.nn.Module], debug: Optional['Debug'] = None) -> None: """ - Release all GPU/MPS memory from model in-place without CPU transfer. - - Args: - model: PyTorch model to release memory from - debug: Optional debug instance for logging + Release model-owned references without force-invalidating accelerator storage. """ if model is None: return try: - # Clear gradients first + synchronize_model(model=model, debug=debug, reason="before releasing model references") model.zero_grad(set_to_none=True) - - # Release GPU memory directly without CPU transfer - released_params = 0 - released_buffers = 0 - - for param in model.parameters(): - if param.is_cuda or param.is_mps: - if param.numel() > 0: - param.data.set_() - released_params += 1 - param.grad = None - - for buffer in model.buffers(): - if buffer.is_cuda or buffer.is_mps: - if buffer.numel() > 0: - buffer.data.set_() - released_buffers += 1 - - if debug and (released_params > 0 or released_buffers > 0): - debug.log(f"Released memory from {released_params} params and {released_buffers} buffers", category="success") + + cleared_memory_buffers = 0 + cleared_runtime_tensors = 0 + for module in model.modules(): + cleared_tensors = _clear_runtime_memory_attr(module) + if cleared_tensors > 0: + cleared_memory_buffers += 1 + cleared_runtime_tensors += cleared_tensors + + if debug: + debug.log( + f"Released model references and cleared {cleared_memory_buffers} runtime memory buffers " + f"({cleared_runtime_tensors} tensors)", + category="success", + ) except (AttributeError, RuntimeError) as e: if debug: debug.log(f"Failed to release model memory: {e}", level="WARNING", category="memory", force=True) +def iter_model_wrapper_chain(model: Optional[torch.nn.Module]): + """Yield a model plus any known wrappers/base modules reachable through unwrap attributes.""" + if model is None: + return + + stack = [model] + seen = set() + + while stack: + current = stack.pop() + if current is None: + continue + + obj_id = id(current) + if obj_id in seen: + continue + seen.add(obj_id) + yield current + + for attr in ('_orig_mod', 'dit_model'): + child = getattr(current, attr, None) + if child is not None: + stack.append(child) + + +def set_model_cache_cold_state(model: Optional[torch.nn.Module], is_cold: bool) -> None: + """Mark a cached model and all known wrappers/base objects as cold or hot.""" + for wrapped_model in iter_model_wrapper_chain(model): + setattr(wrapped_model, '_seedvr2_cold_cache', is_cold) + + +def set_model_cache_claimed_state(model: Optional[torch.nn.Module], is_claimed: bool) -> None: + """Mark a cached model and all known wrappers/base objects as claimed or free.""" + for wrapped_model in iter_model_wrapper_chain(model): + setattr(wrapped_model, '_seedvr2_cache_claimed', is_claimed) + + +def is_model_cache_cold(model: Optional[torch.nn.Module]) -> bool: + """Return True when a cached model is in its cold reusable canonical form.""" + return any(getattr(wrapped_model, '_seedvr2_cold_cache', False) for wrapped_model in iter_model_wrapper_chain(model)) + + +def is_model_cache_claimed(model: Optional[torch.nn.Module]) -> bool: + """Return True when a cached model is already leased to an in-flight execution.""" + return any(getattr(wrapped_model, '_seedvr2_cache_claimed', False) for wrapped_model in iter_model_wrapper_chain(model)) + + +def _copy_model_cache_metadata(source: Any, target: Any, attrs: Tuple[str, ...]) -> None: + """Preserve cache metadata when normalizing wrapped models back to their base form.""" + for attr in attrs: + if hasattr(source, attr): + setattr(target, attr, getattr(source, attr)) + + +def _normalize_cached_dit_model(model: torch.nn.Module, debug: Optional['Debug'] = None) -> torch.nn.Module: + """Return a cold canonical DiT model with compile/wrapper state removed.""" + while True: + changed = False + + if hasattr(model, '_orig_mod'): + if debug: + debug.log("Removing torch.compile wrapper from DiT for cold cache storage", category="cleanup") + base_model = model._orig_mod + _copy_model_cache_metadata( + model, + base_model, + ('_model_name', '_config_compile', '_config_swap', '_config_attn'), + ) + model = base_model + changed = True + + if hasattr(model, 'dit_model'): + if debug: + debug.log("Removing DiT compatibility wrapper for cold cache storage", category="cleanup") + base_model = model.dit_model + _copy_model_cache_metadata( + model, + base_model, + ('_model_name', '_config_compile', '_config_swap', '_config_attn'), + ) + model = base_model + changed = True + + if not changed: + break + + release_model_memory(model=model, debug=debug) + set_model_cache_cold_state(model, True) + return model + + +def _normalize_cached_vae_model(model: torch.nn.Module, debug: Optional['Debug'] = None) -> torch.nn.Module: + """Return a cold canonical VAE model with compiled submodules removed.""" + if hasattr(model, 'encoder') and hasattr(model.encoder, '_orig_mod'): + if debug: + debug.log("Removing torch.compile wrapper from VAE encoder for cold cache storage", category="cleanup") + model.encoder = model.encoder._orig_mod + + if hasattr(model, 'decoder') and hasattr(model.decoder, '_orig_mod'): + if debug: + debug.log("Removing torch.compile wrapper from VAE decoder for cold cache storage", category="cleanup") + model.decoder = model.decoder._orig_mod + + if hasattr(model, 'debug'): + model.debug = None + if hasattr(model, 'tensor_offload_device'): + model.tensor_offload_device = None + + release_model_memory(model=model, debug=debug) + set_model_cache_cold_state(model, True) + return model + + def manage_tensor( tensor: torch.Tensor, target_device: torch.device, @@ -694,17 +951,21 @@ def manage_model_device(model: torch.nn.Module, target_device: torch.device, mod if runner and model_name == "DiT": # Import here to avoid circular dependency from .blockswap import is_blockswap_enabled + actual_model = getattr(model, "dit_model", model) # Check if BlockSwap config exists and is enabled has_blockswap_config = ( hasattr(runner, '_dit_block_swap_config') and is_blockswap_enabled(runner._dit_block_swap_config) ) + has_blockswap_runtime_state = ( + getattr(runner, '_blockswap_active', False) or + hasattr(actual_model, '_block_swap_config') or + hasattr(actual_model, '_original_to') or + getattr(actual_model, '_blockswap_bypass_protection', False) + ) - if has_blockswap_config: + if has_blockswap_config and has_blockswap_runtime_state: is_blockswap_model = True - # Get the actual model (handle CompatibleDiT wrapper) - if hasattr(model, "dit_model"): - actual_model = model.dit_model # Get current device try: @@ -783,6 +1044,8 @@ def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module, if debug: debug.start_timer(timer_name) + synchronize_model(model=model, debug=debug, reason=f"before moving {model_name} to {_device_str(target_device)}") + # Move entire model to target offload device model.to(target_device) model.zero_grad(set_to_none=True) @@ -817,6 +1080,8 @@ def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module, if debug: debug.start_timer(timer_name) + synchronize_model(model=model, debug=debug, reason=f"before restoring {model_name} BlockSwap placement") + # Restore blocks to their configured devices if hasattr(model, "blocks") and hasattr(model, "blocks_to_swap"): # Use configured offload_device from BlockSwap config @@ -907,6 +1172,8 @@ def _standard_model_movement(model: torch.nn.Module, current_device: torch.devic if debug: debug.start_timer(timer_name) + synchronize_model(model=model, debug=debug, reason=f"before moving {model_name} to {_device_str(target_device)}") + # Move model and clear gradients model.to(target_device) model.zero_grad(set_to_none=True) @@ -914,13 +1181,17 @@ def _standard_model_movement(model: torch.nn.Module, current_device: torch.devic # Clear VAE memory buffers when moving to CPU if target_type == 'cpu' and model_name == "VAE": cleared_count = 0 + cleared_tensor_count = 0 for module in model.modules(): - if hasattr(module, 'memory') and module.memory is not None: - if torch.is_tensor(module.memory) and (module.memory.is_cuda or module.memory.is_mps): - module.memory = None - cleared_count += 1 + cleared_tensors = _clear_runtime_memory_attr(module) + if cleared_tensors > 0: + cleared_count += 1 + cleared_tensor_count += cleared_tensors if cleared_count > 0 and debug: - debug.log(f"Cleared {cleared_count} VAE memory buffers", category="success") + debug.log( + f"Cleared {cleared_count} VAE memory buffers ({cleared_tensor_count} tensors)", + category="success", + ) # End timer if debug: @@ -1023,6 +1294,8 @@ def cleanup_dit(runner: Any, debug: Optional['Debug'] = None, cache_model: bool if debug: debug.log("Cleaning up DiT components", category="cleanup") + + synchronize_model(getattr(runner, 'dit', None), debug=debug, reason="before DiT cleanup") # 1. Clear DiT-specific runtime caches first if hasattr(runner, 'dit'): @@ -1071,6 +1344,10 @@ def cleanup_dit(runner: Any, debug: Optional['Debug'] = None, cache_model: bool # Import here to avoid circular dependency from .blockswap import cleanup_blockswap cleanup_blockswap(runner=runner, keep_state_for_cache=cache_model) + + if cache_model and runner.dit is not None: + set_model_cache_cold_state(runner.dit, False) + runner._seedvr2_dit_phase_cleaned = True # 4. Complete cleanup if not caching if not cache_model: @@ -1087,6 +1364,9 @@ def cleanup_dit(runner: Any, debug: Optional['Debug'] = None, cache_model: bool if hasattr(runner, '_dit_attention_mode'): delattr(runner, '_dit_attention_mode') + if not cache_model: + runner._seedvr2_dit_phase_cleaned = True + # 5. Clear DiT temporary attributes (should be already cleared in materialize_model) runner._dit_checkpoint = None runner._dit_dtype_override = None @@ -1112,6 +1392,8 @@ def cleanup_vae(runner: Any, debug: Optional['Debug'] = None, cache_model: bool if debug: debug.log("Cleaning up VAE components", category="cleanup") + + synchronize_model(getattr(runner, 'vae', None), debug=debug, reason="before VAE cleanup") # 1. Clear VAE-specific temporary attributes if hasattr(runner, 'vae'): @@ -1143,6 +1425,10 @@ def cleanup_vae(runner: Any, debug: Optional['Debug'] = None, cache_model: bool debug.log("VAE on meta device - keeping structure for cache", category="cleanup") except StopIteration: pass + + if cache_model and runner.vae is not None: + set_model_cache_cold_state(runner.vae, False) + runner._seedvr2_vae_phase_cleaned = True # 3. Complete cleanup if not caching if not cache_model: @@ -1157,6 +1443,9 @@ def cleanup_vae(runner: Any, debug: Optional['Debug'] = None, cache_model: bool if hasattr(runner, '_vae_tiling_config'): delattr(runner, '_vae_tiling_config') + if not cache_model: + runner._seedvr2_vae_phase_cleaned = True + # 3. Clear VAE temporary attributes (should be already cleared in materialize_model) runner._vae_checkpoint = None runner._vae_dtype_override = None @@ -1194,10 +1483,12 @@ def complete_cleanup(runner: Any, debug: Optional['Debug'] = None, dit_cache: bo # 1. Cleanup any remaining models if they still exist # (This handles cases where phases were skipped or errored) if hasattr(runner, 'dit') and runner.dit is not None: - cleanup_dit(runner=runner, debug=debug, cache_model=dit_cache) + if not (dit_cache and getattr(runner, '_seedvr2_dit_phase_cleaned', False)): + cleanup_dit(runner=runner, debug=debug, cache_model=dit_cache) if hasattr(runner, 'vae') and runner.vae is not None: - cleanup_vae(runner=runner, debug=debug, cache_model=vae_cache) + if not (vae_cache and getattr(runner, '_seedvr2_vae_phase_cleaned', False)): + cleanup_vae(runner=runner, debug=debug, cache_model=vae_cache) # 2. Clear remaining runtime caches clear_runtime_caches(runner=runner, debug=debug) @@ -1228,4 +1519,4 @@ def complete_cleanup(runner: Any, debug: Optional['Debug'] = None, dit_cache: bo debug.log(f"Models cached for next run: {models_str}", category="cache", force=True) if debug: - debug.log(f"Completed {cleanup_type}", category="success") \ No newline at end of file + debug.log(f"Completed {cleanup_type}", category="success")