From b9a3816e8dd9eebc7595ab27313e02f59871a55a Mon Sep 17 00:00:00 2001 From: xmarre Date: Sun, 15 Mar 2026 09:54:35 +0100 Subject: [PATCH 01/29] Investigate CUDA cleanup hang --- src/core/model_cache.py | 33 ++++++- src/core/model_configuration.py | 25 ++++- src/interfaces/video_upscaler.py | 33 ++++++- src/optimization/memory_manager.py | 154 ++++++++++++++++++++++------- 4 files changed, 201 insertions(+), 44 deletions(-) diff --git a/src/core/model_cache.py b/src/core/model_cache.py index 2c54ea51..ad502304 100644 --- a/src/core/model_cache.py +++ b/src/core/model_cache.py @@ -145,20 +145,43 @@ def set_runner(self, dit_id: Optional[int], vae_id: Optional[int], debug: Optional debug instance for logging Returns: - Runner key string (format: "dit_id+vae_id") if cached successfully, - None if either ID is None or runner already cached + Runner key string (format: "dit_id+vae_id") if cached successfully, + None if either ID is None """ if dit_id is None or vae_id is None: return None runner_key = f"{dit_id}+{vae_id}" - if runner_key not in self._runner_templates: + existing = self._runner_templates.get(runner_key) + if existing is runner: + return runner_key + + replace_existing = False + if existing is not None: + replace_existing = getattr(existing, '_seedvr2_runner_tainted', False) + + if existing is None or replace_existing: self._runner_templates[runner_key] = runner if debug: - debug.log(f"Runner template cached in memory: nodes {runner_key}", category="cache", force=True) + action = "replaced" if replace_existing else "cached" + debug.log(f"Runner template {action} in memory: nodes {runner_key}", category="cache", force=True) return runner_key return None + + def remove_runner(self, dit_id: Optional[int], vae_id: Optional[int], + debug: Optional['Debug'] = None) -> bool: + """Remove a cached runner template for the given DiT/VAE node pair.""" + if dit_id is None or vae_id is None: + return False + + runner_key = f"{dit_id}+{vae_id}" + if runner_key in self._runner_templates: + del self._runner_templates[runner_key] + if debug: + debug.log(f"Removed cached runner template: nodes {runner_key}", category="cache", force=True) + return True + return False def remove_dit(self, dit_config: Dict[str, Any], debug: Optional['Debug'] = None) -> bool: """ @@ -236,4 +259,4 @@ def remove_vae(self, vae_config: Dict[str, Any], debug: Optional['Debug'] = None def get_global_cache() -> GlobalModelCache: """Get the global model cache instance.""" - return _global_cache \ No newline at end of file + return _global_cache diff --git a/src/core/model_configuration.py b/src/core/model_configuration.py index 61297627..cff8bc24 100644 --- a/src/core/model_configuration.py +++ b/src/core/model_configuration.py @@ -664,6 +664,28 @@ def _acquire_runner( ) if template: + runner_key = f"{cache_context['dit_id']}+{cache_context['vae_id']}" + + if getattr(template, '_seedvr2_execution_active', False): + debug.log( + f"Cached runner template still marked active: nodes {runner_key}; creating a fresh runner", + level="WARNING", + category="cache", + force=True, + ) + cache_context['global_cache'].remove_runner(cache_context['dit_id'], cache_context['vae_id'], debug) + return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) + + if getattr(template, '_seedvr2_runner_tainted', False): + debug.log( + f"Cached runner template was tainted by a prior failed/interrupted run: nodes {runner_key}; creating a fresh runner", + level="WARNING", + category="cache", + force=True, + ) + cache_context['global_cache'].remove_runner(cache_context['dit_id'], cache_context['vae_id'], debug) + return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) + # We have a template - check if we can use it current_dit = getattr(template, '_dit_model_name', None) current_vae = getattr(template, '_vae_model_name', None) @@ -671,7 +693,6 @@ def _acquire_runner( if models_match: # Perfect match - reuse template directly - runner_key = f"{cache_context['dit_id']}+{cache_context['vae_id']}" debug.log(f"Reusing cached runner template: nodes {runner_key}", category="reuse", force=True) cache_context['reusing_runner'] = True return template @@ -1477,4 +1498,4 @@ def _propagate_debug_to_modules(module: torch.nn.Module, debug: 'Debug') -> None for name, submodule in module.named_modules(): if submodule.__class__.__name__ in target_modules: if not hasattr(submodule, 'debug'): # Only set if not already present - submodule.debug = debug \ No newline at end of file + submodule.debug = debug diff --git a/src/interfaces/video_upscaler.py b/src/interfaces/video_upscaler.py index 159ca2dc..52ca7545 100644 --- a/src/interfaces/video_upscaler.py +++ b/src/interfaces/video_upscaler.py @@ -267,6 +267,7 @@ def execute(cls, image: torch.Tensor, dit: Dict[str, Any], vae: Dict[str, Any], # Track execution state in local variables (not instance) runner = None ctx = None + cache_context = None pbar = None # Define progress callback as local closure @@ -319,8 +320,15 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: # Use complete_cleanup for all cleanup operations if runner: - complete_cleanup(runner=runner, debug=debug, - dit_cache=dit_cache, vae_cache=vae_cache) + try: + complete_cleanup( + runner=runner, + debug=debug, + dit_cache=dit_cache, + vae_cache=vae_cache, + ) + finally: + runner._seedvr2_execution_active = False # Delete runner only if neither model is cached if not (dit_cache or vae_cache): @@ -436,6 +444,9 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: torch_compile_args_vae=vae_torch_compile_args ) + runner._seedvr2_execution_active = True + runner._seedvr2_runner_tainted = False + # Store cache context in ctx for use in generation phases ctx['cache_context'] = cache_context @@ -568,6 +579,9 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: # Print footer debug.print_footer() + if runner is not None: + runner._seedvr2_runner_tainted = False + debug.clear_history() pbar = None ctx = None @@ -575,6 +589,17 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: # V3-compatible return with optional UI preview return io.NodeOutput(sample) - except Exception as e: + except BaseException: + if runner is not None: + runner._seedvr2_runner_tainted = True + runner._seedvr2_execution_active = False + + if cache_context is not None: + cache_context['global_cache'].remove_runner( + cache_context.get('dit_id'), + cache_context.get('vae_id'), + debug, + ) + cleanup(dit_cache=dit_cache, vae_cache=vae_cache) - raise e \ No newline at end of file + raise diff --git a/src/optimization/memory_manager.py b/src/optimization/memory_manager.py index 780c6909..2e1821fc 100644 --- a/src/optimization/memory_manager.py +++ b/src/optimization/memory_manager.py @@ -20,6 +20,95 @@ def _device_str(device: Union[torch.device, str]) -> str: return 'MPS' if s.startswith('MPS') else s +def _normalize_device(device: Optional[Union[torch.device, str]]) -> Optional[torch.device]: + """Normalize an optional device spec to torch.device.""" + if device is None: + return None + if isinstance(device, torch.device): + return device + return torch.device(device) + + +def synchronize_device(device: Optional[Union[torch.device, str]], + debug: Optional['Debug'] = None, + reason: Optional[str] = None) -> bool: + """Synchronize a single accelerator device before cleanup or device moves.""" + device = _normalize_device(device) + if device is None or device.type in ('cpu', 'meta'): + return False + + try: + if device.type == 'cuda' and is_cuda_available(): + torch.cuda.synchronize(device) + if debug: + why = f" ({reason})" if reason else "" + debug.log(f"Synchronized {_device_str(device)}{why}", category="cleanup") + return True + if device.type == 'mps' and is_mps_available(): + torch.mps.synchronize() + if debug: + why = f" ({reason})" if reason else "" + debug.log(f"Synchronized MPS{why}", category="cleanup") + return True + except Exception as e: + if debug: + why = f" ({reason})" if reason else "" + debug.log( + f"Device synchronization failed for {device}{why}: {e}", + level="WARNING", + category="cleanup", + force=True, + ) + return False + + +def synchronize_model(model: Optional[torch.nn.Module], + debug: Optional['Debug'] = None, + reason: Optional[str] = None) -> int: + """Synchronize all non-CPU/non-meta devices touched by a model.""" + if model is None: + return 0 + + devices = set() + try: + for tensor in list(model.parameters()) + list(model.buffers()): + if tensor is None or not torch.is_tensor(tensor): + continue + device = tensor.device + if device.type not in ('cpu', 'meta'): + devices.add(str(device)) + except Exception as e: + if debug: + why = f" ({reason})" if reason else "" + debug.log( + f"Failed to inspect model devices{why}: {e}", + level="WARNING", + category="cleanup", + force=True, + ) + return 0 + + synced = 0 + for device_str in sorted(devices): + if synchronize_device(torch.device(device_str), debug=debug, reason=reason): + synced += 1 + return synced + + +def synchronize_visible_accelerators(debug: Optional['Debug'] = None, + reason: Optional[str] = None) -> int: + """Synchronize all visible accelerator devices before allocator/cache operations.""" + synced = 0 + if is_cuda_available(): + for idx in range(torch.cuda.device_count()): + if synchronize_device(torch.device(f"cuda:{idx}"), debug=debug, reason=reason): + synced += 1 + elif is_mps_available(): + if synchronize_device(torch.device('mps'), debug=debug, reason=reason): + synced += 1 + return synced + + def is_mps_available() -> bool: """Check if MPS (Apple Metal) backend is available.""" return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() @@ -291,7 +380,10 @@ def clear_memory(debug: Optional['Debug'] = None, deep: bool = False, force: boo debug.log(f"Clearing memory caches ({cleanup_mode})...", category="cleanup") # ===== MINIMAL OPERATIONS (Always performed) ===== - # Step 1: Clear GPU caches - Fast operations (~1-5ms) + # Step 1: Synchronize devices before touching allocators / caches + synchronize_visible_accelerators(debug=debug, reason="before allocator cache clearing") + + # Step 2: Clear GPU caches - Fast operations (~1-5ms) if debug: debug.start_timer(gpu_timer) @@ -456,11 +548,8 @@ def clear_rope_lru_caches(model: Optional[torch.nn.Module], debug: Optional['Deb def release_tensor_memory(tensor: Optional[torch.Tensor]) -> None: - """Release tensor memory from any device (CPU/CUDA/MPS)""" + """Release tensor references without invalidating accelerator-backed storage.""" if tensor is not None and torch.is_tensor(tensor): - # Release storage for all devices (CPU, CUDA, MPS) - if tensor.numel() > 0: - tensor.data.set_() tensor.grad = None @@ -533,7 +622,7 @@ def cleanup_text_embeddings(ctx: Dict[str, Any], debug: Optional['Debug'] = None names.append(key) if embeddings: - release_text_embeddings(embeddings, names, debug) + release_text_embeddings(*embeddings, debug=debug, names=names) if debug: debug.log(f"Cleaned up text embeddings: {', '.join(names)}", category="cleanup") @@ -543,38 +632,26 @@ def cleanup_text_embeddings(ctx: Dict[str, Any], debug: Optional['Debug'] = None def release_model_memory(model: Optional[torch.nn.Module], debug: Optional['Debug'] = None) -> None: """ - Release all GPU/MPS memory from model in-place without CPU transfer. - - Args: - model: PyTorch model to release memory from - debug: Optional debug instance for logging + Release model-owned references without force-invalidating accelerator storage. """ if model is None: return try: - # Clear gradients first + synchronize_model(model=model, debug=debug, reason="before releasing model references") model.zero_grad(set_to_none=True) - - # Release GPU memory directly without CPU transfer - released_params = 0 - released_buffers = 0 - - for param in model.parameters(): - if param.is_cuda or param.is_mps: - if param.numel() > 0: - param.data.set_() - released_params += 1 - param.grad = None - - for buffer in model.buffers(): - if buffer.is_cuda or buffer.is_mps: - if buffer.numel() > 0: - buffer.data.set_() - released_buffers += 1 - - if debug and (released_params > 0 or released_buffers > 0): - debug.log(f"Released memory from {released_params} params and {released_buffers} buffers", category="success") + + cleared_memory_buffers = 0 + for module in model.modules(): + if hasattr(module, 'memory') and torch.is_tensor(getattr(module, 'memory')): + module.memory = None + cleared_memory_buffers += 1 + + if debug: + debug.log( + f"Released model references and cleared {cleared_memory_buffers} runtime memory buffers", + category="success", + ) except (AttributeError, RuntimeError) as e: if debug: @@ -783,6 +860,8 @@ def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module, if debug: debug.start_timer(timer_name) + synchronize_model(model=model, debug=debug, reason=f"before moving {model_name} to {_device_str(target_device)}") + # Move entire model to target offload device model.to(target_device) model.zero_grad(set_to_none=True) @@ -817,6 +896,8 @@ def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module, if debug: debug.start_timer(timer_name) + synchronize_model(model=model, debug=debug, reason=f"before restoring {model_name} BlockSwap placement") + # Restore blocks to their configured devices if hasattr(model, "blocks") and hasattr(model, "blocks_to_swap"): # Use configured offload_device from BlockSwap config @@ -907,6 +988,8 @@ def _standard_model_movement(model: torch.nn.Module, current_device: torch.devic if debug: debug.start_timer(timer_name) + synchronize_model(model=model, debug=debug, reason=f"before moving {model_name} to {_device_str(target_device)}") + # Move model and clear gradients model.to(target_device) model.zero_grad(set_to_none=True) @@ -1023,6 +1106,8 @@ def cleanup_dit(runner: Any, debug: Optional['Debug'] = None, cache_model: bool if debug: debug.log("Cleaning up DiT components", category="cleanup") + + synchronize_model(getattr(runner, 'dit', None), debug=debug, reason="before DiT cleanup") # 1. Clear DiT-specific runtime caches first if hasattr(runner, 'dit'): @@ -1112,6 +1197,8 @@ def cleanup_vae(runner: Any, debug: Optional['Debug'] = None, cache_model: bool if debug: debug.log("Cleaning up VAE components", category="cleanup") + + synchronize_model(getattr(runner, 'vae', None), debug=debug, reason="before VAE cleanup") # 1. Clear VAE-specific temporary attributes if hasattr(runner, 'vae'): @@ -1210,6 +1297,7 @@ def complete_cleanup(runner: Any, debug: Optional['Debug'] = None, dit_cache: bo runner._vae_model_name = None # 4. Final memory cleanup + synchronize_visible_accelerators(debug=debug, reason="before final allocator cleanup") clear_memory(debug=debug, deep=True, force=True, timer_name="complete_cleanup") # 5. Clear cuBLAS workspaces @@ -1228,4 +1316,4 @@ def complete_cleanup(runner: Any, debug: Optional['Debug'] = None, dit_cache: bo debug.log(f"Models cached for next run: {models_str}", category="cache", force=True) if debug: - debug.log(f"Completed {cleanup_type}", category="success") \ No newline at end of file + debug.log(f"Completed {cleanup_type}", category="success") From 3a5411d92ca72b9343835c036eafd3acdce3b8a2 Mon Sep 17 00:00:00 2001 From: xmarre Date: Sun, 15 Mar 2026 10:36:36 +0100 Subject: [PATCH 02/29] Fix CUDA cleanup reuse hangs --- src/core/model_cache.py | 5 +++- src/core/model_configuration.py | 13 +++++++- src/interfaces/video_upscaler.py | 48 +++++++++++++++++++++++++----- src/optimization/memory_manager.py | 11 ++++++- 4 files changed, 66 insertions(+), 11 deletions(-) diff --git a/src/core/model_cache.py b/src/core/model_cache.py index ad502304..3c8c05d0 100644 --- a/src/core/model_cache.py +++ b/src/core/model_cache.py @@ -3,9 +3,12 @@ Enables independent DiT and VAE model sharing across multiple upscaler node instances """ -from typing import Dict, Any, Optional, Tuple +from typing import Dict, Any, Optional, Tuple, TYPE_CHECKING from ..optimization.memory_manager import release_model_memory +if TYPE_CHECKING: + from ..utils.debug import Debug + class GlobalModelCache: """ diff --git a/src/core/model_configuration.py b/src/core/model_configuration.py index cff8bc24..fd92aa84 100644 --- a/src/core/model_configuration.py +++ b/src/core/model_configuration.py @@ -697,7 +697,18 @@ def _acquire_runner( cache_context['reusing_runner'] = True return template else: - # Template exists but models changed and no cached models - create new + debug.log( + f"Cached runner template models no longer match: nodes {runner_key} " + f"({current_dit}/{current_vae} -> {dit_model}/{vae_model}); creating a fresh runner", + level="WARNING", + category="cache", + force=True, + ) + cache_context['global_cache'].remove_runner( + cache_context['dit_id'], + cache_context['vae_id'], + debug, + ) return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) else: # No template - create new runner diff --git a/src/interfaces/video_upscaler.py b/src/interfaces/video_upscaler.py index 52ca7545..10a4c32e 100644 --- a/src/interfaces/video_upscaler.py +++ b/src/interfaces/video_upscaler.py @@ -447,6 +447,21 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: runner._seedvr2_execution_active = True runner._seedvr2_runner_tainted = False + # If both models were already cached but the runner template had been + # invalidated or missing, cache this freshly configured runner now. + if ( + cache_context is not None + and not cache_context.get('reusing_runner', False) + and cache_context.get('cached_dit') is not None + and cache_context.get('cached_vae') is not None + ): + cache_context['global_cache'].set_runner( + cache_context.get('dit_id'), + cache_context.get('vae_id'), + runner, + debug, + ) + # Store cache context in ctx for use in generation phases ctx['cache_context'] = cache_context @@ -592,14 +607,31 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: except BaseException: if runner is not None: runner._seedvr2_runner_tainted = True - runner._seedvr2_execution_active = False if cache_context is not None: - cache_context['global_cache'].remove_runner( - cache_context.get('dit_id'), - cache_context.get('vae_id'), - debug, - ) - - cleanup(dit_cache=dit_cache, vae_cache=vae_cache) + try: + cache_context['global_cache'].remove_runner( + cache_context.get('dit_id'), + cache_context.get('vae_id'), + debug, + ) + except Exception as cache_error: + if debug is not None: + debug.log( + f"Failed to evict cached runner while handling prior exception: {cache_error}", + level="WARNING", + category="cleanup", + force=True, + ) + + try: + cleanup(dit_cache=dit_cache, vae_cache=vae_cache) + except BaseException as cleanup_error: + if debug is not None: + debug.log( + f"Cleanup failed while handling prior exception: {cleanup_error}", + level="WARNING", + category="cleanup", + force=True, + ) raise diff --git a/src/optimization/memory_manager.py b/src/optimization/memory_manager.py index 2e1821fc..29138435 100644 --- a/src/optimization/memory_manager.py +++ b/src/optimization/memory_manager.py @@ -11,7 +11,10 @@ import time import psutil import platform -from typing import Tuple, Dict, Any, Optional, List, Union +from typing import Tuple, Dict, Any, Optional, List, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from ..utils.debug import Debug def _device_str(device: Union[torch.device, str]) -> str: @@ -39,12 +42,18 @@ def synchronize_device(device: Optional[Union[torch.device, str]], try: if device.type == 'cuda' and is_cuda_available(): + if debug: + why = f" ({reason})" if reason else "" + debug.log(f"Synchronizing {_device_str(device)}{why}", category="cleanup") torch.cuda.synchronize(device) if debug: why = f" ({reason})" if reason else "" debug.log(f"Synchronized {_device_str(device)}{why}", category="cleanup") return True if device.type == 'mps' and is_mps_available(): + if debug: + why = f" ({reason})" if reason else "" + debug.log(f"Synchronizing MPS{why}", category="cleanup") torch.mps.synchronize() if debug: why = f" ({reason})" if reason else "" From 516da73a704440e506fb3e14deb4e1d91e1daa0b Mon Sep 17 00:00:00 2001 From: xmarre Date: Sun, 15 Mar 2026 10:59:38 +0100 Subject: [PATCH 03/29] Fix runner template cache handling --- src/core/model_cache.py | 31 +++++++++--- src/core/model_configuration.py | 18 +++++-- src/interfaces/video_upscaler.py | 1 + src/optimization/memory_manager.py | 80 ++++++++++++++++++++++++++---- 4 files changed, 110 insertions(+), 20 deletions(-) diff --git a/src/core/model_cache.py b/src/core/model_cache.py index 3c8c05d0..33e586a5 100644 --- a/src/core/model_cache.py +++ b/src/core/model_cache.py @@ -173,18 +173,35 @@ def set_runner(self, dit_id: Optional[int], vae_id: Optional[int], return None def remove_runner(self, dit_id: Optional[int], vae_id: Optional[int], - debug: Optional['Debug'] = None) -> bool: - """Remove a cached runner template for the given DiT/VAE node pair.""" + debug: Optional['Debug'] = None, + expected_runner: Optional[Any] = None) -> bool: + """Remove a cached runner template for the given DiT/VAE node pair. + + If expected_runner is provided, only remove the cache entry when the + currently stored runner is that exact object. + """ if dit_id is None or vae_id is None: return False runner_key = f"{dit_id}+{vae_id}" - if runner_key in self._runner_templates: - del self._runner_templates[runner_key] + cached_runner = self._runner_templates.get(runner_key) + if cached_runner is None: + return False + + if expected_runner is not None and cached_runner is not expected_runner: if debug: - debug.log(f"Removed cached runner template: nodes {runner_key}", category="cache", force=True) - return True - return False + debug.log( + f"Skipped cached runner removal for nodes {runner_key}: cache entry no longer matches expected runner", + level="WARNING", + category="cache", + force=True, + ) + return False + + del self._runner_templates[runner_key] + if debug: + debug.log(f"Removed cached runner template: nodes {runner_key}", category="cache", force=True) + return True def remove_dit(self, dit_config: Dict[str, Any], debug: Optional['Debug'] = None) -> bool: """ diff --git a/src/core/model_configuration.py b/src/core/model_configuration.py index fd92aa84..0462bd8c 100644 --- a/src/core/model_configuration.py +++ b/src/core/model_configuration.py @@ -673,7 +673,12 @@ def _acquire_runner( category="cache", force=True, ) - cache_context['global_cache'].remove_runner(cache_context['dit_id'], cache_context['vae_id'], debug) + cache_context['global_cache'].remove_runner( + cache_context['dit_id'], + cache_context['vae_id'], + debug, + expected_runner=template, + ) return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) if getattr(template, '_seedvr2_runner_tainted', False): @@ -683,7 +688,12 @@ def _acquire_runner( category="cache", force=True, ) - cache_context['global_cache'].remove_runner(cache_context['dit_id'], cache_context['vae_id'], debug) + cache_context['global_cache'].remove_runner( + cache_context['dit_id'], + cache_context['vae_id'], + debug, + expected_runner=template, + ) return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) # We have a template - check if we can use it @@ -708,6 +718,7 @@ def _acquire_runner( cache_context['dit_id'], cache_context['vae_id'], debug, + expected_runner=template, ) return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) else: @@ -1508,5 +1519,4 @@ def _propagate_debug_to_modules(module: torch.nn.Module, debug: 'Debug') -> None for name, submodule in module.named_modules(): if submodule.__class__.__name__ in target_modules: - if not hasattr(submodule, 'debug'): # Only set if not already present - submodule.debug = debug + submodule.debug = debug diff --git a/src/interfaces/video_upscaler.py b/src/interfaces/video_upscaler.py index 10a4c32e..5685a622 100644 --- a/src/interfaces/video_upscaler.py +++ b/src/interfaces/video_upscaler.py @@ -614,6 +614,7 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: cache_context.get('dit_id'), cache_context.get('vae_id'), debug, + expected_runner=runner, ) except Exception as cache_error: if debug is not None: diff --git a/src/optimization/memory_manager.py b/src/optimization/memory_manager.py index 29138435..c822886b 100644 --- a/src/optimization/memory_manager.py +++ b/src/optimization/memory_manager.py @@ -71,6 +71,48 @@ def synchronize_device(device: Optional[Union[torch.device, str]], return False +def _iter_runtime_tensors(value: Any): + """Yield tensors stored in ad-hoc runtime containers like module.memory.""" + stack = [value] + seen = set() + + while stack: + current = stack.pop() + + if torch.is_tensor(current): + yield current + continue + + if isinstance(current, dict): + obj_id = id(current) + if obj_id in seen: + continue + seen.add(obj_id) + stack.extend(current.values()) + continue + + if isinstance(current, (list, tuple, set, frozenset)): + obj_id = id(current) + if obj_id in seen: + continue + seen.add(obj_id) + stack.extend(current) + continue + + +def _clear_runtime_memory_attr(module: Any) -> int: + """Drop a module.memory attribute when it contains runtime tensors.""" + if not hasattr(module, 'memory'): + return 0 + + tensor_count = sum(1 for _ in _iter_runtime_tensors(getattr(module, 'memory', None))) + if tensor_count <= 0: + return 0 + + module.memory = None + return tensor_count + + def synchronize_model(model: Optional[torch.nn.Module], debug: Optional['Debug'] = None, reason: Optional[str] = None) -> int: @@ -80,12 +122,25 @@ def synchronize_model(model: Optional[torch.nn.Module], devices = set() try: - for tensor in list(model.parameters()) + list(model.buffers()): + for tensor in model.parameters(): + if tensor is None or not torch.is_tensor(tensor): + continue + device = tensor.device + if device.type not in ('cpu', 'meta'): + devices.add(str(device)) + + for tensor in model.buffers(): if tensor is None or not torch.is_tensor(tensor): continue device = tensor.device if device.type not in ('cpu', 'meta'): devices.add(str(device)) + + for module in model.modules(): + for tensor in _iter_runtime_tensors(getattr(module, 'memory', None)): + device = tensor.device + if device.type not in ('cpu', 'meta'): + devices.add(str(device)) except Exception as e: if debug: why = f" ({reason})" if reason else "" @@ -651,14 +706,17 @@ def release_model_memory(model: Optional[torch.nn.Module], debug: Optional['Debu model.zero_grad(set_to_none=True) cleared_memory_buffers = 0 + cleared_runtime_tensors = 0 for module in model.modules(): - if hasattr(module, 'memory') and torch.is_tensor(getattr(module, 'memory')): - module.memory = None + cleared_tensors = _clear_runtime_memory_attr(module) + if cleared_tensors > 0: cleared_memory_buffers += 1 + cleared_runtime_tensors += cleared_tensors if debug: debug.log( - f"Released model references and cleared {cleared_memory_buffers} runtime memory buffers", + f"Released model references and cleared {cleared_memory_buffers} runtime memory buffers " + f"({cleared_runtime_tensors} tensors)", category="success", ) @@ -1006,13 +1064,17 @@ def _standard_model_movement(model: torch.nn.Module, current_device: torch.devic # Clear VAE memory buffers when moving to CPU if target_type == 'cpu' and model_name == "VAE": cleared_count = 0 + cleared_tensor_count = 0 for module in model.modules(): - if hasattr(module, 'memory') and module.memory is not None: - if torch.is_tensor(module.memory) and (module.memory.is_cuda or module.memory.is_mps): - module.memory = None - cleared_count += 1 + cleared_tensors = _clear_runtime_memory_attr(module) + if cleared_tensors > 0: + cleared_count += 1 + cleared_tensor_count += cleared_tensors if cleared_count > 0 and debug: - debug.log(f"Cleared {cleared_count} VAE memory buffers", category="success") + debug.log( + f"Cleared {cleared_count} VAE memory buffers ({cleared_tensor_count} tensors)", + category="success", + ) # End timer if debug: From b9f603318d530bc263b8dc91ca4602ddccf77f66 Mon Sep 17 00:00:00 2001 From: xmarre Date: Sun, 15 Mar 2026 11:21:00 +0100 Subject: [PATCH 04/29] Fix release_model_memory containers --- src/core/model_cache.py | 112 ++++++++++++++++++++--------- src/core/model_configuration.py | 59 +++++++-------- src/optimization/memory_manager.py | 1 - 3 files changed, 103 insertions(+), 69 deletions(-) diff --git a/src/core/model_cache.py b/src/core/model_cache.py index 33e586a5..2fa6d886 100644 --- a/src/core/model_cache.py +++ b/src/core/model_cache.py @@ -3,6 +3,7 @@ Enables independent DiT and VAE model sharing across multiple upscaler node instances """ +import threading from typing import Dict, Any, Optional, Tuple, TYPE_CHECKING from ..optimization.memory_manager import release_model_memory @@ -24,6 +25,8 @@ def __init__(self): self._vae_models: Dict[str, Tuple[Any, Dict]] = {} # Storage for runner templates: "dit_id+vae_id" -> runner self._runner_templates: Dict[str, Any] = {} + # Synchronizes runner-template claim/set/remove operations + self._runner_templates_lock = threading.RLock() def get_dit(self, dit_config: Dict[str, Any], debug: Optional['Debug'] = None) -> Optional[Any]: """ @@ -82,9 +85,44 @@ def get_runner(self, dit_id: Optional[int], vae_id: Optional[int], return None runner_key = f"{dit_id}+{vae_id}" - if runner_key in self._runner_templates: - return self._runner_templates[runner_key] - return None + with self._runner_templates_lock: + return self._runner_templates.get(runner_key) + + def claim_runner(self, dit_id: Optional[int], vae_id: Optional[int], + dit_model: str, vae_model: str) -> Tuple[Optional[Any], str]: + """ + Atomically inspect and claim a cached runner template for exclusive reuse. + + Returns: + (template, status) where status is one of: + - "missing": no cached template exists + - "active": template exists but is already in use + - "tainted": template exists but was marked failed/interrupted + - "mismatch": template exists but was built for different DiT/VAE names + - "claimed": template was successfully claimed for reuse + """ + if dit_id is None or vae_id is None: + return None, "missing" + + runner_key = f"{dit_id}+{vae_id}" + with self._runner_templates_lock: + template = self._runner_templates.get(runner_key) + if template is None: + return None, "missing" + + if getattr(template, '_seedvr2_execution_active', False): + return template, "active" + + if getattr(template, '_seedvr2_runner_tainted', False): + return template, "tainted" + + current_dit = getattr(template, '_dit_model_name', None) + current_vae = getattr(template, '_vae_model_name', None) + if current_dit != dit_model or current_vae != vae_model: + return template, "mismatch" + + setattr(template, '_seedvr2_execution_active', True) + return template, "claimed" def set_dit(self, dit_config: Dict[str, Any], model: Any, model_name: str, debug: Optional['Debug'] = None) -> Optional[str]: """ @@ -155,20 +193,21 @@ def set_runner(self, dit_id: Optional[int], vae_id: Optional[int], return None runner_key = f"{dit_id}+{vae_id}" - existing = self._runner_templates.get(runner_key) - if existing is runner: - return runner_key + with self._runner_templates_lock: + existing = self._runner_templates.get(runner_key) + if existing is runner: + return runner_key - replace_existing = False - if existing is not None: - replace_existing = getattr(existing, '_seedvr2_runner_tainted', False) + replace_existing = False + if existing is not None: + replace_existing = getattr(existing, '_seedvr2_runner_tainted', False) - if existing is None or replace_existing: - self._runner_templates[runner_key] = runner - if debug: - action = "replaced" if replace_existing else "cached" - debug.log(f"Runner template {action} in memory: nodes {runner_key}", category="cache", force=True) - return runner_key + if existing is None or replace_existing: + self._runner_templates[runner_key] = runner + if debug: + action = "replaced" if replace_existing else "cached" + debug.log(f"Runner template {action} in memory: nodes {runner_key}", category="cache", force=True) + return runner_key return None @@ -183,22 +222,23 @@ def remove_runner(self, dit_id: Optional[int], vae_id: Optional[int], if dit_id is None or vae_id is None: return False - runner_key = f"{dit_id}+{vae_id}" - cached_runner = self._runner_templates.get(runner_key) - if cached_runner is None: - return False + with self._runner_templates_lock: + runner_key = f"{dit_id}+{vae_id}" + cached_runner = self._runner_templates.get(runner_key) + if cached_runner is None: + return False - if expected_runner is not None and cached_runner is not expected_runner: - if debug: - debug.log( - f"Skipped cached runner removal for nodes {runner_key}: cache entry no longer matches expected runner", - level="WARNING", - category="cache", - force=True, - ) - return False + if expected_runner is not None and cached_runner is not expected_runner: + if debug: + debug.log( + f"Skipped cached runner removal for nodes {runner_key}: cache entry no longer matches expected runner", + level="WARNING", + category="cache", + force=True, + ) + return False - del self._runner_templates[runner_key] + del self._runner_templates[runner_key] if debug: debug.log(f"Removed cached runner template: nodes {runner_key}", category="cache", force=True) return True @@ -231,9 +271,10 @@ def remove_dit(self, dit_config: Dict[str, Any], debug: Optional['Debug'] = None del self._dit_models[node_id] # Remove any runner templates that used this DiT - templates_to_remove = [k for k in self._runner_templates.keys() if k.startswith(str(node_id) + "+")] - for template_key in templates_to_remove: - del self._runner_templates[template_key] + with self._runner_templates_lock: + templates_to_remove = [k for k in self._runner_templates.keys() if k.startswith(str(node_id) + "+")] + for template_key in templates_to_remove: + del self._runner_templates[template_key] return True return False @@ -266,9 +307,10 @@ def remove_vae(self, vae_config: Dict[str, Any], debug: Optional['Debug'] = None del self._vae_models[node_id] # Remove any runner templates that used this VAE - templates_to_remove = [k for k in self._runner_templates.keys() if k.endswith("+" + str(node_id))] - for template_key in templates_to_remove: - del self._runner_templates[template_key] + with self._runner_templates_lock: + templates_to_remove = [k for k in self._runner_templates.keys() if k.endswith("+" + str(node_id))] + for template_key in templates_to_remove: + del self._runner_templates[template_key] return True return False diff --git a/src/core/model_configuration.py b/src/core/model_configuration.py index 0462bd8c..3a86cc0c 100644 --- a/src/core/model_configuration.py +++ b/src/core/model_configuration.py @@ -658,30 +658,27 @@ def _acquire_runner( Returns: VideoDiffusionInfer: Runner instance (cached template or newly created) """ - # Try to get runner template from global cache - template = cache_context['global_cache'].get_runner( - cache_context['dit_id'], cache_context['vae_id'], debug + # Try to atomically claim a reusable runner template from global cache + template, template_status = cache_context['global_cache'].claim_runner( + cache_context['dit_id'], + cache_context['vae_id'], + dit_model, + vae_model, ) if template: runner_key = f"{cache_context['dit_id']}+{cache_context['vae_id']}" - if getattr(template, '_seedvr2_execution_active', False): + if template_status == "active": debug.log( f"Cached runner template still marked active: nodes {runner_key}; creating a fresh runner", level="WARNING", category="cache", force=True, ) - cache_context['global_cache'].remove_runner( - cache_context['dit_id'], - cache_context['vae_id'], - debug, - expected_runner=template, - ) return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) - if getattr(template, '_seedvr2_runner_tainted', False): + if template_status == "tainted": debug.log( f"Cached runner template was tainted by a prior failed/interrupted run: nodes {runner_key}; creating a fresh runner", level="WARNING", @@ -696,31 +693,27 @@ def _acquire_runner( ) return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) - # We have a template - check if we can use it - current_dit = getattr(template, '_dit_model_name', None) - current_vae = getattr(template, '_vae_model_name', None) - models_match = (current_dit == dit_model and current_vae == vae_model) - - if models_match: - # Perfect match - reuse template directly + if template_status == "claimed": debug.log(f"Reusing cached runner template: nodes {runner_key}", category="reuse", force=True) cache_context['reusing_runner'] = True return template - else: - debug.log( - f"Cached runner template models no longer match: nodes {runner_key} " - f"({current_dit}/{current_vae} -> {dit_model}/{vae_model}); creating a fresh runner", - level="WARNING", - category="cache", - force=True, - ) - cache_context['global_cache'].remove_runner( - cache_context['dit_id'], - cache_context['vae_id'], - debug, - expected_runner=template, - ) - return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) + + current_dit = getattr(template, '_dit_model_name', None) + current_vae = getattr(template, '_vae_model_name', None) + debug.log( + f"Cached runner template models no longer match: nodes {runner_key} " + f"({current_dit}/{current_vae} -> {dit_model}/{vae_model}); creating a fresh runner", + level="WARNING", + category="cache", + force=True, + ) + cache_context['global_cache'].remove_runner( + cache_context['dit_id'], + cache_context['vae_id'], + debug, + expected_runner=template, + ) + return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) else: # No template - create new runner return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) diff --git a/src/optimization/memory_manager.py b/src/optimization/memory_manager.py index c822886b..2b9c60d6 100644 --- a/src/optimization/memory_manager.py +++ b/src/optimization/memory_manager.py @@ -1368,7 +1368,6 @@ def complete_cleanup(runner: Any, debug: Optional['Debug'] = None, dit_cache: bo runner._vae_model_name = None # 4. Final memory cleanup - synchronize_visible_accelerators(debug=debug, reason="before final allocator cleanup") clear_memory(debug=debug, deep=True, force=True, timer_name="complete_cleanup") # 5. Clear cuBLAS workspaces From 8e4c89ab9be7969671624ef86e6494f5a87c7fc1 Mon Sep 17 00:00:00 2001 From: xmarre Date: Sun, 15 Mar 2026 11:27:12 +0100 Subject: [PATCH 05/29] Synchronize runner template access --- src/core/model_cache.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/core/model_cache.py b/src/core/model_cache.py index 2fa6d886..254225ef 100644 --- a/src/core/model_cache.py +++ b/src/core/model_cache.py @@ -186,8 +186,9 @@ def set_runner(self, dit_id: Optional[int], vae_id: Optional[int], debug: Optional debug instance for logging Returns: - Runner key string (format: "dit_id+vae_id") if cached successfully, - None if either ID is None + Runner key string (format: "dit_id+vae_id") if this call cached or + replaced the template, None if either ID is None or an existing + non-tainted template is intentionally kept. """ if dit_id is None or vae_id is None: return None From e16c827709b209692eec25b35bbb588998bcc236 Mon Sep 17 00:00:00 2001 From: xmarre Date: Sun, 15 Mar 2026 11:33:10 +0100 Subject: [PATCH 06/29] Synchronize runner template access --- src/core/model_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/model_cache.py b/src/core/model_cache.py index 254225ef..231a6cb7 100644 --- a/src/core/model_cache.py +++ b/src/core/model_cache.py @@ -121,7 +121,7 @@ def claim_runner(self, dit_id: Optional[int], vae_id: Optional[int], if current_dit != dit_model or current_vae != vae_model: return template, "mismatch" - setattr(template, '_seedvr2_execution_active', True) + template._seedvr2_execution_active = True return template, "claimed" def set_dit(self, dit_config: Dict[str, Any], model: Any, model_name: str, debug: Optional['Debug'] = None) -> Optional[str]: From 47600179ef8548a49c1ffba9d2afcc260fbb4490 Mon Sep 17 00:00:00 2001 From: xmarre Date: Sun, 15 Mar 2026 13:46:33 +0100 Subject: [PATCH 07/29] Address runner cache eviction bug --- src/core/model_configuration.py | 52 +++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/src/core/model_configuration.py b/src/core/model_configuration.py index 3a86cc0c..b9bfa8a1 100644 --- a/src/core/model_configuration.py +++ b/src/core/model_configuration.py @@ -846,26 +846,48 @@ def configure_runner( ) # Phase 2: Get or create runner + runner = None runner = _acquire_runner( cache_context, dit_model, vae_model, base_cache_dir, debug ) - # Phase 3: Configure runner settings - _configure_runner_settings( - runner, ctx, - encode_tiled, encode_tile_size, encode_tile_overlap, - decode_tiled, decode_tile_size, decode_tile_overlap, - tile_debug, attention_mode, - torch_compile_args_dit, torch_compile_args_vae, - block_swap_config, debug - ) - - # Phase 4: Setup models (load from cache or create new) - _setup_models( - runner, cache_context, dit_model, vae_model, - base_cache_dir, block_swap_config, debug - ) + try: + # Phase 3: Configure runner settings + _configure_runner_settings( + runner, ctx, + encode_tiled, encode_tile_size, encode_tile_overlap, + decode_tiled, decode_tile_size, decode_tile_overlap, + tile_debug, attention_mode, + torch_compile_args_dit, torch_compile_args_vae, + block_swap_config, debug + ) + + # Phase 4: Setup models (load from cache or create new) + _setup_models( + runner, cache_context, dit_model, vae_model, + base_cache_dir, block_swap_config, debug + ) + except BaseException: + if runner is not None and cache_context.get('reusing_runner', False): + runner._seedvr2_runner_tainted = True + runner._seedvr2_execution_active = False + try: + cache_context['global_cache'].remove_runner( + cache_context.get('dit_id'), + cache_context.get('vae_id'), + debug, + expected_runner=runner, + ) + except Exception as cache_error: + if debug is not None: + debug.log( + f"Failed to evict claimed runner after setup failure: {cache_error}", + level="WARNING", + category="cleanup", + force=True, + ) + raise return runner, cache_context From f0376ea5ab6b146fb947fe1073a4feb238d2eaec Mon Sep 17 00:00:00 2001 From: xmarre Date: Sun, 15 Mar 2026 18:00:08 +0100 Subject: [PATCH 08/29] Keep model claim until cache rewrite --- inference_cli.py | 278 +++++++++++++++++---------- src/core/generation_phases.py | 21 +- src/core/model_cache.py | 296 ++++++++++++++++++++++++----- src/core/model_configuration.py | 131 +++++++++++-- src/interfaces/video_upscaler.py | 11 +- src/optimization/memory_manager.py | 179 +++++++++++++++-- 6 files changed, 739 insertions(+), 177 deletions(-) diff --git a/inference_cli.py b/inference_cli.py index 2d4fff18..0d9a804a 100644 --- a/inference_cli.py +++ b/inference_cli.py @@ -130,8 +130,16 @@ decode_all_batches, postprocess_all_batches ) +from src.core.model_configuration import _evict_claimed_cached_models from src.utils.debug import Debug -from src.optimization.memory_manager import clear_memory, get_gpu_backend, is_cuda_available +from src.optimization.memory_manager import ( + cleanup_text_embeddings, + clear_memory, + complete_cleanup, + get_gpu_backend, + is_cuda_available, + set_model_cache_claimed_state, +) debug = Debug(enabled=False) # Will be enabled via --debug CLI flag @@ -913,103 +921,175 @@ def _process_frames_core( dit_id = "cli_dit" if cache_dit else None vae_id = "cli_vae" if cache_vae else None - runner, cache_context = prepare_runner( - dit_model=args.dit_model, - vae_model=DEFAULT_VAE, - model_dir=model_dir, - debug=debug, - ctx=ctx, - dit_cache=cache_dit, - vae_cache=cache_vae, - dit_id=dit_id, - vae_id=vae_id, - block_swap_config={ - 'blocks_to_swap': args.blocks_to_swap, - 'swap_io_components': args.swap_io_components, - 'offload_device': dit_offload, - }, - encode_tiled=args.vae_encode_tiled, - encode_tile_size=(args.vae_encode_tile_size, args.vae_encode_tile_size), - encode_tile_overlap=(args.vae_encode_tile_overlap, args.vae_encode_tile_overlap), - decode_tiled=args.vae_decode_tiled, - decode_tile_size=(args.vae_decode_tile_size, args.vae_decode_tile_size), - decode_tile_overlap=(args.vae_decode_tile_overlap, args.vae_decode_tile_overlap), - tile_debug=args.tile_debug.lower() if args.tile_debug else "false", - attention_mode=args.attention_mode, - torch_compile_args_dit=torch_compile_args_dit, - torch_compile_args_vae=torch_compile_args_vae - ) - - ctx['cache_context'] = cache_context - if runner_cache is not None: - runner_cache['runner'] = runner - - # Preload text embeddings before Phase 1 to avoid sync stall in Phase 2 - ctx['text_embeds'] = load_text_embeddings(script_directory, ctx['dit_device'], ctx['compute_dtype'], debug) - debug.log("Loaded text embeddings for DiT", category="dit") - - # Compute generation info and log start (handles prepending internally) - frames_tensor, gen_info = compute_generation_info( - ctx=ctx, - images=frames_tensor, - resolution=args.resolution, - max_resolution=args.max_resolution, - batch_size=args.batch_size, - uniform_batch_size=args.uniform_batch_size, - seed=args.seed, - prepend_frames=args.prepend_frames, - temporal_overlap=args.temporal_overlap, - debug=debug - ) - log_generation_start(gen_info, debug) - - # Phase 1: Encode - ctx = encode_all_batches( - runner, ctx=ctx, images=frames_tensor, - debug=debug, - batch_size=args.batch_size, - uniform_batch_size=args.uniform_batch_size, - seed=args.seed, - progress_callback=None, - temporal_overlap=args.temporal_overlap, - resolution=args.resolution, - max_resolution=args.max_resolution, - input_noise_scale=args.input_noise_scale, - color_correction=args.color_correction - ) - - # Phase 2: Upscale - ctx = upscale_all_batches( - runner, ctx=ctx, debug=debug, progress_callback=None, - seed=args.seed, - latent_noise_scale=args.latent_noise_scale, - cache_model=cache_dit - ) - - # Phase 3: Decode - ctx = decode_all_batches( - runner, ctx=ctx, debug=debug, progress_callback=None, - cache_model=cache_vae - ) - - # Phase 4: Post-process - ctx = postprocess_all_batches( - ctx=ctx, debug=debug, progress_callback=None, - color_correction=args.color_correction, - prepend_frames=0, # Worker mode handles this in main process - temporal_overlap=args.temporal_overlap, - batch_size=args.batch_size - ) - - result_tensor = ctx['final_video'] - - # Convert to CPU and compatible dtype - if result_tensor.is_cuda or result_tensor.is_mps: - result_tensor = result_tensor.cpu() - if result_tensor.dtype in (torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2): - result_tensor = result_tensor.to(torch.float32) - - return result_tensor + runner = None + cache_context = None + + def cleanup(dit_cache_flag: bool = False, vae_cache_flag: bool = False) -> None: + nonlocal runner, ctx + + if runner is not None: + try: + complete_cleanup( + runner=runner, + debug=debug, + dit_cache=dit_cache_flag, + vae_cache=vae_cache_flag, + ) + if dit_cache_flag and getattr(runner, 'dit', None) is not None: + set_model_cache_claimed_state(runner.dit, False) + if vae_cache_flag and getattr(runner, 'vae', None) is not None: + set_model_cache_claimed_state(runner.vae, False) + finally: + runner._seedvr2_execution_active = False + + if not (dit_cache_flag or vae_cache_flag): + runner = None + if runner_cache is not None: + runner_cache.pop('runner', None) + + if ctx is not None: + cleanup_text_embeddings(ctx, debug) + if not (dit_cache_flag or vae_cache_flag): + ctx = None + if runner_cache is not None: + runner_cache.pop('ctx', None) + + try: + runner, cache_context = prepare_runner( + dit_model=args.dit_model, + vae_model=DEFAULT_VAE, + model_dir=model_dir, + debug=debug, + ctx=ctx, + dit_cache=cache_dit, + vae_cache=cache_vae, + dit_id=dit_id, + vae_id=vae_id, + block_swap_config={ + 'blocks_to_swap': args.blocks_to_swap, + 'swap_io_components': args.swap_io_components, + 'offload_device': dit_offload, + }, + encode_tiled=args.vae_encode_tiled, + encode_tile_size=(args.vae_encode_tile_size, args.vae_encode_tile_size), + encode_tile_overlap=(args.vae_encode_tile_overlap, args.vae_encode_tile_overlap), + decode_tiled=args.vae_decode_tiled, + decode_tile_size=(args.vae_decode_tile_size, args.vae_decode_tile_size), + decode_tile_overlap=(args.vae_decode_tile_overlap, args.vae_decode_tile_overlap), + tile_debug=args.tile_debug.lower() if args.tile_debug else "false", + attention_mode=args.attention_mode, + torch_compile_args_dit=torch_compile_args_dit, + torch_compile_args_vae=torch_compile_args_vae + ) + + runner._seedvr2_execution_active = True + runner._seedvr2_runner_tainted = False + + if ( + cache_context is not None + and not cache_context.get('reusing_runner', False) + and cache_context.get('cached_dit') is not None + and cache_context.get('cached_vae') is not None + ): + cache_context['global_cache'].set_runner( + cache_context.get('dit_id'), + cache_context.get('vae_id'), + runner, + debug, + ) + + ctx['cache_context'] = cache_context + if runner_cache is not None: + runner_cache['runner'] = runner + + # Preload text embeddings before Phase 1 to avoid sync stall in Phase 2 + ctx['text_embeds'] = load_text_embeddings(script_directory, ctx['dit_device'], ctx['compute_dtype'], debug) + debug.log("Loaded text embeddings for DiT", category="dit") + + # Compute generation info and log start (handles prepending internally) + frames_tensor, gen_info = compute_generation_info( + ctx=ctx, + images=frames_tensor, + resolution=args.resolution, + max_resolution=args.max_resolution, + batch_size=args.batch_size, + uniform_batch_size=args.uniform_batch_size, + seed=args.seed, + prepend_frames=args.prepend_frames, + temporal_overlap=args.temporal_overlap, + debug=debug + ) + log_generation_start(gen_info, debug) + + # Phase 1: Encode + ctx = encode_all_batches( + runner, ctx=ctx, images=frames_tensor, + debug=debug, + batch_size=args.batch_size, + uniform_batch_size=args.uniform_batch_size, + seed=args.seed, + progress_callback=None, + temporal_overlap=args.temporal_overlap, + resolution=args.resolution, + max_resolution=args.max_resolution, + input_noise_scale=args.input_noise_scale, + color_correction=args.color_correction + ) + + # Phase 2: Upscale + ctx = upscale_all_batches( + runner, ctx=ctx, debug=debug, progress_callback=None, + seed=args.seed, + latent_noise_scale=args.latent_noise_scale, + cache_model=cache_dit + ) + + # Phase 3: Decode + ctx = decode_all_batches( + runner, ctx=ctx, debug=debug, progress_callback=None, + cache_model=cache_vae + ) + + # Phase 4: Post-process + ctx = postprocess_all_batches( + ctx=ctx, debug=debug, progress_callback=None, + color_correction=args.color_correction, + prepend_frames=0, # Worker mode handles this in main process + temporal_overlap=args.temporal_overlap, + batch_size=args.batch_size + ) + + result_tensor = ctx['final_video'] + + # Convert to CPU and compatible dtype + if result_tensor.is_cuda or result_tensor.is_mps: + result_tensor = result_tensor.cpu() + if result_tensor.dtype in (torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2): + result_tensor = result_tensor.to(torch.float32) + + cleanup(dit_cache_flag=cache_dit, vae_cache_flag=cache_vae) + return result_tensor + except BaseException: + if runner is not None: + runner._seedvr2_runner_tainted = True + + if cache_context is not None: + _evict_claimed_cached_models(cache_context, runner, debug) + try: + cache_context['global_cache'].remove_runner( + cache_context.get('dit_id'), + cache_context.get('vae_id'), + debug, + expected_runner=runner, + ) + except BaseException: + pass + + try: + cleanup(dit_cache_flag=False, vae_cache_flag=False) + except BaseException: + pass + raise def _worker_process( @@ -1709,4 +1789,4 @@ def main() -> None: debug.print_footer() if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/core/generation_phases.py b/src/core/generation_phases.py index 3b7e6ea0..eb490443 100644 --- a/src/core/generation_phases.py +++ b/src/core/generation_phases.py @@ -47,6 +47,7 @@ cleanup_dit, cleanup_vae, cleanup_text_embeddings, + is_model_cache_cold, manage_tensor, manage_model_device, release_tensor_memory, @@ -294,12 +295,20 @@ def encode_all_batches( encode_idx = 0 try: + vae_needs_reactivation = runner.vae is not None and is_model_cache_cold(runner.vae) + # Materialize VAE if still on meta device if runner.vae and next(runner.vae.parameters()).device.type == 'meta': materialize_model(runner, "vae", ctx['vae_device'], runner.config, debug) else: + # Cold cached models keep weights/config, but execution state is rebuilt each run. + if vae_needs_reactivation: + debug.log("Rebuilding VAE execution state from cold cache", category="vae", force=True) + manage_model_device(model=runner.vae, target_device=ctx['vae_device'], + model_name="VAE", debug=debug, reason="cold-cache activation", runner=runner) + apply_model_specific_config(runner.vae, runner, runner.config, False, debug) # Model already materialized (cached) - apply any pending configs if needed - if getattr(runner, '_vae_config_needs_application', False): + elif getattr(runner, '_vae_config_needs_application', False): debug.log("Applying updated VAE configuration", category="vae", force=True) apply_model_specific_config(runner.vae, runner, runner.config, False, debug) @@ -616,12 +625,20 @@ def upscale_all_batches( upscale_idx = 0 try: + dit_needs_reactivation = runner.dit is not None and is_model_cache_cold(runner.dit) + # Materialize DiT if still on meta device if runner.dit and next(runner.dit.parameters()).device.type == 'meta': materialize_model(runner, "dit", ctx['dit_device'], runner.config, debug) else: + # Cold cached models keep weights/config, but execution state is rebuilt each run. + if dit_needs_reactivation: + debug.log("Rebuilding DiT execution state from cold cache", category="dit", force=True) + manage_model_device(model=runner.dit, target_device=ctx['dit_device'], + model_name="DiT", debug=debug, reason="cold-cache activation", runner=runner) + apply_model_specific_config(runner.dit, runner, runner.config, True, debug) # Model already materialized (cached) - apply any pending configs if needed - if getattr(runner, '_dit_config_needs_application', False): + elif getattr(runner, '_dit_config_needs_application', False): debug.log("Applying updated DiT configuration", category="dit", force=True) apply_model_specific_config(runner.dit, runner, runner.config, True, debug) diff --git a/src/core/model_cache.py b/src/core/model_cache.py index 231a6cb7..21563ec1 100644 --- a/src/core/model_cache.py +++ b/src/core/model_cache.py @@ -5,7 +5,14 @@ import threading from typing import Dict, Any, Optional, Tuple, TYPE_CHECKING -from ..optimization.memory_manager import release_model_memory +from ..optimization.memory_manager import ( + is_model_cache_claimed, + is_model_cache_cold, + iter_model_wrapper_chain, + release_model_memory, + set_model_cache_claimed_state, + set_model_cache_cold_state, +) if TYPE_CHECKING: from ..utils.debug import Debug @@ -25,8 +32,19 @@ def __init__(self): self._vae_models: Dict[str, Tuple[Any, Dict]] = {} # Storage for runner templates: "dit_id+vae_id" -> runner self._runner_templates: Dict[str, Any] = {} + # Synchronizes DiT/VAE model cache claim/set/replace/remove operations + self._model_cache_lock = threading.RLock() # Synchronizes runner-template claim/set/remove operations self._runner_templates_lock = threading.RLock() + + def _models_share_identity(self, cached_model: Any, expected_model: Any) -> bool: + """Return True when two model references point into the same wrapper/base chain.""" + if cached_model is None or expected_model is None: + return False + + cached_ids = {id(model) for model in iter_model_wrapper_chain(cached_model)} + expected_ids = {id(model) for model in iter_model_wrapper_chain(expected_model)} + return bool(cached_ids & expected_ids) def get_dit(self, dit_config: Dict[str, Any], debug: Optional['Debug'] = None) -> Optional[Any]: """ @@ -43,10 +61,38 @@ def get_dit(self, dit_config: Dict[str, Any], debug: Optional['Debug'] = None) - return None node_id = dit_config.get('node_id') - if node_id in self._dit_models: - model, stored_config = self._dit_models[node_id] - return model + with self._model_cache_lock: + if node_id in self._dit_models: + model, stored_config = self._dit_models[node_id] + if not is_model_cache_cold(model): + if debug: + debug.log( + f"Cached DiT is still hot or incomplete; waiting for cold cache state before reuse (node {node_id})", + category="cache", + force=True, + ) + return None + if is_model_cache_claimed(model): + if debug: + debug.log( + f"Cached DiT is already claimed by another execution; skipping reuse (node {node_id})", + category="cache", + force=True, + ) + return None + set_model_cache_claimed_state(model, True) + return model return None + + def peek_dit(self, dit_config: Dict[str, Any]) -> Optional[Any]: + """Return the cached DiT model without claiming it.""" + node_id = dit_config.get('node_id') + if node_id is None: + return None + + with self._model_cache_lock: + entry = self._dit_models.get(node_id) + return None if entry is None else entry[0] def get_vae(self, vae_config: Dict[str, Any], debug: Optional['Debug'] = None) -> Optional[Any]: """ @@ -63,10 +109,38 @@ def get_vae(self, vae_config: Dict[str, Any], debug: Optional['Debug'] = None) - return None node_id = vae_config.get('node_id') - if node_id in self._vae_models: - model, stored_config = self._vae_models[node_id] - return model + with self._model_cache_lock: + if node_id in self._vae_models: + model, stored_config = self._vae_models[node_id] + if not is_model_cache_cold(model): + if debug: + debug.log( + f"Cached VAE is still hot or incomplete; waiting for cold cache state before reuse (node {node_id})", + category="cache", + force=True, + ) + return None + if is_model_cache_claimed(model): + if debug: + debug.log( + f"Cached VAE is already claimed by another execution; skipping reuse (node {node_id})", + category="cache", + force=True, + ) + return None + set_model_cache_claimed_state(model, True) + return model return None + + def peek_vae(self, vae_config: Dict[str, Any]) -> Optional[Any]: + """Return the cached VAE model without claiming it.""" + node_id = vae_config.get('node_id') + if node_id is None: + return None + + with self._model_cache_lock: + entry = self._vae_models.get(node_id) + return None if entry is None else entry[0] def get_runner(self, dit_id: Optional[int], vae_id: Optional[int], debug: Optional['Debug'] = None) -> Optional[Any]: @@ -141,7 +215,22 @@ def set_dit(self, dit_config: Dict[str, Any], model: Any, model_name: str, debug return None node_id = dit_config.get('node_id') - self._dit_models[node_id] = (model, dit_config) + with self._model_cache_lock: + existing = self._dit_models.get(node_id) + if existing is not None: + existing_model, _ = existing + if not self._models_share_identity(existing_model, model) and is_model_cache_claimed(existing_model): + if debug: + debug.log( + f"Skipped caching DiT model for node {node_id}: cache entry is currently claimed by another execution", + level="WARNING", + category="cache", + force=True, + ) + return None + set_model_cache_cold_state(model, False) + set_model_cache_claimed_state(model, True) + self._dit_models[node_id] = (model, dit_config) if debug: debug.log(f"DiT model cached in memory (node {node_id}): {model_name}", @@ -166,13 +255,86 @@ def set_vae(self, vae_config: Dict[str, Any], model: Any, model_name: str, debug return None node_id = vae_config.get('node_id') - self._vae_models[node_id] = (model, vae_config) + with self._model_cache_lock: + existing = self._vae_models.get(node_id) + if existing is not None: + existing_model, _ = existing + if not self._models_share_identity(existing_model, model) and is_model_cache_claimed(existing_model): + if debug: + debug.log( + f"Skipped caching VAE model for node {node_id}: cache entry is currently claimed by another execution", + level="WARNING", + category="cache", + force=True, + ) + return None + set_model_cache_cold_state(model, False) + set_model_cache_claimed_state(model, True) + self._vae_models[node_id] = (model, vae_config) if debug: debug.log(f"VAE model cached in memory (node {node_id}): {model_name}", category="cache", force=True) return node_id + + def replace_dit( + self, + dit_config: Dict[str, Any], + model: Any, + debug: Optional['Debug'] = None, + expected_model: Optional[Any] = None, + ) -> bool: + """Rewrite a cached DiT entry to a normalized canonical model.""" + node_id = dit_config.get('node_id') + with self._model_cache_lock: + if node_id not in self._dit_models: + return False + + cached_model, stored_config = self._dit_models[node_id] + if expected_model is not None and not self._models_share_identity(cached_model, expected_model): + if debug: + debug.log( + f"Skipped cached DiT rewrite for node {node_id}: cache entry no longer matches the claimed model", + level="WARNING", + category="cache", + force=True, + ) + return False + + self._dit_models[node_id] = (model, stored_config) + if debug: + debug.log(f"Rewrote cached DiT entry to cold canonical model (node {node_id})", category="cache", force=True) + return True + + def replace_vae( + self, + vae_config: Dict[str, Any], + model: Any, + debug: Optional['Debug'] = None, + expected_model: Optional[Any] = None, + ) -> bool: + """Rewrite a cached VAE entry to a normalized canonical model.""" + node_id = vae_config.get('node_id') + with self._model_cache_lock: + if node_id not in self._vae_models: + return False + + cached_model, stored_config = self._vae_models[node_id] + if expected_model is not None and not self._models_share_identity(cached_model, expected_model): + if debug: + debug.log( + f"Skipped cached VAE rewrite for node {node_id}: cache entry no longer matches the claimed model", + level="WARNING", + category="cache", + force=True, + ) + return False + + self._vae_models[node_id] = (model, stored_config) + if debug: + debug.log(f"Rewrote cached VAE entry to cold canonical model (node {node_id})", category="cache", force=True) + return True def set_runner(self, dit_id: Optional[int], vae_id: Optional[int], runner: Any, debug: Optional['Debug'] = None) -> Optional[str]: @@ -244,7 +406,12 @@ def remove_runner(self, dit_id: Optional[int], vae_id: Optional[int], debug.log(f"Removed cached runner template: nodes {runner_key}", category="cache", force=True) return True - def remove_dit(self, dit_config: Dict[str, Any], debug: Optional['Debug'] = None) -> bool: + def remove_dit( + self, + dit_config: Dict[str, Any], + debug: Optional['Debug'] = None, + expected_model: Optional[Any] = None, + ) -> bool: """ Remove DiT model from cache if it exists. @@ -259,28 +426,52 @@ def remove_dit(self, dit_config: Dict[str, Any], debug: Optional['Debug'] = None Also removes any runner templates that used this DiT model """ node_id = dit_config.get('node_id') - if node_id in self._dit_models: + with self._model_cache_lock: + if node_id not in self._dit_models: + return False + + cached_model, stored_config = self._dit_models[node_id] + if expected_model is None and is_model_cache_claimed(cached_model): + if debug: + debug.log( + f"Skipped cached DiT removal for node {node_id}: cache entry is currently claimed by another execution", + level="WARNING", + category="cache", + force=True, + ) + return False + if expected_model is not None and not self._models_share_identity(cached_model, expected_model): + if debug: + debug.log( + f"Skipped cached DiT removal for node {node_id}: cache entry no longer matches the claimed model", + level="WARNING", + category="cache", + force=True, + ) + return False + if debug: debug.log(f"Removing cached DiT: {node_id}", category="cache", force=True) - model, stored_config = self._dit_models[node_id] - - # Release model memory - if model is not None: - release_model_memory(model=model, debug=debug) - + model = cached_model del self._dit_models[node_id] - - # Remove any runner templates that used this DiT - with self._runner_templates_lock: - templates_to_remove = [k for k in self._runner_templates.keys() if k.startswith(str(node_id) + "+")] - for template_key in templates_to_remove: - del self._runner_templates[template_key] - - return True - return False + + if model is not None: + release_model_memory(model=model, debug=debug) + + with self._runner_templates_lock: + templates_to_remove = [k for k in self._runner_templates.keys() if k.startswith(str(node_id) + "+")] + for template_key in templates_to_remove: + del self._runner_templates[template_key] + + return True - def remove_vae(self, vae_config: Dict[str, Any], debug: Optional['Debug'] = None) -> bool: + def remove_vae( + self, + vae_config: Dict[str, Any], + debug: Optional['Debug'] = None, + expected_model: Optional[Any] = None, + ) -> bool: """ Remove VAE model from cache if it exists. @@ -295,26 +486,45 @@ def remove_vae(self, vae_config: Dict[str, Any], debug: Optional['Debug'] = None Also removes any runner templates that used this VAE model """ node_id = vae_config.get('node_id') - if node_id in self._vae_models: + with self._model_cache_lock: + if node_id not in self._vae_models: + return False + + cached_model, stored_config = self._vae_models[node_id] + if expected_model is None and is_model_cache_claimed(cached_model): + if debug: + debug.log( + f"Skipped cached VAE removal for node {node_id}: cache entry is currently claimed by another execution", + level="WARNING", + category="cache", + force=True, + ) + return False + if expected_model is not None and not self._models_share_identity(cached_model, expected_model): + if debug: + debug.log( + f"Skipped cached VAE removal for node {node_id}: cache entry no longer matches the claimed model", + level="WARNING", + category="cache", + force=True, + ) + return False + if debug: debug.log(f"Removing cached VAE: {node_id}", category="cache", force=True) - model, stored_config = self._vae_models[node_id] - - # Release model memory directly - if model is not None: - release_model_memory(model=model, debug=debug) - + model = cached_model del self._vae_models[node_id] - - # Remove any runner templates that used this VAE - with self._runner_templates_lock: - templates_to_remove = [k for k in self._runner_templates.keys() if k.endswith("+" + str(node_id))] - for template_key in templates_to_remove: - del self._runner_templates[template_key] - - return True - return False + + if model is not None: + release_model_memory(model=model, debug=debug) + + with self._runner_templates_lock: + templates_to_remove = [k for k in self._runner_templates.keys() if k.endswith("+" + str(node_id))] + for template_key in templates_to_remove: + del self._runner_templates[template_key] + + return True # Global singleton instance diff --git a/src/core/model_configuration.py b/src/core/model_configuration.py index b9bfa8a1..2e683d02 100644 --- a/src/core/model_configuration.py +++ b/src/core/model_configuration.py @@ -75,7 +75,13 @@ validate_attention_mode ) from ..optimization.blockswap import is_blockswap_enabled, validate_blockswap_config, apply_block_swap_to_dit, cleanup_blockswap -from ..optimization.memory_manager import cleanup_dit, cleanup_vae +from ..optimization.memory_manager import ( + cleanup_dit, + cleanup_vae, + is_model_cache_claimed, + set_model_cache_claimed_state, + set_model_cache_cold_state, +) from ..utils.constants import find_model_file @@ -320,7 +326,7 @@ def _update_model_config( elif config_name == 'attention_mode': display_name = 'Attention Mode' - config_changes.append(f"{display_name}: {old_desc} → {new_desc}") + config_changes.append(f"{display_name}: {old_desc} -> {new_desc}") # If nothing changed, reuse model as-is if not any(changes_detected.values()): @@ -592,41 +598,89 @@ def _initialize_cache_context( # Check for cached DiT model with model name validation # Model name validation prevents stale cache when user switches models in UI if dit_cache and dit_model and dit_id is not None: - cached_model = global_cache.get_dit({'node_id': dit_id, 'cache_model': True}, debug) - if cached_model: + cached_model = global_cache.peek_dit({'node_id': dit_id}) + if cached_model is not None: + cached_claimed = is_model_cache_claimed(cached_model) # Verify cached model matches requested model by checking _model_name attribute cached_model_name = getattr(cached_model, '_model_name', None) if cached_model_name == dit_model: # Cache hit with valid model - reuse it - context['cached_dit'] = cached_model + claimed_model = global_cache.get_dit({'node_id': dit_id, 'cache_model': True}, debug) + if claimed_model is not None: + claimed_model_name = getattr(claimed_model, '_model_name', None) + if claimed_model_name == dit_model: + context['cached_dit'] = claimed_model + else: + if claimed_model_name: + debug.log( + f"Claimed DiT no longer matches requested model ({claimed_model_name} -> {dit_model}), " + f"evicting claimed cache entry", + category="cache", + force=True, + ) + global_cache.remove_dit({'node_id': dit_id}, debug, expected_model=claimed_model) else: # Model changed - remove stale cache and log the change if cached_model_name: - debug.log(f"DiT model changed in cache ({cached_model_name} → {dit_model}), " + debug.log(f"DiT model changed in cache ({cached_model_name} -> {dit_model}), " f"removing stale cached model", category="cache", force=True) - global_cache.remove_dit({'node_id': dit_id}, debug) + if cached_claimed: + debug.log( + f"Cached DiT for node {dit_id} is stale but currently claimed; leaving it in cache until the owning execution releases it", + level="WARNING", + category="cache", + force=True, + ) + else: + global_cache.remove_dit({'node_id': dit_id}, debug) else: # Caching disabled or no ID - clean up any existing cache for this node if dit_id is not None: - global_cache.remove_dit({'node_id': dit_id}, debug) + cached_model = global_cache.peek_dit({'node_id': dit_id}) + if cached_model is not None and not is_model_cache_claimed(cached_model): + global_cache.remove_dit({'node_id': dit_id}, debug) # Check for cached VAE model with model name validation if vae_cache and vae_model and vae_id is not None: - cached_model = global_cache.get_vae({'node_id': vae_id, 'cache_model': True}, debug) - if cached_model: + cached_model = global_cache.peek_vae({'node_id': vae_id}) + if cached_model is not None: + cached_claimed = is_model_cache_claimed(cached_model) # Verify cached model matches requested model by checking _model_name attribute cached_model_name = getattr(cached_model, '_model_name', None) if cached_model_name == vae_model: - context['cached_vae'] = cached_model + claimed_model = global_cache.get_vae({'node_id': vae_id, 'cache_model': True}, debug) + if claimed_model is not None: + claimed_model_name = getattr(claimed_model, '_model_name', None) + if claimed_model_name == vae_model: + context['cached_vae'] = claimed_model + else: + if claimed_model_name: + debug.log( + f"Claimed VAE no longer matches requested model ({claimed_model_name} -> {vae_model}), " + f"evicting claimed cache entry", + category="cache", + force=True, + ) + global_cache.remove_vae({'node_id': vae_id}, debug, expected_model=claimed_model) else: # Model changed - remove stale cache and log the change if cached_model_name: - debug.log(f"VAE model changed in cache ({cached_model_name} → {vae_model}), " + debug.log(f"VAE model changed in cache ({cached_model_name} -> {vae_model}), " f"removing stale cached model", category="cache", force=True) - global_cache.remove_vae({'node_id': vae_id}, debug) + if cached_claimed: + debug.log( + f"Cached VAE for node {vae_id} is stale but currently claimed; leaving it in cache until the owning execution releases it", + level="WARNING", + category="cache", + force=True, + ) + else: + global_cache.remove_vae({'node_id': vae_id}, debug) else: if vae_id is not None: - global_cache.remove_vae({'node_id': vae_id}, debug) + cached_model = global_cache.peek_vae({'node_id': vae_id}) + if cached_model is not None and not is_model_cache_claimed(cached_model): + global_cache.remove_vae({'node_id': vae_id}, debug) return context @@ -734,8 +788,8 @@ def _create_new_runner( Args: dit_model: DiT model filename (determines config selection) - - Contains "7b" → loads configs_7b/main.yaml - - Otherwise → loads configs_3b/main.yaml + - Contains "7b" -> loads configs_7b/main.yaml + - Otherwise -> loads configs_3b/main.yaml vae_model: VAE model filename (stored for reference, not used in config selection) base_cache_dir: Base directory for model files (not used directly but passed for context) debug: Debug instance for logging and timing @@ -856,6 +910,8 @@ def configure_runner( # Phase 3: Configure runner settings _configure_runner_settings( runner, ctx, + cache_context.get('dit_id') if dit_cache else None, + cache_context.get('vae_id') if vae_cache else None, encode_tiled, encode_tile_size, encode_tile_overlap, decode_tiled, decode_tile_size, decode_tile_overlap, tile_debug, attention_mode, @@ -869,6 +925,7 @@ def configure_runner( base_cache_dir, block_swap_config, debug ) except BaseException: + _evict_claimed_cached_models(cache_context, runner, debug) if runner is not None and cache_context.get('reusing_runner', False): runner._seedvr2_runner_tainted = True runner._seedvr2_execution_active = False @@ -892,9 +949,41 @@ def configure_runner( return runner, cache_context +def _evict_claimed_cached_models( + cache_context: Dict[str, Any], + runner: Optional[VideoDiffusionInfer], + debug: Optional['Debug'] = None, +) -> None: + """ + Evict claimed cached models after activation/setup failure. + + Claimed cached models may be partially materialized or partially reconfigured + when an exception interrupts setup. In that case they must be removed from the + global cache rather than merely unclaimed. + """ + if not cache_context: + return + + global_cache = cache_context.get('global_cache') + if global_cache is None: + return + + dit_id = cache_context.get('dit_id') + if cache_context.get('dit_cache') and dit_id is not None: + expected_dit = (getattr(runner, 'dit', None) if runner is not None else None) or cache_context.get('cached_dit') + global_cache.remove_dit({'node_id': dit_id}, debug, expected_model=expected_dit) + + vae_id = cache_context.get('vae_id') + if cache_context.get('vae_cache') and vae_id is not None: + expected_vae = (getattr(runner, 'vae', None) if runner is not None else None) or cache_context.get('cached_vae') + global_cache.remove_vae({'node_id': vae_id}, debug, expected_model=expected_vae) + + def _configure_runner_settings( runner: VideoDiffusionInfer, ctx: Dict[str, Any], + dit_cache_node_id: Optional[int], + vae_cache_node_id: Optional[int], encode_tiled: bool, encode_tile_size: Optional[Tuple[int, int]], encode_tile_overlap: Optional[Tuple[int, int]], @@ -963,6 +1052,8 @@ def _configure_runner_settings( runner._vae_offload_device = ctx['vae_offload_device'] runner._tensor_offload_device = ctx['tensor_offload_device'] runner._compute_dtype = ctx['compute_dtype'] + runner._dit_cache_node_id = dit_cache_node_id + runner._vae_cache_node_id = vae_cache_node_id runner.debug = debug @@ -1079,7 +1170,7 @@ def _setup_dit_model( current_dit_name = getattr(runner, '_dit_model_name', None) if current_dit_name and current_dit_name != dit_model: if hasattr(runner, 'dit') and runner.dit is not None: - debug.log(f"DiT model changed ({current_dit_name} → {dit_model}), cleaning old model", + debug.log(f"DiT model changed ({current_dit_name} -> {dit_model}), cleaning old model", category="cache", force=True) cleanup_dit(runner=runner, debug=debug, cache_model=False) @@ -1152,7 +1243,7 @@ def _setup_vae_model( current_vae_name = getattr(runner, '_vae_model_name', None) if current_vae_name and current_vae_name != vae_model: if hasattr(runner, 'vae') and runner.vae is not None: - debug.log(f"VAE model changed ({current_vae_name} → {vae_model}), cleaning old model", + debug.log(f"VAE model changed ({current_vae_name} -> {vae_model}), cleaning old model", category="cache", force=True) cleanup_vae(runner=runner, debug=debug, cache_model=False) @@ -1336,7 +1427,9 @@ def apply_model_specific_config(model: torch.nn.Module, runner: VideoDiffusionIn # Clear the config application flag after successful application if hasattr(runner, '_vae_config_needs_application'): runner._vae_config_needs_application = False - + + set_model_cache_cold_state(model, False) + set_model_cache_claimed_state(model, True) return model diff --git a/src/interfaces/video_upscaler.py b/src/interfaces/video_upscaler.py index 5685a622..6c93a256 100644 --- a/src/interfaces/video_upscaler.py +++ b/src/interfaces/video_upscaler.py @@ -23,10 +23,12 @@ load_text_embeddings, script_directory ) +from ..core.model_configuration import _evict_claimed_cached_models from ..optimization.memory_manager import ( cleanup_text_embeddings, complete_cleanup, - get_device_list + get_device_list, + set_model_cache_claimed_state, ) # Import ComfyUI progress reporting @@ -327,6 +329,10 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: dit_cache=dit_cache, vae_cache=vae_cache, ) + if dit_cache and getattr(runner, 'dit', None) is not None: + set_model_cache_claimed_state(runner.dit, False) + if vae_cache and getattr(runner, 'vae', None) is not None: + set_model_cache_claimed_state(runner.vae, False) finally: runner._seedvr2_execution_active = False @@ -609,6 +615,7 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: runner._seedvr2_runner_tainted = True if cache_context is not None: + _evict_claimed_cached_models(cache_context, runner, debug) try: cache_context['global_cache'].remove_runner( cache_context.get('dit_id'), @@ -626,7 +633,7 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: ) try: - cleanup(dit_cache=dit_cache, vae_cache=vae_cache) + cleanup(dit_cache=False, vae_cache=False) except BaseException as cleanup_error: if debug is not None: debug.log( diff --git a/src/optimization/memory_manager.py b/src/optimization/memory_manager.py index 2b9c60d6..c0efd00f 100644 --- a/src/optimization/memory_manager.py +++ b/src/optimization/memory_manager.py @@ -725,6 +725,119 @@ def release_model_memory(model: Optional[torch.nn.Module], debug: Optional['Debu debug.log(f"Failed to release model memory: {e}", level="WARNING", category="memory", force=True) +def iter_model_wrapper_chain(model: Optional[torch.nn.Module]): + """Yield a model plus any known wrappers/base modules reachable through unwrap attributes.""" + if model is None: + return + + stack = [model] + seen = set() + + while stack: + current = stack.pop() + if current is None: + continue + + obj_id = id(current) + if obj_id in seen: + continue + seen.add(obj_id) + yield current + + for attr in ('_orig_mod', 'dit_model'): + child = getattr(current, attr, None) + if child is not None: + stack.append(child) + + +def set_model_cache_cold_state(model: Optional[torch.nn.Module], is_cold: bool) -> None: + """Mark a cached model and all known wrappers/base objects as cold or hot.""" + for wrapped_model in iter_model_wrapper_chain(model): + setattr(wrapped_model, '_seedvr2_cold_cache', is_cold) + + +def set_model_cache_claimed_state(model: Optional[torch.nn.Module], is_claimed: bool) -> None: + """Mark a cached model and all known wrappers/base objects as claimed or free.""" + for wrapped_model in iter_model_wrapper_chain(model): + setattr(wrapped_model, '_seedvr2_cache_claimed', is_claimed) + + +def is_model_cache_cold(model: Optional[torch.nn.Module]) -> bool: + """Return True when a cached model is in its cold reusable canonical form.""" + return any(getattr(wrapped_model, '_seedvr2_cold_cache', False) for wrapped_model in iter_model_wrapper_chain(model)) + + +def is_model_cache_claimed(model: Optional[torch.nn.Module]) -> bool: + """Return True when a cached model is already leased to an in-flight execution.""" + return any(getattr(wrapped_model, '_seedvr2_cache_claimed', False) for wrapped_model in iter_model_wrapper_chain(model)) + + +def _copy_model_cache_metadata(source: Any, target: Any, attrs: Tuple[str, ...]) -> None: + """Preserve cache metadata when normalizing wrapped models back to their base form.""" + for attr in attrs: + if hasattr(source, attr): + setattr(target, attr, getattr(source, attr)) + + +def _normalize_cached_dit_model(model: torch.nn.Module, debug: Optional['Debug'] = None) -> torch.nn.Module: + """Return a cold canonical DiT model with compile/wrapper state removed.""" + while True: + changed = False + + if hasattr(model, '_orig_mod'): + if debug: + debug.log("Removing torch.compile wrapper from DiT for cold cache storage", category="cleanup") + base_model = model._orig_mod + _copy_model_cache_metadata( + model, + base_model, + ('_model_name', '_config_compile', '_config_swap', '_config_attn'), + ) + model = base_model + changed = True + + if hasattr(model, 'dit_model'): + if debug: + debug.log("Removing DiT compatibility wrapper for cold cache storage", category="cleanup") + base_model = model.dit_model + _copy_model_cache_metadata( + model, + base_model, + ('_model_name', '_config_compile', '_config_swap', '_config_attn'), + ) + model = base_model + changed = True + + if not changed: + break + + release_model_memory(model=model, debug=debug) + set_model_cache_cold_state(model, True) + return model + + +def _normalize_cached_vae_model(model: torch.nn.Module, debug: Optional['Debug'] = None) -> torch.nn.Module: + """Return a cold canonical VAE model with compiled submodules removed.""" + if hasattr(model, 'encoder') and hasattr(model.encoder, '_orig_mod'): + if debug: + debug.log("Removing torch.compile wrapper from VAE encoder for cold cache storage", category="cleanup") + model.encoder = model.encoder._orig_mod + + if hasattr(model, 'decoder') and hasattr(model.decoder, '_orig_mod'): + if debug: + debug.log("Removing torch.compile wrapper from VAE decoder for cold cache storage", category="cleanup") + model.decoder = model.decoder._orig_mod + + if hasattr(model, 'debug'): + model.debug = None + if hasattr(model, 'tensor_offload_device'): + model.tensor_offload_device = None + + release_model_memory(model=model, debug=debug) + set_model_cache_cold_state(model, True) + return model + + def manage_tensor( tensor: torch.Tensor, target_device: torch.device, @@ -838,17 +951,21 @@ def manage_model_device(model: torch.nn.Module, target_device: torch.device, mod if runner and model_name == "DiT": # Import here to avoid circular dependency from .blockswap import is_blockswap_enabled + actual_model = getattr(model, "dit_model", model) # Check if BlockSwap config exists and is enabled has_blockswap_config = ( hasattr(runner, '_dit_block_swap_config') and is_blockswap_enabled(runner._dit_block_swap_config) ) + has_blockswap_runtime_state = ( + getattr(runner, '_blockswap_active', False) or + hasattr(actual_model, '_block_swap_config') or + hasattr(actual_model, '_original_to') or + getattr(actual_model, '_blockswap_bypass_protection', False) + ) - if has_blockswap_config: + if has_blockswap_config and has_blockswap_runtime_state: is_blockswap_model = True - # Get the actual model (handle CompatibleDiT wrapper) - if hasattr(model, "dit_model"): - actual_model = model.dit_model # Get current device try: @@ -1211,10 +1328,14 @@ def cleanup_dit(runner: Any, debug: Optional['Debug'] = None, cache_model: bool if debug: debug.log("DiT on MPS - skipping CPU movement before deletion", category="cleanup") else: - offload_target = getattr(runner, '_dit_offload_device', None) - if offload_target is None or offload_target == 'none': + if cache_model: offload_target = torch.device('cpu') - reason = "model caching" if cache_model else "releasing GPU memory" + reason = "cold-cache normalization" + else: + offload_target = getattr(runner, '_dit_offload_device', None) + if offload_target is None or offload_target == 'none': + offload_target = torch.device('cpu') + reason = "releasing GPU memory" manage_model_device(model=runner.dit, target_device=offload_target, model_name="DiT", debug=debug, reason=reason, runner=runner) elif param_device.type == 'meta' and debug: @@ -1223,10 +1344,25 @@ def cleanup_dit(runner: Any, debug: Optional['Debug'] = None, cache_model: bool pass # 3. Clean BlockSwap after model movement - if hasattr(runner, "_blockswap_active") and runner._blockswap_active: + if runner.dit is not None and (cache_model or getattr(runner, "_blockswap_active", False)): # Import here to avoid circular dependency from .blockswap import cleanup_blockswap - cleanup_blockswap(runner=runner, keep_state_for_cache=cache_model) + cleanup_blockswap(runner=runner, keep_state_for_cache=False) + + if cache_model and runner.dit is not None: + cached_dit_before_normalization = runner.dit + runner.dit = _normalize_cached_dit_model(runner.dit, debug=debug) + dit_cache_node_id = getattr(runner, '_dit_cache_node_id', None) + if dit_cache_node_id is not None: + from ..core.model_cache import get_global_cache + get_global_cache().replace_dit( + {'node_id': dit_cache_node_id}, + runner.dit, + debug=debug, + expected_model=cached_dit_before_normalization, + ) + if debug: + debug.log("DiT cache normalized to cold CPU model state", category="cache", force=True) # 4. Complete cleanup if not caching if not cache_model: @@ -1291,16 +1427,35 @@ def cleanup_vae(runner: Any, debug: Optional['Debug'] = None, cache_model: bool if debug: debug.log("VAE on MPS - skipping CPU movement before deletion", category="cleanup") else: - offload_target = getattr(runner, '_vae_offload_device', None) - if offload_target is None or offload_target == 'none': + if cache_model: offload_target = torch.device('cpu') - reason = "model caching" if cache_model else "releasing GPU memory" + reason = "cold-cache normalization" + else: + offload_target = getattr(runner, '_vae_offload_device', None) + if offload_target is None or offload_target == 'none': + offload_target = torch.device('cpu') + reason = "releasing GPU memory" manage_model_device(model=runner.vae, target_device=offload_target, model_name="VAE", debug=debug, reason=reason, runner=runner) elif param_device.type == 'meta' and debug: debug.log("VAE on meta device - keeping structure for cache", category="cleanup") except StopIteration: pass + + if cache_model and runner.vae is not None: + cached_vae_before_normalization = runner.vae + runner.vae = _normalize_cached_vae_model(runner.vae, debug=debug) + vae_cache_node_id = getattr(runner, '_vae_cache_node_id', None) + if vae_cache_node_id is not None: + from ..core.model_cache import get_global_cache + get_global_cache().replace_vae( + {'node_id': vae_cache_node_id}, + runner.vae, + debug=debug, + expected_model=cached_vae_before_normalization, + ) + if debug: + debug.log("VAE cache normalized to cold CPU model state", category="cache", force=True) # 3. Complete cleanup if not caching if not cache_model: From 7d28241f0bf25b27940b4aaa41fe8a53ca4928f3 Mon Sep 17 00:00:00 2001 From: xmarre Date: Mon, 16 Mar 2026 04:30:14 +0100 Subject: [PATCH 09/29] Allow reuse of hot model cache --- inference_cli.py | 2 + src/core/model_cache.py | 16 -------- src/interfaces/video_upscaler.py | 2 + src/optimization/memory_manager.py | 66 ++++++++++-------------------- 4 files changed, 26 insertions(+), 60 deletions(-) diff --git a/inference_cli.py b/inference_cli.py index 0d9a804a..78d523c5 100644 --- a/inference_cli.py +++ b/inference_cli.py @@ -984,6 +984,8 @@ def cleanup(dit_cache_flag: bool = False, vae_cache_flag: bool = False) -> None: runner._seedvr2_execution_active = True runner._seedvr2_runner_tainted = False + runner._seedvr2_dit_phase_cleaned = False + runner._seedvr2_vae_phase_cleaned = False if ( cache_context is not None diff --git a/src/core/model_cache.py b/src/core/model_cache.py index 21563ec1..a4e3d28b 100644 --- a/src/core/model_cache.py +++ b/src/core/model_cache.py @@ -64,14 +64,6 @@ def get_dit(self, dit_config: Dict[str, Any], debug: Optional['Debug'] = None) - with self._model_cache_lock: if node_id in self._dit_models: model, stored_config = self._dit_models[node_id] - if not is_model_cache_cold(model): - if debug: - debug.log( - f"Cached DiT is still hot or incomplete; waiting for cold cache state before reuse (node {node_id})", - category="cache", - force=True, - ) - return None if is_model_cache_claimed(model): if debug: debug.log( @@ -112,14 +104,6 @@ def get_vae(self, vae_config: Dict[str, Any], debug: Optional['Debug'] = None) - with self._model_cache_lock: if node_id in self._vae_models: model, stored_config = self._vae_models[node_id] - if not is_model_cache_cold(model): - if debug: - debug.log( - f"Cached VAE is still hot or incomplete; waiting for cold cache state before reuse (node {node_id})", - category="cache", - force=True, - ) - return None if is_model_cache_claimed(model): if debug: debug.log( diff --git a/src/interfaces/video_upscaler.py b/src/interfaces/video_upscaler.py index 6c93a256..b0eae88d 100644 --- a/src/interfaces/video_upscaler.py +++ b/src/interfaces/video_upscaler.py @@ -452,6 +452,8 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: runner._seedvr2_execution_active = True runner._seedvr2_runner_tainted = False + runner._seedvr2_dit_phase_cleaned = False + runner._seedvr2_vae_phase_cleaned = False # If both models were already cached but the runner template had been # invalidated or missing, cache this freshly configured runner now. diff --git a/src/optimization/memory_manager.py b/src/optimization/memory_manager.py index c0efd00f..4354b307 100644 --- a/src/optimization/memory_manager.py +++ b/src/optimization/memory_manager.py @@ -1328,14 +1328,10 @@ def cleanup_dit(runner: Any, debug: Optional['Debug'] = None, cache_model: bool if debug: debug.log("DiT on MPS - skipping CPU movement before deletion", category="cleanup") else: - if cache_model: + offload_target = getattr(runner, '_dit_offload_device', None) + if offload_target is None or offload_target == 'none': offload_target = torch.device('cpu') - reason = "cold-cache normalization" - else: - offload_target = getattr(runner, '_dit_offload_device', None) - if offload_target is None or offload_target == 'none': - offload_target = torch.device('cpu') - reason = "releasing GPU memory" + reason = "model caching" if cache_model else "releasing GPU memory" manage_model_device(model=runner.dit, target_device=offload_target, model_name="DiT", debug=debug, reason=reason, runner=runner) elif param_device.type == 'meta' and debug: @@ -1344,25 +1340,14 @@ def cleanup_dit(runner: Any, debug: Optional['Debug'] = None, cache_model: bool pass # 3. Clean BlockSwap after model movement - if runner.dit is not None and (cache_model or getattr(runner, "_blockswap_active", False)): + if hasattr(runner, "_blockswap_active") and runner._blockswap_active: # Import here to avoid circular dependency from .blockswap import cleanup_blockswap - cleanup_blockswap(runner=runner, keep_state_for_cache=False) + cleanup_blockswap(runner=runner, keep_state_for_cache=cache_model) if cache_model and runner.dit is not None: - cached_dit_before_normalization = runner.dit - runner.dit = _normalize_cached_dit_model(runner.dit, debug=debug) - dit_cache_node_id = getattr(runner, '_dit_cache_node_id', None) - if dit_cache_node_id is not None: - from ..core.model_cache import get_global_cache - get_global_cache().replace_dit( - {'node_id': dit_cache_node_id}, - runner.dit, - debug=debug, - expected_model=cached_dit_before_normalization, - ) - if debug: - debug.log("DiT cache normalized to cold CPU model state", category="cache", force=True) + set_model_cache_cold_state(runner.dit, False) + runner._seedvr2_dit_phase_cleaned = True # 4. Complete cleanup if not caching if not cache_model: @@ -1379,6 +1364,9 @@ def cleanup_dit(runner: Any, debug: Optional['Debug'] = None, cache_model: bool if hasattr(runner, '_dit_attention_mode'): delattr(runner, '_dit_attention_mode') + if not cache_model: + runner._seedvr2_dit_phase_cleaned = True + # 5. Clear DiT temporary attributes (should be already cleared in materialize_model) runner._dit_checkpoint = None runner._dit_dtype_override = None @@ -1427,14 +1415,10 @@ def cleanup_vae(runner: Any, debug: Optional['Debug'] = None, cache_model: bool if debug: debug.log("VAE on MPS - skipping CPU movement before deletion", category="cleanup") else: - if cache_model: + offload_target = getattr(runner, '_vae_offload_device', None) + if offload_target is None or offload_target == 'none': offload_target = torch.device('cpu') - reason = "cold-cache normalization" - else: - offload_target = getattr(runner, '_vae_offload_device', None) - if offload_target is None or offload_target == 'none': - offload_target = torch.device('cpu') - reason = "releasing GPU memory" + reason = "model caching" if cache_model else "releasing GPU memory" manage_model_device(model=runner.vae, target_device=offload_target, model_name="VAE", debug=debug, reason=reason, runner=runner) elif param_device.type == 'meta' and debug: @@ -1443,19 +1427,8 @@ def cleanup_vae(runner: Any, debug: Optional['Debug'] = None, cache_model: bool pass if cache_model and runner.vae is not None: - cached_vae_before_normalization = runner.vae - runner.vae = _normalize_cached_vae_model(runner.vae, debug=debug) - vae_cache_node_id = getattr(runner, '_vae_cache_node_id', None) - if vae_cache_node_id is not None: - from ..core.model_cache import get_global_cache - get_global_cache().replace_vae( - {'node_id': vae_cache_node_id}, - runner.vae, - debug=debug, - expected_model=cached_vae_before_normalization, - ) - if debug: - debug.log("VAE cache normalized to cold CPU model state", category="cache", force=True) + set_model_cache_cold_state(runner.vae, False) + runner._seedvr2_vae_phase_cleaned = True # 3. Complete cleanup if not caching if not cache_model: @@ -1470,6 +1443,9 @@ def cleanup_vae(runner: Any, debug: Optional['Debug'] = None, cache_model: bool if hasattr(runner, '_vae_tiling_config'): delattr(runner, '_vae_tiling_config') + if not cache_model: + runner._seedvr2_vae_phase_cleaned = True + # 3. Clear VAE temporary attributes (should be already cleared in materialize_model) runner._vae_checkpoint = None runner._vae_dtype_override = None @@ -1507,10 +1483,12 @@ def complete_cleanup(runner: Any, debug: Optional['Debug'] = None, dit_cache: bo # 1. Cleanup any remaining models if they still exist # (This handles cases where phases were skipped or errored) if hasattr(runner, 'dit') and runner.dit is not None: - cleanup_dit(runner=runner, debug=debug, cache_model=dit_cache) + if not (dit_cache and getattr(runner, '_seedvr2_dit_phase_cleaned', False)): + cleanup_dit(runner=runner, debug=debug, cache_model=dit_cache) if hasattr(runner, 'vae') and runner.vae is not None: - cleanup_vae(runner=runner, debug=debug, cache_model=vae_cache) + if not (vae_cache and getattr(runner, '_seedvr2_vae_phase_cleaned', False)): + cleanup_vae(runner=runner, debug=debug, cache_model=vae_cache) # 2. Clear remaining runtime caches clear_runtime_caches(runner=runner, debug=debug) From 549c065309571318232d3b668f43d7093b6adb1e Mon Sep 17 00:00:00 2001 From: xmarre Date: Thu, 19 Mar 2026 16:38:27 +0100 Subject: [PATCH 10/29] Fix cached model claim ownership during teardown --- inference_cli.py | 42 +++++++++++----- src/core/model_cache.py | 9 ++-- src/core/model_configuration.py | 84 ++++++++++++++++++++++++++++---- src/interfaces/video_upscaler.py | 42 +++++++++++----- 4 files changed, 141 insertions(+), 36 deletions(-) diff --git a/inference_cli.py b/inference_cli.py index 78d523c5..5523a8e4 100644 --- a/inference_cli.py +++ b/inference_cli.py @@ -130,7 +130,10 @@ decode_all_batches, postprocess_all_batches ) -from src.core.model_configuration import _evict_claimed_cached_models +from src.core.model_configuration import ( + _evict_claimed_cached_models, + _finalize_claimed_cached_models_for_reuse, +) from src.utils.debug import Debug from src.optimization.memory_manager import ( cleanup_text_embeddings, @@ -928,18 +931,35 @@ def cleanup(dit_cache_flag: bool = False, vae_cache_flag: bool = False) -> None: nonlocal runner, ctx if runner is not None: + claimed_dit = cache_context.get('cached_dit') if cache_context is not None else None + claimed_vae = cache_context.get('cached_vae') if cache_context is not None else None try: - complete_cleanup( - runner=runner, - debug=debug, - dit_cache=dit_cache_flag, - vae_cache=vae_cache_flag, - ) - if dit_cache_flag and getattr(runner, 'dit', None) is not None: - set_model_cache_claimed_state(runner.dit, False) - if vae_cache_flag and getattr(runner, 'vae', None) is not None: - set_model_cache_claimed_state(runner.vae, False) + try: + complete_cleanup( + runner=runner, + debug=debug, + dit_cache=dit_cache_flag, + vae_cache=vae_cache_flag, + ) + if dit_cache_flag or vae_cache_flag: + _finalize_claimed_cached_models_for_reuse(cache_context, runner, debug) + except Exception: + try: + _evict_claimed_cached_models(cache_context, runner, debug) + except Exception as evict_error: + if debug is not None: + debug.log( + f"Failed to evict claimed cached models while handling prior cleanup/finalize exception: {evict_error}", + level="WARNING", + category="cleanup", + force=True, + ) + raise finally: + if dit_cache_flag and claimed_dit is not None: + set_model_cache_claimed_state(claimed_dit, False) + if vae_cache_flag and claimed_vae is not None: + set_model_cache_claimed_state(claimed_vae, False) runner._seedvr2_execution_active = False if not (dit_cache_flag or vae_cache_flag): diff --git a/src/core/model_cache.py b/src/core/model_cache.py index a4e3d28b..c7f07965 100644 --- a/src/core/model_cache.py +++ b/src/core/model_cache.py @@ -7,7 +7,6 @@ from typing import Dict, Any, Optional, Tuple, TYPE_CHECKING from ..optimization.memory_manager import ( is_model_cache_claimed, - is_model_cache_cold, iter_model_wrapper_chain, release_model_memory, set_model_cache_claimed_state, @@ -269,7 +268,7 @@ def replace_dit( debug: Optional['Debug'] = None, expected_model: Optional[Any] = None, ) -> bool: - """Rewrite a cached DiT entry to a normalized canonical model.""" + """Rewrite a cached DiT entry to the latest claimed model object.""" node_id = dit_config.get('node_id') with self._model_cache_lock: if node_id not in self._dit_models: @@ -288,7 +287,7 @@ def replace_dit( self._dit_models[node_id] = (model, stored_config) if debug: - debug.log(f"Rewrote cached DiT entry to cold canonical model (node {node_id})", category="cache", force=True) + debug.log(f"Refreshed cached DiT entry to the latest claimed model object (node {node_id})", category="cache", force=True) return True def replace_vae( @@ -298,7 +297,7 @@ def replace_vae( debug: Optional['Debug'] = None, expected_model: Optional[Any] = None, ) -> bool: - """Rewrite a cached VAE entry to a normalized canonical model.""" + """Rewrite a cached VAE entry to the latest claimed model object.""" node_id = vae_config.get('node_id') with self._model_cache_lock: if node_id not in self._vae_models: @@ -317,7 +316,7 @@ def replace_vae( self._vae_models[node_id] = (model, stored_config) if debug: - debug.log(f"Rewrote cached VAE entry to cold canonical model (node {node_id})", category="cache", force=True) + debug.log(f"Refreshed cached VAE entry to the latest claimed model object (node {node_id})", category="cache", force=True) return True def set_runner(self, dit_id: Optional[int], vae_id: Optional[int], diff --git a/src/core/model_configuration.py b/src/core/model_configuration.py index 2e683d02..413ddc4f 100644 --- a/src/core/model_configuration.py +++ b/src/core/model_configuration.py @@ -748,9 +748,31 @@ def _acquire_runner( return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) if template_status == "claimed": - debug.log(f"Reusing cached runner template: nodes {runner_key}", category="reuse", force=True) - cache_context['reusing_runner'] = True - return template + need_dit = bool(cache_context.get('dit_cache') and cache_context.get('dit_id') is not None) + need_vae = bool(cache_context.get('vae_cache') and cache_context.get('vae_id') is not None) + have_dit = (not need_dit) or (cache_context.get('cached_dit') is not None) + have_vae = (not need_vae) or (cache_context.get('cached_vae') is not None) + + if have_dit and have_vae: + debug.log(f"Reusing cached runner template: nodes {runner_key}", category="reuse", force=True) + cache_context['reusing_runner'] = True + return template + + debug.log( + "Runner template matched, but required claimed cached models were not acquired; creating a fresh runner", + level="WARNING", + category="cache", + force=True, + ) + template._seedvr2_runner_tainted = True + template._seedvr2_execution_active = False + cache_context['global_cache'].remove_runner( + cache_context['dit_id'], + cache_context['vae_id'], + debug, + expected_runner=template, + ) + return _create_new_runner(dit_model, vae_model, base_cache_dir, debug) current_dit = getattr(template, '_dit_model_name', None) current_vae = getattr(template, '_vae_model_name', None) @@ -968,15 +990,59 @@ def _evict_claimed_cached_models( if global_cache is None: return + claimed_dit = cache_context.get('cached_dit') + claimed_vae = cache_context.get('cached_vae') + dit_id = cache_context.get('dit_id') - if cache_context.get('dit_cache') and dit_id is not None: - expected_dit = (getattr(runner, 'dit', None) if runner is not None else None) or cache_context.get('cached_dit') - global_cache.remove_dit({'node_id': dit_id}, debug, expected_model=expected_dit) + if cache_context.get('dit_cache') and dit_id is not None and claimed_dit is not None: + global_cache.remove_dit({'node_id': dit_id}, debug, expected_model=claimed_dit) vae_id = cache_context.get('vae_id') - if cache_context.get('vae_cache') and vae_id is not None: - expected_vae = (getattr(runner, 'vae', None) if runner is not None else None) or cache_context.get('cached_vae') - global_cache.remove_vae({'node_id': vae_id}, debug, expected_model=expected_vae) + if cache_context.get('vae_cache') and vae_id is not None and claimed_vae is not None: + global_cache.remove_vae({'node_id': vae_id}, debug, expected_model=claimed_vae) + + +def _finalize_claimed_cached_models_for_reuse( + cache_context: Dict[str, Any], + runner: Optional[VideoDiffusionInfer], + debug: Optional['Debug'] = None, +) -> None: + """Refresh claimed cache entries to the post-cleanup released model objects.""" + if not cache_context or runner is None: + return + + global_cache = cache_context.get('global_cache') + if global_cache is None: + return + + claimed_dit = cache_context.get('cached_dit') + claimed_vae = cache_context.get('cached_vae') + + dit_id = cache_context.get('dit_id') + if cache_context.get('dit_cache') and dit_id is not None and claimed_dit is not None: + released_dit = getattr(runner, 'dit', None) + if released_dit is not None: + if global_cache.replace_dit({'node_id': dit_id}, released_dit, debug, expected_model=claimed_dit): + runner.dit = released_dit + else: + global_cache.remove_dit({'node_id': dit_id}, debug, expected_model=claimed_dit) + runner.dit = None + else: + global_cache.remove_dit({'node_id': dit_id}, debug, expected_model=claimed_dit) + runner.dit = None + + vae_id = cache_context.get('vae_id') + if cache_context.get('vae_cache') and vae_id is not None and claimed_vae is not None: + released_vae = getattr(runner, 'vae', None) + if released_vae is not None: + if global_cache.replace_vae({'node_id': vae_id}, released_vae, debug, expected_model=claimed_vae): + runner.vae = released_vae + else: + global_cache.remove_vae({'node_id': vae_id}, debug, expected_model=claimed_vae) + runner.vae = None + else: + global_cache.remove_vae({'node_id': vae_id}, debug, expected_model=claimed_vae) + runner.vae = None def _configure_runner_settings( diff --git a/src/interfaces/video_upscaler.py b/src/interfaces/video_upscaler.py index b0eae88d..168fae8d 100644 --- a/src/interfaces/video_upscaler.py +++ b/src/interfaces/video_upscaler.py @@ -23,7 +23,10 @@ load_text_embeddings, script_directory ) -from ..core.model_configuration import _evict_claimed_cached_models +from ..core.model_configuration import ( + _evict_claimed_cached_models, + _finalize_claimed_cached_models_for_reuse, +) from ..optimization.memory_manager import ( cleanup_text_embeddings, complete_cleanup, @@ -322,18 +325,35 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: # Use complete_cleanup for all cleanup operations if runner: + claimed_dit = cache_context.get('cached_dit') if cache_context is not None else None + claimed_vae = cache_context.get('cached_vae') if cache_context is not None else None try: - complete_cleanup( - runner=runner, - debug=debug, - dit_cache=dit_cache, - vae_cache=vae_cache, - ) - if dit_cache and getattr(runner, 'dit', None) is not None: - set_model_cache_claimed_state(runner.dit, False) - if vae_cache and getattr(runner, 'vae', None) is not None: - set_model_cache_claimed_state(runner.vae, False) + try: + complete_cleanup( + runner=runner, + debug=debug, + dit_cache=dit_cache, + vae_cache=vae_cache, + ) + if dit_cache or vae_cache: + _finalize_claimed_cached_models_for_reuse(cache_context, runner, debug) + except Exception: + try: + _evict_claimed_cached_models(cache_context, runner, debug) + except Exception as evict_error: + if debug is not None: + debug.log( + f"Failed to evict claimed cached models while handling prior cleanup/finalize exception: {evict_error}", + level="WARNING", + category="cleanup", + force=True, + ) + raise finally: + if dit_cache and claimed_dit is not None: + set_model_cache_claimed_state(claimed_dit, False) + if vae_cache and claimed_vae is not None: + set_model_cache_claimed_state(claimed_vae, False) runner._seedvr2_execution_active = False # Delete runner only if neither model is cached From 92d7a6e0149d69141bb106bc1b9db5999984f8fd Mon Sep 17 00:00:00 2001 From: xmarre Date: Thu, 19 Mar 2026 18:57:59 +0100 Subject: [PATCH 11/29] Fix atomic runner eviction and claim cleanup --- inference_cli.py | 34 +++++++++++++--------- src/core/model_cache.py | 33 +++++++++++++++++++++ src/core/model_configuration.py | 10 ++----- src/interfaces/video_upscaler.py | 50 ++++++++++++++++++-------------- 4 files changed, 84 insertions(+), 43 deletions(-) diff --git a/inference_cli.py b/inference_cli.py index 5523a8e4..52ac41fb 100644 --- a/inference_cli.py +++ b/inference_cli.py @@ -1092,25 +1092,31 @@ def cleanup(dit_cache_flag: bool = False, vae_cache_flag: bool = False) -> None: cleanup(dit_cache_flag=cache_dit, vae_cache_flag=cache_vae) return result_tensor except BaseException: - if runner is not None: - runner._seedvr2_runner_tainted = True - + claimed_dit = cache_context.get('cached_dit') if cache_context is not None else None + claimed_vae = cache_context.get('cached_vae') if cache_context is not None else None if cache_context is not None: _evict_claimed_cached_models(cache_context, runner, debug) + if runner is not None and cache_context.get('reusing_runner', False): + try: + cache_context['global_cache'].taint_and_remove_runner( + cache_context.get('dit_id'), + cache_context.get('vae_id'), + debug, + expected_runner=runner, + ) + except BaseException: + pass + + try: try: - cache_context['global_cache'].remove_runner( - cache_context.get('dit_id'), - cache_context.get('vae_id'), - debug, - expected_runner=runner, - ) + cleanup(dit_cache_flag=False, vae_cache_flag=False) except BaseException: pass - - try: - cleanup(dit_cache_flag=False, vae_cache_flag=False) - except BaseException: - pass + finally: + if claimed_dit is not None: + set_model_cache_claimed_state(claimed_dit, False) + if claimed_vae is not None: + set_model_cache_claimed_state(claimed_vae, False) raise diff --git a/src/core/model_cache.py b/src/core/model_cache.py index c7f07965..eb11fdfb 100644 --- a/src/core/model_cache.py +++ b/src/core/model_cache.py @@ -388,6 +388,39 @@ def remove_runner(self, dit_id: Optional[int], vae_id: Optional[int], if debug: debug.log(f"Removed cached runner template: nodes {runner_key}", category="cache", force=True) return True + + def taint_and_remove_runner(self, + dit_id: Optional[int], + vae_id: Optional[int], + debug: Optional['Debug'] = None, + expected_runner: Optional[Any] = None) -> bool: + """Atomically mark a cached runner template tainted/inactive and remove it.""" + if dit_id is None or vae_id is None: + return False + + runner_key = f"{dit_id}+{vae_id}" + with self._runner_templates_lock: + cached_runner = self._runner_templates.get(runner_key) + if cached_runner is None: + return False + + if expected_runner is not None and cached_runner is not expected_runner: + if debug: + debug.log( + f"Skipped taint+remove for cached runner nodes {runner_key}: cache entry no longer matches expected runner", + level="WARNING", + category="cache", + force=True, + ) + return False + + cached_runner._seedvr2_runner_tainted = True + cached_runner._seedvr2_execution_active = False + del self._runner_templates[runner_key] + + if debug: + debug.log(f"Tainted and removed cached runner template: nodes {runner_key}", category="cache", force=True) + return True def remove_dit( self, diff --git a/src/core/model_configuration.py b/src/core/model_configuration.py index 413ddc4f..26016774 100644 --- a/src/core/model_configuration.py +++ b/src/core/model_configuration.py @@ -764,9 +764,7 @@ def _acquire_runner( category="cache", force=True, ) - template._seedvr2_runner_tainted = True - template._seedvr2_execution_active = False - cache_context['global_cache'].remove_runner( + cache_context['global_cache'].taint_and_remove_runner( cache_context['dit_id'], cache_context['vae_id'], debug, @@ -949,10 +947,8 @@ def configure_runner( except BaseException: _evict_claimed_cached_models(cache_context, runner, debug) if runner is not None and cache_context.get('reusing_runner', False): - runner._seedvr2_runner_tainted = True - runner._seedvr2_execution_active = False try: - cache_context['global_cache'].remove_runner( + cache_context['global_cache'].taint_and_remove_runner( cache_context.get('dit_id'), cache_context.get('vae_id'), debug, @@ -1007,7 +1003,7 @@ def _finalize_claimed_cached_models_for_reuse( runner: Optional[VideoDiffusionInfer], debug: Optional['Debug'] = None, ) -> None: - """Refresh claimed cache entries to the post-cleanup released model objects.""" + """Refresh or evict claimed cache entries using the released runner-held model refs.""" if not cache_context or runner is None: return diff --git a/src/interfaces/video_upscaler.py b/src/interfaces/video_upscaler.py index 168fae8d..a0442fc0 100644 --- a/src/interfaces/video_upscaler.py +++ b/src/interfaces/video_upscaler.py @@ -633,35 +633,41 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: return io.NodeOutput(sample) except BaseException: - if runner is not None: - runner._seedvr2_runner_tainted = True - + claimed_dit = cache_context.get('cached_dit') if cache_context is not None else None + claimed_vae = cache_context.get('cached_vae') if cache_context is not None else None if cache_context is not None: _evict_claimed_cached_models(cache_context, runner, debug) + if runner is not None and cache_context.get('reusing_runner', False): + try: + cache_context['global_cache'].taint_and_remove_runner( + cache_context.get('dit_id'), + cache_context.get('vae_id'), + debug, + expected_runner=runner, + ) + except Exception as cache_error: + if debug is not None: + debug.log( + f"Failed to evict cached runner while handling prior exception: {cache_error}", + level="WARNING", + category="cleanup", + force=True, + ) + + try: try: - cache_context['global_cache'].remove_runner( - cache_context.get('dit_id'), - cache_context.get('vae_id'), - debug, - expected_runner=runner, - ) - except Exception as cache_error: + cleanup(dit_cache=False, vae_cache=False) + except BaseException as cleanup_error: if debug is not None: debug.log( - f"Failed to evict cached runner while handling prior exception: {cache_error}", + f"Cleanup failed while handling prior exception: {cleanup_error}", level="WARNING", category="cleanup", force=True, ) - - try: - cleanup(dit_cache=False, vae_cache=False) - except BaseException as cleanup_error: - if debug is not None: - debug.log( - f"Cleanup failed while handling prior exception: {cleanup_error}", - level="WARNING", - category="cleanup", - force=True, - ) + finally: + if claimed_dit is not None: + set_model_cache_claimed_state(claimed_dit, False) + if claimed_vae is not None: + set_model_cache_claimed_state(claimed_vae, False) raise From 8a07d157d4460b865dd16e8322d1e256730fe7d0 Mon Sep 17 00:00:00 2001 From: xmarre Date: Fri, 20 Mar 2026 11:45:51 +0100 Subject: [PATCH 12/29] Fix cached claim release after finalize refresh --- inference_cli.py | 30 +++++++++++++++++++++++++----- src/core/model_configuration.py | 13 ++++++++++--- src/interfaces/video_upscaler.py | 8 +++++++- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/inference_cli.py b/inference_cli.py index 52ac41fb..5e1fec4a 100644 --- a/inference_cli.py +++ b/inference_cli.py @@ -933,6 +933,8 @@ def cleanup(dit_cache_flag: bool = False, vae_cache_flag: bool = False) -> None: if runner is not None: claimed_dit = cache_context.get('cached_dit') if cache_context is not None else None claimed_vae = cache_context.get('cached_vae') if cache_context is not None else None + refreshed_dit = None + refreshed_vae = None try: try: complete_cleanup( @@ -942,7 +944,7 @@ def cleanup(dit_cache_flag: bool = False, vae_cache_flag: bool = False) -> None: vae_cache=vae_cache_flag, ) if dit_cache_flag or vae_cache_flag: - _finalize_claimed_cached_models_for_reuse(cache_context, runner, debug) + refreshed_dit, refreshed_vae = _finalize_claimed_cached_models_for_reuse(cache_context, runner, debug) except Exception: try: _evict_claimed_cached_models(cache_context, runner, debug) @@ -958,8 +960,12 @@ def cleanup(dit_cache_flag: bool = False, vae_cache_flag: bool = False) -> None: finally: if dit_cache_flag and claimed_dit is not None: set_model_cache_claimed_state(claimed_dit, False) + if dit_cache_flag and refreshed_dit is not None and refreshed_dit is not claimed_dit: + set_model_cache_claimed_state(refreshed_dit, False) if vae_cache_flag and claimed_vae is not None: set_model_cache_claimed_state(claimed_vae, False) + if vae_cache_flag and refreshed_vae is not None and refreshed_vae is not claimed_vae: + set_model_cache_claimed_state(refreshed_vae, False) runner._seedvr2_execution_active = False if not (dit_cache_flag or vae_cache_flag): @@ -1104,14 +1110,28 @@ def cleanup(dit_cache_flag: bool = False, vae_cache_flag: bool = False) -> None: debug, expected_runner=runner, ) - except BaseException: - pass + except BaseException as cache_error: + if debug is not None: + debug.log( + f"Failed to evict cached runner while handling prior exception " + f"(runner={id(runner)}, dit_id={cache_context.get('dit_id')}, vae_id={cache_context.get('vae_id')}): {cache_error}", + level="WARNING", + category="cleanup", + force=True, + ) try: try: cleanup(dit_cache_flag=False, vae_cache_flag=False) - except BaseException: - pass + except BaseException as cleanup_error: + if debug is not None: + debug.log( + f"Cleanup failed while handling prior exception " + f"(runner={id(runner) if runner is not None else 'none'}, dit_id={cache_context.get('dit_id') if cache_context is not None else 'none'}, vae_id={cache_context.get('vae_id') if cache_context is not None else 'none'}): {cleanup_error}", + level="WARNING", + category="cleanup", + force=True, + ) finally: if claimed_dit is not None: set_model_cache_claimed_state(claimed_dit, False) diff --git a/src/core/model_configuration.py b/src/core/model_configuration.py index 26016774..c7a21039 100644 --- a/src/core/model_configuration.py +++ b/src/core/model_configuration.py @@ -1002,14 +1002,17 @@ def _finalize_claimed_cached_models_for_reuse( cache_context: Dict[str, Any], runner: Optional[VideoDiffusionInfer], debug: Optional['Debug'] = None, -) -> None: +) -> Tuple[Optional[Any], Optional[Any]]: """Refresh or evict claimed cache entries using the released runner-held model refs.""" + refreshed_dit = None + refreshed_vae = None + if not cache_context or runner is None: - return + return refreshed_dit, refreshed_vae global_cache = cache_context.get('global_cache') if global_cache is None: - return + return refreshed_dit, refreshed_vae claimed_dit = cache_context.get('cached_dit') claimed_vae = cache_context.get('cached_vae') @@ -1019,6 +1022,7 @@ def _finalize_claimed_cached_models_for_reuse( released_dit = getattr(runner, 'dit', None) if released_dit is not None: if global_cache.replace_dit({'node_id': dit_id}, released_dit, debug, expected_model=claimed_dit): + refreshed_dit = released_dit runner.dit = released_dit else: global_cache.remove_dit({'node_id': dit_id}, debug, expected_model=claimed_dit) @@ -1032,6 +1036,7 @@ def _finalize_claimed_cached_models_for_reuse( released_vae = getattr(runner, 'vae', None) if released_vae is not None: if global_cache.replace_vae({'node_id': vae_id}, released_vae, debug, expected_model=claimed_vae): + refreshed_vae = released_vae runner.vae = released_vae else: global_cache.remove_vae({'node_id': vae_id}, debug, expected_model=claimed_vae) @@ -1040,6 +1045,8 @@ def _finalize_claimed_cached_models_for_reuse( global_cache.remove_vae({'node_id': vae_id}, debug, expected_model=claimed_vae) runner.vae = None + return refreshed_dit, refreshed_vae + def _configure_runner_settings( runner: VideoDiffusionInfer, diff --git a/src/interfaces/video_upscaler.py b/src/interfaces/video_upscaler.py index a0442fc0..b5643521 100644 --- a/src/interfaces/video_upscaler.py +++ b/src/interfaces/video_upscaler.py @@ -327,6 +327,8 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: if runner: claimed_dit = cache_context.get('cached_dit') if cache_context is not None else None claimed_vae = cache_context.get('cached_vae') if cache_context is not None else None + refreshed_dit = None + refreshed_vae = None try: try: complete_cleanup( @@ -336,7 +338,7 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: vae_cache=vae_cache, ) if dit_cache or vae_cache: - _finalize_claimed_cached_models_for_reuse(cache_context, runner, debug) + refreshed_dit, refreshed_vae = _finalize_claimed_cached_models_for_reuse(cache_context, runner, debug) except Exception: try: _evict_claimed_cached_models(cache_context, runner, debug) @@ -352,8 +354,12 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: finally: if dit_cache and claimed_dit is not None: set_model_cache_claimed_state(claimed_dit, False) + if dit_cache and refreshed_dit is not None and refreshed_dit is not claimed_dit: + set_model_cache_claimed_state(refreshed_dit, False) if vae_cache and claimed_vae is not None: set_model_cache_claimed_state(claimed_vae, False) + if vae_cache and refreshed_vae is not None and refreshed_vae is not claimed_vae: + set_model_cache_claimed_state(refreshed_vae, False) runner._seedvr2_execution_active = False # Delete runner only if neither model is cached From 08fad5bb341eccbb4185e12ef5d597517e040dfa Mon Sep 17 00:00:00 2001 From: xmarre Date: Fri, 20 Mar 2026 13:22:25 +0100 Subject: [PATCH 13/29] Track newly cached models for claim release --- src/core/generation_phases.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/core/generation_phases.py b/src/core/generation_phases.py index eb490443..8c4e5455 100644 --- a/src/core/generation_phases.py +++ b/src/core/generation_phases.py @@ -318,11 +318,13 @@ def encode_all_batches( # Cache VAE now that it's fully configured and ready for inference if ctx['cache_context']['vae_cache'] and not ctx['cache_context']['cached_vae']: runner.vae._model_name = ctx['cache_context']['vae_model'] - ctx['cache_context']['global_cache'].set_vae( + cached_vae_id = ctx['cache_context']['global_cache'].set_vae( {'node_id': ctx['cache_context']['vae_id'], 'cache_model': True}, runner.vae, ctx['cache_context']['vae_model'], debug ) - ctx['cache_context']['vae_newly_cached'] = True + if cached_vae_id is not None: + ctx['cache_context']['vae_newly_cached'] = True + ctx['cache_context']['cached_vae'] = runner.vae # If both models now cached, cache runner template dit_is_cached = ctx['cache_context']['cached_dit'] or ctx['cache_context']['dit_newly_cached'] @@ -648,11 +650,13 @@ def upscale_all_batches( # Cache DiT now that it's fully configured and ready for inference if ctx['cache_context']['dit_cache'] and not ctx['cache_context']['cached_dit']: runner.dit._model_name = ctx['cache_context']['dit_model'] - ctx['cache_context']['global_cache'].set_dit( + cached_dit_id = ctx['cache_context']['global_cache'].set_dit( {'node_id': ctx['cache_context']['dit_id'], 'cache_model': True}, runner.dit, ctx['cache_context']['dit_model'], debug ) - ctx['cache_context']['dit_newly_cached'] = True + if cached_dit_id is not None: + ctx['cache_context']['dit_newly_cached'] = True + ctx['cache_context']['cached_dit'] = runner.dit # If both models now cached, cache runner template vae_is_cached = ctx['cache_context']['cached_vae'] or ctx['cache_context']['vae_newly_cached'] From c6466fa1a67714adaacd7809f4de97237cae256a Mon Sep 17 00:00:00 2001 From: xmarre Date: Fri, 20 Mar 2026 19:56:46 +0100 Subject: [PATCH 14/29] Add breadcrumbs around SeedVR2 model prep --- src/core/model_configuration.py | 2 ++ src/interfaces/video_upscaler.py | 14 +++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/core/model_configuration.py b/src/core/model_configuration.py index c7a21039..02ce2abc 100644 --- a/src/core/model_configuration.py +++ b/src/core/model_configuration.py @@ -940,10 +940,12 @@ def configure_runner( ) # Phase 4: Setup models (load from cache or create new) + debug.log("SeedVR2 breadcrumb: before _setup_models", category="runner", force=True) _setup_models( runner, cache_context, dit_model, vae_model, base_cache_dir, block_swap_config, debug ) + debug.log("SeedVR2 breadcrumb: after _setup_models", category="runner", force=True) except BaseException: _evict_claimed_cached_models(cache_context, runner, debug) if runner is not None and cache_context.get('reusing_runner', False): diff --git a/src/interfaces/video_upscaler.py b/src/interfaces/video_upscaler.py index b5643521..75bf9bb1 100644 --- a/src/interfaces/video_upscaler.py +++ b/src/interfaces/video_upscaler.py @@ -453,6 +453,7 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: ) # Prepare runner with model state management and global cache + debug.log("SeedVR2 breadcrumb: before prepare_runner", category="runner", force=True) runner, cache_context = prepare_runner( dit_model=dit_model, vae_model=vae_model, @@ -475,6 +476,7 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: torch_compile_args_dit=dit_torch_compile_args, torch_compile_args_vae=vae_torch_compile_args ) + debug.log("SeedVR2 breadcrumb: after prepare_runner", category="runner", force=True) runner._seedvr2_execution_active = True runner._seedvr2_runner_tainted = False @@ -489,23 +491,32 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: and cache_context.get('cached_dit') is not None and cache_context.get('cached_vae') is not None ): + debug.log("SeedVR2 breadcrumb: before set_runner", category="runner", force=True) cache_context['global_cache'].set_runner( cache_context.get('dit_id'), cache_context.get('vae_id'), runner, debug, ) + debug.log("SeedVR2 breadcrumb: after set_runner", category="runner", force=True) # Store cache context in ctx for use in generation phases ctx['cache_context'] = cache_context + debug.log("SeedVR2 breadcrumb: before load_text_embeddings", category="dit", force=True) # Preload text embeddings before Phase 1 to avoid sync stall in Phase 2 ctx['text_embeds'] = load_text_embeddings(script_directory, ctx['dit_device'], ctx['compute_dtype'], debug) - debug.log("Loaded text embeddings for DiT", category="dit") + debug.log("SeedVR2 breadcrumb: after load_text_embeddings", category="dit", force=True) + debug.log("SeedVR2 breadcrumb: before log_memory_state", category="memory", force=True) debug.log_memory_state("After model preparation", show_tensors=False, detailed_tensors=False) + debug.log("SeedVR2 breadcrumb: after log_memory_state", category="memory", force=True) + + debug.log("SeedVR2 breadcrumb: before end_timer(model_preparation)", category="runner", force=True) debug.end_timer("model_preparation", "Model preparation", force=True, show_breakdown=True) + debug.log("SeedVR2 breadcrumb: after end_timer(model_preparation)", category="runner", force=True) + debug.log("SeedVR2 breadcrumb: before compute_generation_info", category="generation", force=True) # Compute generation info and log start (handles prepending internally) image, gen_info = compute_generation_info( ctx=ctx, @@ -519,6 +530,7 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: temporal_overlap=temporal_overlap, debug=debug ) + debug.log("SeedVR2 breadcrumb: after compute_generation_info", category="generation", force=True) # Log generation start in consistent format log_generation_start(gen_info, debug) From 9b572eda232eb435a7ff824be6f0e76b13f89af4 Mon Sep 17 00:00:00 2001 From: xmarre Date: Sat, 21 Mar 2026 00:09:23 +0100 Subject: [PATCH 15/29] Add breadcrumbs around video transform setup --- src/core/generation_utils.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/core/generation_utils.py b/src/core/generation_utils.py index 9a4cb3c5..f451d934 100644 --- a/src/core/generation_utils.py +++ b/src/core/generation_utils.py @@ -105,11 +105,14 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: existing_transform = ctx.get('video_transform') if existing_transform is not None: + debug.log("SeedVR2 breadcrumb: setup_video_transform using existing transform", category="setup", force=True) if debug else None # Transform exists - check if we need to compute dimensions if 'true_target_dims' in ctx and sample_frame is not None: # Return cached dimensions + recompute padded from sample + debug.log("SeedVR2 breadcrumb: before existing_transform(sample_frame)", category="setup", force=True) if debug else None true_h, true_w = ctx['true_target_dims'] transformed = existing_transform(sample_frame) + debug.log("SeedVR2 breadcrumb: after existing_transform(sample_frame)", category="setup", force=True) if debug else None padded_h, padded_w = transformed.shape[-2:] if debug: debug.log("Reusing pre-initialized video transformation pipeline", category="reuse") @@ -119,6 +122,7 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: return 0, 0, 0, 0 # Create transformation pipeline (first time or after cleanup) + debug.log("SeedVR2 breadcrumb: setup_video_transform creating new transform", category="setup", force=True) if debug else None ctx['video_transform'] = prepare_video_transforms(resolution, max_resolution, debug) # Compute dimensions if sample frame provided @@ -128,7 +132,9 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: NaResize(resolution=resolution, mode="side", downsample_only=False, max_resolution=max_resolution), Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) ]) + debug.log("SeedVR2 breadcrumb: before temp_transform(sample_frame)", category="setup", force=True) if debug else None resized_sample = temp_transform(sample_frame) + debug.log("SeedVR2 breadcrumb: after temp_transform(sample_frame)", category="setup", force=True) if debug else None true_h, true_w = resized_sample.shape[-2:] # Round to even numbers for video codec compatibility (libx264 requirement) @@ -139,7 +145,9 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: ctx['true_target_dims'] = (true_h, true_w) # Get padded dimensions + debug.log("SeedVR2 breadcrumb: before ctx['video_transform'](sample_frame)", category="setup", force=True) if debug else None transformed_sample = ctx['video_transform'](sample_frame) + debug.log("SeedVR2 breadcrumb: after ctx['video_transform'](sample_frame)", category="setup", force=True) if debug else None padded_h, padded_w = transformed_sample.shape[-2:] if debug: @@ -194,18 +202,23 @@ def compute_generation_info( channels_info = "RGBA" if images.shape[-1] == 4 else "RGB" # Apply prepending if requested + debug.log("SeedVR2 breadcrumb: compute_generation_info before temporal prepend", category="generation", force=True) if debug else None if prepend_frames > 0: images = pad_video_temporal(images, count=prepend_frames, temporal_dim=0, prepend=True, debug=debug) + debug.log("SeedVR2 breadcrumb: compute_generation_info after temporal prepend", category="generation", force=True) if debug else None # Track total frames after prepending total_frames = len(images) ctx['total_frames'] = total_frames # Setup transform and compute dimensions on final frame count + debug.log("SeedVR2 breadcrumb: compute_generation_info before sample_frame", category="generation", force=True) if debug else None sample_frame = images[0].permute(2, 0, 1).unsqueeze(0) + debug.log("SeedVR2 breadcrumb: compute_generation_info before setup_video_transform", category="generation", force=True) if debug else None true_h, true_w, padded_h, padded_w = setup_video_transform( ctx, resolution, max_resolution, debug, sample_frame ) + debug.log("SeedVR2 breadcrumb: compute_generation_info after setup_video_transform", category="generation", force=True) if debug else None del sample_frame info = { @@ -824,4 +837,4 @@ def ensure_precision_initialized( debug.log(f"Model precision: {', '.join(parts)}", category="precision") except Exception as e: - debug.log(f"Could not log model dtypes: {e}", level="WARNING", category="precision", force=True) \ No newline at end of file + debug.log(f"Could not log model dtypes: {e}", level="WARNING", category="precision", force=True) From 5cd8f96aa076f652fb15381522005542ae0a2a5d Mon Sep 17 00:00:00 2001 From: xmarre Date: Sat, 21 Mar 2026 17:12:38 +0100 Subject: [PATCH 16/29] Fix video transform dim planning without live tensor pass --- src/core/generation_phases.py | 2 ++ src/core/generation_utils.py | 61 +++++++++++++++++++++++++---------- 2 files changed, 46 insertions(+), 17 deletions(-) diff --git a/src/core/generation_phases.py b/src/core/generation_phases.py index 8c4e5455..f70ec979 100644 --- a/src/core/generation_phases.py +++ b/src/core/generation_phases.py @@ -1487,6 +1487,8 @@ def postprocess_all_batches( del ctx['all_ori_lengths'] if 'true_target_dims' in ctx: del ctx['true_target_dims'] + if 'padded_target_dims' in ctx: + del ctx['padded_target_dims'] if 'batch_metadata' in ctx: del ctx['batch_metadata'] if 'input_images' in ctx: diff --git a/src/core/generation_utils.py b/src/core/generation_utils.py index f451d934..c628faac 100644 --- a/src/core/generation_utils.py +++ b/src/core/generation_utils.py @@ -106,17 +106,45 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: if existing_transform is not None: debug.log("SeedVR2 breadcrumb: setup_video_transform using existing transform", category="setup", force=True) if debug else None - # Transform exists - check if we need to compute dimensions - if 'true_target_dims' in ctx and sample_frame is not None: - # Return cached dimensions + recompute padded from sample - debug.log("SeedVR2 breadcrumb: before existing_transform(sample_frame)", category="setup", force=True) if debug else None + # Transform exists - return cached dimensions without re-running the pipeline + if 'true_target_dims' in ctx and 'padded_target_dims' in ctx: true_h, true_w = ctx['true_target_dims'] - transformed = existing_transform(sample_frame) - debug.log("SeedVR2 breadcrumb: after existing_transform(sample_frame)", category="setup", force=True) if debug else None - padded_h, padded_w = transformed.shape[-2:] + padded_h, padded_w = ctx['padded_target_dims'] if debug: debug.log("Reusing pre-initialized video transformation pipeline", category="reuse") return true_h, true_w, padded_h, padded_w + if sample_frame is not None: + temp_transform = Compose([ + NaResize(resolution=resolution, mode="side", downsample_only=False, max_resolution=max_resolution), + Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) + ]) + debug.log("SeedVR2 breadcrumb: before temp_transform(sample_frame)", category="setup", force=True) if debug else None + resized_sample = temp_transform(sample_frame) + debug.log("SeedVR2 breadcrumb: after temp_transform(sample_frame)", category="setup", force=True) if debug else None + resized_h, resized_w = resized_sample.shape[-2:] + + # Round to even numbers for video codec compatibility (libx264 requirement) + true_h = (resized_h // 2) * 2 + true_w = (resized_w // 2) * 2 + + # Cache for later use in trimming + ctx['true_target_dims'] = (true_h, true_w) + + # Compute padded dimensions from the resized shape before even-rounding + padded_h = ((resized_h + 15) // 16) * 16 + padded_w = ((resized_w + 15) // 16) * 16 + ctx['padded_target_dims'] = (padded_h, padded_w) + + if debug: + if true_h == padded_h and true_w == padded_w: + debug.log(f"Target dimensions: {true_w}x{true_h} (no padding needed)", + category="setup", indent_level=1) + else: + debug.log(f"Target dimensions: {true_w}x{true_h} (padded to {padded_w}x{padded_h} for processing)", + category="setup", indent_level=1) + + del temp_transform, resized_sample + return true_h, true_w, padded_h, padded_w elif debug: debug.log("Reusing pre-initialized video transformation pipeline", category="reuse") return 0, 0, 0, 0 @@ -135,20 +163,19 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: debug.log("SeedVR2 breadcrumb: before temp_transform(sample_frame)", category="setup", force=True) if debug else None resized_sample = temp_transform(sample_frame) debug.log("SeedVR2 breadcrumb: after temp_transform(sample_frame)", category="setup", force=True) if debug else None - true_h, true_w = resized_sample.shape[-2:] + resized_h, resized_w = resized_sample.shape[-2:] # Round to even numbers for video codec compatibility (libx264 requirement) - true_h = (true_h // 2) * 2 - true_w = (true_w // 2) * 2 + true_h = (resized_h // 2) * 2 + true_w = (resized_w // 2) * 2 # Cache for later use in trimming ctx['true_target_dims'] = (true_h, true_w) - - # Get padded dimensions - debug.log("SeedVR2 breadcrumb: before ctx['video_transform'](sample_frame)", category="setup", force=True) if debug else None - transformed_sample = ctx['video_transform'](sample_frame) - debug.log("SeedVR2 breadcrumb: after ctx['video_transform'](sample_frame)", category="setup", force=True) if debug else None - padded_h, padded_w = transformed_sample.shape[-2:] + + # Compute padded dimensions from the resized shape before even-rounding + padded_h = ((resized_h + 15) // 16) * 16 + padded_w = ((resized_w + 15) // 16) * 16 + ctx['padded_target_dims'] = (padded_h, padded_w) if debug: if true_h == padded_h and true_w == padded_w: @@ -158,7 +185,7 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: debug.log(f"Target dimensions: {true_w}x{true_h} (padded to {padded_w}x{padded_h} for processing)", category="setup", indent_level=1) - del temp_transform, resized_sample, transformed_sample + del temp_transform, resized_sample return true_h, true_w, padded_h, padded_w return 0, 0, 0, 0 From a915809b20034f684751019598d67c329c716c03 Mon Sep 17 00:00:00 2001 From: xmarre Date: Fri, 3 Apr 2026 04:39:52 +0200 Subject: [PATCH 17/29] Refactor target-dimension probe and add debug breadcrumbs --- src/core/generation_utils.py | 76 ++++++++++++++++++++++++++++++------ 1 file changed, 63 insertions(+), 13 deletions(-) diff --git a/src/core/generation_utils.py b/src/core/generation_utils.py index c628faac..1f8583a2 100644 --- a/src/core/generation_utils.py +++ b/src/core/generation_utils.py @@ -44,6 +44,54 @@ script_directory = get_script_directory() +def _log_tensor_debug_state(debug: Optional['Debug'], label: str, tensor: Optional[torch.Tensor]) -> None: + """Log lightweight tensor metadata around transform breadcrumbs.""" + if debug is None: + return + + if tensor is None: + debug.log(f"SeedVR2 breadcrumb: {label}: tensor=None", category="setup", force=True) + return + + shape = tuple(tensor.shape) + dtype = tensor.dtype + device = tensor.device + approx_mib = (tensor.numel() * tensor.element_size()) / (1024 * 1024) + debug.log( + f"SeedVR2 breadcrumb: {label}: shape={shape}, dtype={dtype}, device={device}, approx_mib={approx_mib:.2f}", + category="setup", + force=True, + ) + + +def _resize_sample_frame_for_target_dims( + sample_frame: torch.Tensor, + resolution: int, + max_resolution: int = 0, + debug: Optional['Debug'] = None, +) -> torch.Tensor: + """Apply the target-dimension probe transform with step-level breadcrumbs.""" + resize = NaResize( + resolution=resolution, + mode="side", + downsample_only=False, + max_resolution=max_resolution, + ) + + _log_tensor_debug_state(debug, "temp_transform input", sample_frame) + debug.log("SeedVR2 breadcrumb: before NaResize(sample_frame)", category="setup", force=True) if debug else None + resized_sample = resize(sample_frame) + debug.log("SeedVR2 breadcrumb: after NaResize(sample_frame)", category="setup", force=True) if debug else None + _log_tensor_debug_state(debug, "temp_transform after NaResize", resized_sample) + + debug.log("SeedVR2 breadcrumb: before clamp(sample_frame)", category="setup", force=True) if debug else None + resized_sample = torch.clamp(resized_sample, 0.0, 1.0) + debug.log("SeedVR2 breadcrumb: after clamp(sample_frame)", category="setup", force=True) if debug else None + _log_tensor_debug_state(debug, "temp_transform after clamp", resized_sample) + + return resized_sample + + def prepare_video_transforms(resolution: int, max_resolution: int = 0, debug: Optional['Debug'] = None) -> Compose: """ Prepare optimized video transformation pipeline @@ -114,12 +162,13 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: debug.log("Reusing pre-initialized video transformation pipeline", category="reuse") return true_h, true_w, padded_h, padded_w if sample_frame is not None: - temp_transform = Compose([ - NaResize(resolution=resolution, mode="side", downsample_only=False, max_resolution=max_resolution), - Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) - ]) debug.log("SeedVR2 breadcrumb: before temp_transform(sample_frame)", category="setup", force=True) if debug else None - resized_sample = temp_transform(sample_frame) + resized_sample = _resize_sample_frame_for_target_dims( + sample_frame, + resolution, + max_resolution, + debug, + ) debug.log("SeedVR2 breadcrumb: after temp_transform(sample_frame)", category="setup", force=True) if debug else None resized_h, resized_w = resized_sample.shape[-2:] @@ -143,7 +192,7 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: debug.log(f"Target dimensions: {true_w}x{true_h} (padded to {padded_w}x{padded_h} for processing)", category="setup", indent_level=1) - del temp_transform, resized_sample + del resized_sample return true_h, true_w, padded_h, padded_w elif debug: debug.log("Reusing pre-initialized video transformation pipeline", category="reuse") @@ -156,12 +205,13 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: # Compute dimensions if sample frame provided if sample_frame is not None: # Get true target size (after resize, before padding) - temp_transform = Compose([ - NaResize(resolution=resolution, mode="side", downsample_only=False, max_resolution=max_resolution), - Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) - ]) debug.log("SeedVR2 breadcrumb: before temp_transform(sample_frame)", category="setup", force=True) if debug else None - resized_sample = temp_transform(sample_frame) + resized_sample = _resize_sample_frame_for_target_dims( + sample_frame, + resolution, + max_resolution, + debug, + ) debug.log("SeedVR2 breadcrumb: after temp_transform(sample_frame)", category="setup", force=True) if debug else None resized_h, resized_w = resized_sample.shape[-2:] @@ -184,8 +234,8 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: else: debug.log(f"Target dimensions: {true_w}x{true_h} (padded to {padded_w}x{padded_h} for processing)", category="setup", indent_level=1) - - del temp_transform, resized_sample + + del resized_sample return true_h, true_w, padded_h, padded_w return 0, 0, 0, 0 From 6dcbf02c6add0019007c5e92c30ef7dda161b9d5 Mon Sep 17 00:00:00 2001 From: xmarre Date: Mon, 6 Apr 2026 02:54:09 +0200 Subject: [PATCH 18/29] Remove repository metadata and documentation --- .codex | 0 src/core/generation_phases.py | 36 +++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 .codex diff --git a/.codex b/.codex new file mode 100644 index 00000000..e69de29b diff --git a/src/core/generation_phases.py b/src/core/generation_phases.py index f70ec979..3804a2ae 100644 --- a/src/core/generation_phases.py +++ b/src/core/generation_phases.py @@ -51,6 +51,7 @@ manage_tensor, manage_model_device, release_tensor_memory, + synchronize_device, release_tensor_collection ) from ..optimization.performance import ( @@ -1257,6 +1258,9 @@ def postprocess_all_batches( sample = sample_thwc.permute(0, 3, 1, 2) # [T, H, W, C] → [T, C, H, W] # Move to VAE device for processing + phase4_probe_enabled = getattr(debug, "enabled", False) + if phase4_probe_enabled: + debug.log(f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} before manage_tensor(sample->vae_device)", category="setup", force=True) sample = manage_tensor( tensor=sample, target_device=ctx['vae_device'], @@ -1266,15 +1270,24 @@ def postprocess_all_batches( reason="post-processing", indent_level=1 ) + if phase4_probe_enabled: + debug.log(f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} after manage_tensor(sample->vae_device)", category="setup", force=True) # Reconstruct transformed video on-demand for color correction input_video = None if color_correction != "none" and ctx.get('batch_metadata') is not None: if batch_idx < len(ctx['batch_metadata']) and ctx['batch_metadata'][batch_idx] is not None: # Reconstruct transformation + if phase4_probe_enabled: + debug.log(f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} before _reconstruct_and_transform_batch", category="setup", force=True) transformed_video = _reconstruct_and_transform_batch(ctx, batch_idx, debug) + if phase4_probe_enabled: + debug.log(f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} after _reconstruct_and_transform_batch", category="setup", force=True) + debug.log(f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} before optimized_single_video_rearrange", category="setup", force=True) input_video = optimized_single_video_rearrange(transformed_video) del transformed_video + if phase4_probe_enabled: + debug.log(f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} after optimized_single_video_rearrange", category="setup", force=True) # For batches after the first with temporal overlap, the overlap frames # were blended in Phase 3 and are not part of this slice. Skip them. @@ -1307,6 +1320,8 @@ def postprocess_all_batches( # Ensure both tensors are on same device (GPU) for color correction if input_video.device != sample.device: + if phase4_probe_enabled: + debug.log(f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} before manage_tensor(input_video->sample.device)", category="setup", force=True) input_video = manage_tensor( tensor=input_video, target_device=sample.device, @@ -1315,27 +1330,48 @@ def postprocess_all_batches( reason="color correction", indent_level=1 ) + if phase4_probe_enabled: + debug.log(f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} after manage_tensor(input_video->sample.device)", category="setup", force=True) # Apply selected color correction method debug.start_timer(f"color_correction_{color_correction}") + color_correction_applied = False if color_correction == "lab": debug.log("Applying LAB perceptual color transfer", category="video", force=True, indent_level=1) sample = lab_color_transfer(sample, input_video, debug, luminance_weight=0.8) + color_correction_applied = True elif color_correction == "wavelet_adaptive": debug.log("Applying wavelet with adaptive saturation correction", category="video", force=True, indent_level=1) sample = wavelet_adaptive_color_correction(sample, input_video, debug) + color_correction_applied = True elif color_correction == "wavelet": debug.log("Applying wavelet color reconstruction", category="video", force=True, indent_level=1) sample = wavelet_reconstruction(sample, input_video, debug) + color_correction_applied = True elif color_correction == "hsv": debug.log("Applying HSV hue-conditional saturation matching", category="video", force=True, indent_level=1) sample = hsv_saturation_histogram_match(sample, input_video, debug) + color_correction_applied = True elif color_correction == "adain": debug.log("Applying AdaIN color correction", category="video", force=True, indent_level=1) sample = adaptive_instance_normalization(sample, input_video) + color_correction_applied = True else: debug.log(f"Unknown color correction method: {color_correction}", level="WARNING", category="video", force=True, indent_level=1) + + if phase4_probe_enabled and color_correction_applied and sample.device.type != "cpu": + debug.log( + f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} before synchronize_device(after {color_correction})", + category="setup", + force=True, + ) + synchronize_device(sample.device, debug=debug, reason=f"phase4 batch {info_idx+1} after {color_correction}") + debug.log( + f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} after synchronize_device(after {color_correction})", + category="setup", + force=True, + ) debug.end_timer(f"color_correction_{color_correction}", f"Color correction ({color_correction})") From a292b2bd0def31acf16931c8f8a78f1e55e7cfd0 Mon Sep 17 00:00:00 2001 From: xmarre Date: Mon, 6 Apr 2026 20:54:04 +0200 Subject: [PATCH 19/29] Remove repository metadata and documentation files --- src/core/generation_utils.py | 75 ++++++++++++++++++++++++------------ 1 file changed, 51 insertions(+), 24 deletions(-) diff --git a/src/core/generation_utils.py b/src/core/generation_utils.py index 1f8583a2..b349f5d0 100644 --- a/src/core/generation_utils.py +++ b/src/core/generation_utils.py @@ -64,32 +64,63 @@ def _log_tensor_debug_state(debug: Optional['Debug'], label: str, tensor: Option ) -def _resize_sample_frame_for_target_dims( +def _compute_side_resize_output_dims( + input_height: int, + input_width: int, + resolution: int, + max_resolution: int = 0, +) -> Tuple[int, int]: + """Compute output dims for the target-dimension probe path. + + Matches SideResize with downsample_only=False and the current max_size + second pass, without executing the resize kernel. + """ + short, long = (input_width, input_height) if input_width <= input_height else (input_height, input_width) + + # Match torchvision Resize(int) behavior used by SideResize: + # shortest edge = resolution, longest edge scaled with floor(int(...)) + resized_short = resolution + resized_long = int(resolution * long / short) + + if input_width <= input_height: + resized_w, resized_h = resized_short, resized_long + else: + resized_w, resized_h = resized_long, resized_short + + # Match SideResize's second-pass max_size handling exactly. + if max_resolution > 0 and max(resized_h, resized_w) > max_resolution: + scale = max_resolution / max(resized_h, resized_w) + resized_h = round(resized_h * scale) + resized_w = round(resized_w * scale) + + return resized_h, resized_w + + +def _compute_sample_frame_target_dims( sample_frame: torch.Tensor, resolution: int, max_resolution: int = 0, debug: Optional['Debug'] = None, -) -> torch.Tensor: - """Apply the target-dimension probe transform with step-level breadcrumbs.""" - resize = NaResize( +) -> Tuple[int, int]: + """Compute probe target dimensions with breadcrumbs but without running a real resize.""" + input_h, input_w = sample_frame.shape[-2:] + + _log_tensor_debug_state(debug, "temp_transform input", sample_frame) + debug.log("SeedVR2 breadcrumb: before compute_target_dims(sample_frame)", category="setup", force=True) if debug else None + resized_h, resized_w = _compute_side_resize_output_dims( + input_height=input_h, + input_width=input_w, resolution=resolution, - mode="side", - downsample_only=False, max_resolution=max_resolution, ) + debug.log( + f"SeedVR2 breadcrumb: computed temp_transform target dims: input=({input_h}, {input_w}) -> resized=({resized_h}, {resized_w})", + category="setup", + force=True, + ) if debug else None + debug.log("SeedVR2 breadcrumb: after compute_target_dims(sample_frame)", category="setup", force=True) if debug else None - _log_tensor_debug_state(debug, "temp_transform input", sample_frame) - debug.log("SeedVR2 breadcrumb: before NaResize(sample_frame)", category="setup", force=True) if debug else None - resized_sample = resize(sample_frame) - debug.log("SeedVR2 breadcrumb: after NaResize(sample_frame)", category="setup", force=True) if debug else None - _log_tensor_debug_state(debug, "temp_transform after NaResize", resized_sample) - - debug.log("SeedVR2 breadcrumb: before clamp(sample_frame)", category="setup", force=True) if debug else None - resized_sample = torch.clamp(resized_sample, 0.0, 1.0) - debug.log("SeedVR2 breadcrumb: after clamp(sample_frame)", category="setup", force=True) if debug else None - _log_tensor_debug_state(debug, "temp_transform after clamp", resized_sample) - - return resized_sample + return resized_h, resized_w def prepare_video_transforms(resolution: int, max_resolution: int = 0, debug: Optional['Debug'] = None) -> Compose: @@ -163,14 +194,13 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: return true_h, true_w, padded_h, padded_w if sample_frame is not None: debug.log("SeedVR2 breadcrumb: before temp_transform(sample_frame)", category="setup", force=True) if debug else None - resized_sample = _resize_sample_frame_for_target_dims( + resized_h, resized_w = _compute_sample_frame_target_dims( sample_frame, resolution, max_resolution, debug, ) debug.log("SeedVR2 breadcrumb: after temp_transform(sample_frame)", category="setup", force=True) if debug else None - resized_h, resized_w = resized_sample.shape[-2:] # Round to even numbers for video codec compatibility (libx264 requirement) true_h = (resized_h // 2) * 2 @@ -192,7 +222,6 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: debug.log(f"Target dimensions: {true_w}x{true_h} (padded to {padded_w}x{padded_h} for processing)", category="setup", indent_level=1) - del resized_sample return true_h, true_w, padded_h, padded_w elif debug: debug.log("Reusing pre-initialized video transformation pipeline", category="reuse") @@ -206,14 +235,13 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: if sample_frame is not None: # Get true target size (after resize, before padding) debug.log("SeedVR2 breadcrumb: before temp_transform(sample_frame)", category="setup", force=True) if debug else None - resized_sample = _resize_sample_frame_for_target_dims( + resized_h, resized_w = _compute_sample_frame_target_dims( sample_frame, resolution, max_resolution, debug, ) debug.log("SeedVR2 breadcrumb: after temp_transform(sample_frame)", category="setup", force=True) if debug else None - resized_h, resized_w = resized_sample.shape[-2:] # Round to even numbers for video codec compatibility (libx264 requirement) true_h = (resized_h // 2) * 2 @@ -235,7 +263,6 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: debug.log(f"Target dimensions: {true_w}x{true_h} (padded to {padded_w}x{padded_h} for processing)", category="setup", indent_level=1) - del resized_sample return true_h, true_w, padded_h, padded_w return 0, 0, 0, 0 From 0b0dbbd28f2a57f9305b4601958e9bb20cde04f0 Mon Sep 17 00:00:00 2001 From: xmarre Date: Mon, 6 Apr 2026 21:02:35 +0200 Subject: [PATCH 20/29] Remove repository metadata and documentation files --- src/core/generation_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/core/generation_utils.py b/src/core/generation_utils.py index b349f5d0..9b269752 100644 --- a/src/core/generation_utils.py +++ b/src/core/generation_utils.py @@ -193,14 +193,14 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: debug.log("Reusing pre-initialized video transformation pipeline", category="reuse") return true_h, true_w, padded_h, padded_w if sample_frame is not None: - debug.log("SeedVR2 breadcrumb: before temp_transform(sample_frame)", category="setup", force=True) if debug else None + debug.log("SeedVR2 breadcrumb: before compute_target_dims(sample_frame)", category="setup", force=True) if debug else None resized_h, resized_w = _compute_sample_frame_target_dims( sample_frame, resolution, max_resolution, debug, ) - debug.log("SeedVR2 breadcrumb: after temp_transform(sample_frame)", category="setup", force=True) if debug else None + debug.log("SeedVR2 breadcrumb: after compute_target_dims(sample_frame)", category="setup", force=True) if debug else None # Round to even numbers for video codec compatibility (libx264 requirement) true_h = (resized_h // 2) * 2 @@ -234,14 +234,14 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: # Compute dimensions if sample frame provided if sample_frame is not None: # Get true target size (after resize, before padding) - debug.log("SeedVR2 breadcrumb: before temp_transform(sample_frame)", category="setup", force=True) if debug else None + debug.log("SeedVR2 breadcrumb: before compute_target_dims(sample_frame)", category="setup", force=True) if debug else None resized_h, resized_w = _compute_sample_frame_target_dims( sample_frame, resolution, max_resolution, debug, ) - debug.log("SeedVR2 breadcrumb: after temp_transform(sample_frame)", category="setup", force=True) if debug else None + debug.log("SeedVR2 breadcrumb: after compute_target_dims(sample_frame)", category="setup", force=True) if debug else None # Round to even numbers for video codec compatibility (libx264 requirement) true_h = (resized_h // 2) * 2 From ae690052a7843ca0a380b543914927fdb47d29a7 Mon Sep 17 00:00:00 2001 From: xmarre Date: Mon, 6 Apr 2026 21:17:00 +0200 Subject: [PATCH 21/29] Remove repository metadata and documentation files --- src/core/generation_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/core/generation_utils.py b/src/core/generation_utils.py index 9b269752..01249323 100644 --- a/src/core/generation_utils.py +++ b/src/core/generation_utils.py @@ -29,9 +29,12 @@ import os import torch -from typing import Dict, List, Optional, Tuple, Any, Callable, Union +from typing import Dict, List, Optional, Tuple, Any, Callable, Union, TYPE_CHECKING from torchvision.transforms import Compose, Lambda, Normalize +if TYPE_CHECKING: + from ..utils.debug import Debug + from .model_configuration import configure_runner from .infer import VideoDiffusionInfer from ..data.image.transforms.divisible_crop import DivisiblePad From ca7f1a2d21a1a65ee9e452188589fcf0ed92c74c Mon Sep 17 00:00:00 2001 From: xmarre Date: Mon, 6 Apr 2026 21:37:18 +0200 Subject: [PATCH 22/29] Remove repository metadata and documentation files --- src/core/generation_phases.py | 46 ++++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/src/core/generation_phases.py b/src/core/generation_phases.py index 3804a2ae..b02640b6 100644 --- a/src/core/generation_phases.py +++ b/src/core/generation_phases.py @@ -294,6 +294,7 @@ def encode_all_batches( ctx['batch_metadata'] = [None] * num_encode_batches encode_idx = 0 + validated_computed_target_dims = False try: vae_needs_reactivation = runner.vae is not None and is_model_cache_cold(runner.vae) @@ -424,6 +425,33 @@ def encode_all_batches( transformed_video = ctx['video_transform'](rgb_video) + if ( + getattr(debug, "enabled", False) + and not validated_computed_target_dims + and 'padded_target_dims' in ctx + and 'true_target_dims' in ctx + ): + actual_padded_h, actual_padded_w = transformed_video.shape[-2:] + expected_padded_h, expected_padded_w = ctx['padded_target_dims'] + expected_true_h, expected_true_w = ctx['true_target_dims'] + + if (actual_padded_h, actual_padded_w) != (expected_padded_h, expected_padded_w): + msg = ( + "Computed target dims mismatch: " + f"expected padded {expected_padded_w}x{expected_padded_h}, " + f"actual transform output {actual_padded_w}x{actual_padded_h}, " + f"cached true target {expected_true_w}x{expected_true_h}" + ) + debug.log(msg, level="ERROR", category="setup", force=True) + raise RuntimeError(msg) + + debug.log( + f"Validated computed target dims against actual transform output: padded {expected_padded_w}x{expected_padded_h}, true {expected_true_w}x{expected_true_h}", + category="setup", + force=True, + ) + validated_computed_target_dims = True + # Apply input noise if requested (to reduce artifacts at high resolutions) if input_noise_scale > 0: debug.log(f"Applying input noise (scale: {input_noise_scale:.2f})", category="video", indent_level=1) @@ -1302,7 +1330,23 @@ def postprocess_all_batches( # Trim spatial dimensions to true target size if 'true_target_dims' in ctx: true_h, true_w = ctx['true_target_dims'] - if input_video.shape[-2] != true_h or input_video.shape[-1] != true_w: + current_h, current_w = input_video.shape[-2:] + if current_h != true_h or current_w != true_w: + if current_h < true_h or current_w < true_w: + msg = ( + "Reconstructed input spatial dims smaller than expected true target dims: " + f"{current_w}x{current_h} < {true_w}x{true_h}" + ) + if debug: + debug.log(msg, level="ERROR", category="video", force=True) + raise RuntimeError(msg) + + if debug: + debug.log( + f"Trimming reconstructed input spatial padding: {current_w}x{current_h} → {true_w}x{true_h}", + category="video", + indent_level=1, + ) input_video = input_video[:, :, :true_h, :true_w] # Apply color correction if enabled (RGB only) From ee506adb183047c5e6fec785e6995eeb00636a2e Mon Sep 17 00:00:00 2001 From: xmarre Date: Mon, 6 Apr 2026 21:47:02 +0200 Subject: [PATCH 23/29] Remove repository metadata and documentation files --- src/core/generation_phases.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/core/generation_phases.py b/src/core/generation_phases.py index b02640b6..466adda8 100644 --- a/src/core/generation_phases.py +++ b/src/core/generation_phases.py @@ -426,8 +426,7 @@ def encode_all_batches( transformed_video = ctx['video_transform'](rgb_video) if ( - getattr(debug, "enabled", False) - and not validated_computed_target_dims + not validated_computed_target_dims and 'padded_target_dims' in ctx and 'true_target_dims' in ctx ): @@ -442,14 +441,16 @@ def encode_all_batches( f"actual transform output {actual_padded_w}x{actual_padded_h}, " f"cached true target {expected_true_w}x{expected_true_h}" ) - debug.log(msg, level="ERROR", category="setup", force=True) + if debug is not None: + debug.log(msg, level="ERROR", category="setup", force=True) raise RuntimeError(msg) - debug.log( - f"Validated computed target dims against actual transform output: padded {expected_padded_w}x{expected_padded_h}, true {expected_true_w}x{expected_true_h}", - category="setup", - force=True, - ) + if getattr(debug, "enabled", False): + debug.log( + f"Validated computed target dims against actual transform output: padded {expected_padded_w}x{expected_padded_h}, true {expected_true_w}x{expected_true_h}", + category="setup", + force=True, + ) validated_computed_target_dims = True # Apply input noise if requested (to reduce artifacts at high resolutions) From f5c46f403a600c977794646ab547ea6eb80ccb61 Mon Sep 17 00:00:00 2001 From: xmarre Date: Mon, 6 Apr 2026 23:38:08 +0200 Subject: [PATCH 24/29] Align Phase 4 reconstructed input transform with Phase 1 device path --- src/core/generation_phases.py | 60 +++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/src/core/generation_phases.py b/src/core/generation_phases.py index 466adda8..c00c4a21 100644 --- a/src/core/generation_phases.py +++ b/src/core/generation_phases.py @@ -143,8 +143,15 @@ def _reconstruct_and_transform_batch( Transformed video in CTHW format, ready for color correction """ start_idx, end_idx, uniform_padding = ctx['batch_metadata'][batch_idx] + phase4_probe_enabled = getattr(debug, "enabled", False) # Prepare video batch + if phase4_probe_enabled: + debug.log( + f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} before _prepare_video_batch", + category="setup", + force=True, + ) video = _prepare_video_batch( images=ctx['input_images'], start_idx=start_idx, @@ -153,9 +160,50 @@ def _reconstruct_and_transform_batch( debug=None, log_info=False ) + if phase4_probe_enabled: + debug.log( + f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} after _prepare_video_batch", + category="setup", + force=True, + ) + + # Mirror Phase 1 ordering: move to the VAE device before padding and transform. + if phase4_probe_enabled: + debug.log( + f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} before manage_tensor(video->vae_device)", + category="setup", + force=True, + ) + video = manage_tensor( + tensor=video, + target_device=ctx['vae_device'], + tensor_name=f"reconstructed_video_batch_{batch_idx+1}", + dtype=ctx['compute_dtype'], + debug=debug, + reason="Phase 4 input reconstruction", + indent_level=1, + ) + if phase4_probe_enabled: + debug.log( + f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} after manage_tensor(video->vae_device)", + category="setup", + force=True, + ) # Apply 4n+1 padding using shared helper + if phase4_probe_enabled: + debug.log( + f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} before _apply_4n1_padding", + category="setup", + force=True, + ) video = _apply_4n1_padding(video) + if phase4_probe_enabled: + debug.log( + f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} after _apply_4n1_padding", + category="setup", + force=True, + ) # Extract RGB and transform if ctx.get('is_rgba', False): @@ -163,7 +211,19 @@ def _reconstruct_and_transform_batch( else: rgb_video = video + if phase4_probe_enabled: + debug.log( + f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} before video_transform(rgb_video)", + category="setup", + force=True, + ) transformed_video = ctx['video_transform'](rgb_video) + if phase4_probe_enabled: + debug.log( + f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} after video_transform(rgb_video)", + category="setup", + force=True, + ) del video From 29aae4a5e0c52b3bb8447551f49750ad9489c29f Mon Sep 17 00:00:00 2001 From: xmarre Date: Mon, 6 Apr 2026 23:42:08 +0200 Subject: [PATCH 25/29] Revert "Align Phase 4 reconstructed input transform with Phase 1 device path" This reverts commit f5c46f403a600c977794646ab547ea6eb80ccb61. --- src/core/generation_phases.py | 60 ----------------------------------- 1 file changed, 60 deletions(-) diff --git a/src/core/generation_phases.py b/src/core/generation_phases.py index c00c4a21..466adda8 100644 --- a/src/core/generation_phases.py +++ b/src/core/generation_phases.py @@ -143,15 +143,8 @@ def _reconstruct_and_transform_batch( Transformed video in CTHW format, ready for color correction """ start_idx, end_idx, uniform_padding = ctx['batch_metadata'][batch_idx] - phase4_probe_enabled = getattr(debug, "enabled", False) # Prepare video batch - if phase4_probe_enabled: - debug.log( - f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} before _prepare_video_batch", - category="setup", - force=True, - ) video = _prepare_video_batch( images=ctx['input_images'], start_idx=start_idx, @@ -160,50 +153,9 @@ def _reconstruct_and_transform_batch( debug=None, log_info=False ) - if phase4_probe_enabled: - debug.log( - f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} after _prepare_video_batch", - category="setup", - force=True, - ) - - # Mirror Phase 1 ordering: move to the VAE device before padding and transform. - if phase4_probe_enabled: - debug.log( - f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} before manage_tensor(video->vae_device)", - category="setup", - force=True, - ) - video = manage_tensor( - tensor=video, - target_device=ctx['vae_device'], - tensor_name=f"reconstructed_video_batch_{batch_idx+1}", - dtype=ctx['compute_dtype'], - debug=debug, - reason="Phase 4 input reconstruction", - indent_level=1, - ) - if phase4_probe_enabled: - debug.log( - f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} after manage_tensor(video->vae_device)", - category="setup", - force=True, - ) # Apply 4n+1 padding using shared helper - if phase4_probe_enabled: - debug.log( - f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} before _apply_4n1_padding", - category="setup", - force=True, - ) video = _apply_4n1_padding(video) - if phase4_probe_enabled: - debug.log( - f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} after _apply_4n1_padding", - category="setup", - force=True, - ) # Extract RGB and transform if ctx.get('is_rgba', False): @@ -211,19 +163,7 @@ def _reconstruct_and_transform_batch( else: rgb_video = video - if phase4_probe_enabled: - debug.log( - f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} before video_transform(rgb_video)", - category="setup", - force=True, - ) transformed_video = ctx['video_transform'](rgb_video) - if phase4_probe_enabled: - debug.log( - f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} after video_transform(rgb_video)", - category="setup", - force=True, - ) del video From 5a043be8b5d5dc918aaf18a660027158f8d72557 Mon Sep 17 00:00:00 2001 From: xmarre Date: Mon, 6 Apr 2026 23:51:44 +0200 Subject: [PATCH 26/29] Revert "Revert "Align Phase 4 reconstructed input transform with Phase 1 device path"" This reverts commit 29aae4a5e0c52b3bb8447551f49750ad9489c29f. --- src/core/generation_phases.py | 60 +++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/src/core/generation_phases.py b/src/core/generation_phases.py index 466adda8..c00c4a21 100644 --- a/src/core/generation_phases.py +++ b/src/core/generation_phases.py @@ -143,8 +143,15 @@ def _reconstruct_and_transform_batch( Transformed video in CTHW format, ready for color correction """ start_idx, end_idx, uniform_padding = ctx['batch_metadata'][batch_idx] + phase4_probe_enabled = getattr(debug, "enabled", False) # Prepare video batch + if phase4_probe_enabled: + debug.log( + f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} before _prepare_video_batch", + category="setup", + force=True, + ) video = _prepare_video_batch( images=ctx['input_images'], start_idx=start_idx, @@ -153,9 +160,50 @@ def _reconstruct_and_transform_batch( debug=None, log_info=False ) + if phase4_probe_enabled: + debug.log( + f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} after _prepare_video_batch", + category="setup", + force=True, + ) + + # Mirror Phase 1 ordering: move to the VAE device before padding and transform. + if phase4_probe_enabled: + debug.log( + f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} before manage_tensor(video->vae_device)", + category="setup", + force=True, + ) + video = manage_tensor( + tensor=video, + target_device=ctx['vae_device'], + tensor_name=f"reconstructed_video_batch_{batch_idx+1}", + dtype=ctx['compute_dtype'], + debug=debug, + reason="Phase 4 input reconstruction", + indent_level=1, + ) + if phase4_probe_enabled: + debug.log( + f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} after manage_tensor(video->vae_device)", + category="setup", + force=True, + ) # Apply 4n+1 padding using shared helper + if phase4_probe_enabled: + debug.log( + f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} before _apply_4n1_padding", + category="setup", + force=True, + ) video = _apply_4n1_padding(video) + if phase4_probe_enabled: + debug.log( + f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} after _apply_4n1_padding", + category="setup", + force=True, + ) # Extract RGB and transform if ctx.get('is_rgba', False): @@ -163,7 +211,19 @@ def _reconstruct_and_transform_batch( else: rgb_video = video + if phase4_probe_enabled: + debug.log( + f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} before video_transform(rgb_video)", + category="setup", + force=True, + ) transformed_video = ctx['video_transform'](rgb_video) + if phase4_probe_enabled: + debug.log( + f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} after video_transform(rgb_video)", + category="setup", + force=True, + ) del video From 431ca7b280d7dbbacbe38f169be01a85867f6252 Mon Sep 17 00:00:00 2001 From: xmarre Date: Thu, 16 Apr 2026 15:33:27 +0200 Subject: [PATCH 27/29] Add DiT tiling GUI patch wiring --- src/core/generation_utils.py | 9 ++ src/core/infer.py | 193 ++++++++++++++++++++++++++--- src/core/model_configuration.py | 17 ++- src/interfaces/dit_model_loader.py | 46 ++++++- src/interfaces/video_upscaler.py | 6 + 5 files changed, 252 insertions(+), 19 deletions(-) diff --git a/src/core/generation_utils.py b/src/core/generation_utils.py index 01249323..e02b8728 100644 --- a/src/core/generation_utils.py +++ b/src/core/generation_utils.py @@ -556,6 +556,9 @@ def prepare_runner( decode_tile_size: Optional[Tuple[int, int]] = None, decode_tile_overlap: Optional[Tuple[int, int]] = None, tile_debug: str = "false", + dit_tiled: bool = False, + dit_tile_size: Optional[Tuple[int, int]] = None, + dit_tile_overlap: Optional[Tuple[int, int]] = None, attention_mode: str = 'sdpa', torch_compile_args_dit: Optional[Dict[str, Any]] = None, torch_compile_args_vae: Optional[Dict[str, Any]] = None @@ -582,6 +585,9 @@ def prepare_runner( decode_tile_size: Tile size for decoding (height, width) decode_tile_overlap: Tile overlap for decoding (height, width) tile_debug: Tile visualization mode (false/encode/decode) + dit_tiled: Enable spatial DiT tiling during upscaling + dit_tile_size: Spatial DiT tile size (height, width) in latent-space pixels + dit_tile_overlap: Spatial overlap (height, width) between DiT tiles in latent-space pixels attention_mode: Attention computation backend ('sdpa', 'flash_attn_2', 'flash_attn_3', 'sageattn_2', or 'sageattn_3') torch_compile_args_dit: Optional torch.compile configuration for DiT model torch_compile_args_vae: Optional torch.compile configuration for VAE model @@ -626,6 +632,9 @@ def prepare_runner( decode_tile_size=decode_tile_size, decode_tile_overlap=decode_tile_overlap, tile_debug=tile_debug, + dit_tiled=dit_tiled, + dit_tile_size=dit_tile_size, + dit_tile_overlap=dit_tile_overlap, attention_mode=attention_mode, torch_compile_args_dit=torch_compile_args_dit, torch_compile_args_vae=torch_compile_args_vae diff --git a/src/core/infer.py b/src/core/infer.py index a0869cae..0431b05b 100644 --- a/src/core/infer.py +++ b/src/core/infer.py @@ -39,7 +39,9 @@ def __init__(self, config: DictConfig, debug: 'Debug', encode_tile_overlap: Tuple[int, int] = (64, 64), decode_tiled: bool = False, decode_tile_size: Tuple[int, int] = (512, 512), decode_tile_overlap: Tuple[int, int] = (64, 64), - tile_debug: str = "false"): + tile_debug: str = "false", + dit_tiled: bool = False, dit_tile_size: Tuple[int, int] = (128, 128), + dit_tile_overlap: Tuple[int, int] = (16, 16)): self.config = config self.debug = debug # Store separate encode and decode tiling parameters @@ -50,6 +52,9 @@ def __init__(self, config: DictConfig, debug: 'Debug', self.decode_tile_size = decode_tile_size self.decode_tile_overlap = decode_tile_overlap self.tile_debug = tile_debug + self.dit_tiled = dit_tiled + self.dit_tile_size = dit_tile_size + self.dit_tile_overlap = dit_tile_overlap def get_condition(self, latent: Tensor, latent_blur: Tensor, task: str) -> Tensor: t, h, w, c = latent.shape @@ -310,9 +315,68 @@ def get_lin_function(x1, y1, x2, y2): timesteps = timesteps * self.schedule.T return timesteps - - @torch.no_grad() - def inference( + @staticmethod + def _tile_axis_starts(length: int, tile: int, overlap: int) -> List[int]: + if length <= tile: + return [0] + + stride = max(1, tile - overlap) + starts: List[int] = [] + start = 0 + while True: + starts.append(start) + if start + tile >= length: + break + next_start = min(start + stride, length - tile) + if next_start <= start: + break + start = next_start + return starts + + @staticmethod + def _tile_blend_vector( + length: int, + overlap: int, + is_start_edge: bool, + is_end_edge: bool, + device: torch.device, + dtype: torch.dtype, + ) -> Tensor: + weight = torch.ones((length,), device=device, dtype=dtype) + if overlap <= 0 or length <= 1: + return weight + + ramp_extent = min(overlap, length - 1) + if ramp_extent <= 0: + return weight + + ramp = torch.linspace(1.0 / (ramp_extent + 1), 1.0, steps=ramp_extent, device=device, dtype=dtype) + if not is_start_edge: + weight[:ramp_extent] = ramp + if not is_end_edge: + weight[-ramp_extent:] = torch.minimum(weight[-ramp_extent:], torch.flip(ramp, dims=[0])) + return weight + + def _dit_blend_mask( + self, + tile_h: int, + tile_w: int, + y0: int, + y1: int, + x0: int, + x1: int, + full_h: int, + full_w: int, + device: torch.device, + dtype: torch.dtype, + ) -> Tensor: + overlap_h = max(0, min(self.dit_tile_overlap[0], tile_h - 1)) + overlap_w = max(0, min(self.dit_tile_overlap[1], tile_w - 1)) + weight_y = self._tile_blend_vector(tile_h, overlap_h, y0 == 0, y1 >= full_h, device, dtype) + weight_x = self._tile_blend_vector(tile_w, overlap_w, x0 == 0, x1 >= full_w, device, dtype) + return (weight_y[:, None] * weight_x[None, :]).view(1, tile_h, tile_w, 1) + + def _inference_flat( self, noises: List[Tensor], conditions: List[Tensor], @@ -323,15 +387,12 @@ def inference( assert len(noises) == len(conditions) == len(texts_pos) == len(texts_neg) batch_size = len(noises) - # Return if empty. if batch_size == 0: return [] - - # Set cfg scale + if cfg_scale is None: cfg_scale = self.config.diffusion.cfg.scale - - # Text embeddings. + assert type(texts_pos[0]) is type(texts_neg[0]) if isinstance(texts_pos[0], str): text_pos_embeds, text_pos_shapes = self.text_encode(texts_pos) @@ -350,11 +411,10 @@ def inference( else: text_pos_embeds, text_pos_shapes = na.flatten(texts_pos) text_neg_embeds, text_neg_shapes = na.flatten(texts_neg) - - # Flatten. + latents, latents_shapes = na.flatten(noises) latents_cond, _ = na.flatten(conditions) - + latents = self.sampler.sample( x=latents, f=lambda args: classifier_free_guidance_dispatcher( @@ -384,12 +444,115 @@ def inference( latents = na.unflatten(latents, latents_shapes) - # Clean up temporary tensors del latents_cond del latents_shapes del text_pos_embeds del text_neg_embeds del text_pos_shapes del text_neg_shapes - - return latents \ No newline at end of file + + return latents + + def _inference_tiled_single( + self, + noise: Tensor, + condition: Tensor, + texts_pos: Union[List[str], List[Tensor], List[Tuple[Tensor]]], + texts_neg: Union[List[str], List[Tensor], List[Tuple[Tensor]]], + cfg_scale: Optional[float] = None, + ) -> Tensor: + if noise.ndim != 4 or condition.ndim != 4: + return self._inference_flat([noise], [condition], texts_pos, texts_neg, cfg_scale=cfg_scale)[0] + + _, full_h, full_w, _ = noise.shape + tile_h = max(1, min(self.dit_tile_size[0], full_h)) + tile_w = max(1, min(self.dit_tile_size[1], full_w)) + + if full_h <= tile_h and full_w <= tile_w: + return self._inference_flat([noise], [condition], texts_pos, texts_neg, cfg_scale=cfg_scale)[0] + + overlap_h = max(0, min(self.dit_tile_overlap[0], tile_h - 1)) + overlap_w = max(0, min(self.dit_tile_overlap[1], tile_w - 1)) + y_starts = self._tile_axis_starts(full_h, tile_h, overlap_h) + x_starts = self._tile_axis_starts(full_w, tile_w, overlap_w) + tile_count = len(y_starts) * len(x_starts) + + if self.debug is not None: + self.debug.log( + f"Using DiT tiled inference ({tile_count} tiles, size {tile_h}x{tile_w}, overlap {overlap_h}x{overlap_w})", + category="dit", + force=True, + indent_level=1, + ) + + result = None + weight_sum = None + tile_index = 0 + + for y0 in y_starts: + y1 = min(y0 + tile_h, full_h) + for x0 in x_starts: + x1 = min(x0 + tile_w, full_w) + tile_index += 1 + if self.debug is not None and (tile_index == 1 or tile_index == tile_count or tile_index % 4 == 0): + self.debug.log( + f"DiT tile {tile_index}/{tile_count}: y={y0}:{y1}, x={x0}:{x1}", + category="dit", + indent_level=2, + ) + + noise_tile = noise[:, y0:y1, x0:x1, :] + condition_tile = condition[:, y0:y1, x0:x1, :] + tile_result = self._inference_flat([noise_tile], [condition_tile], texts_pos, texts_neg, cfg_scale=cfg_scale)[0] + + if result is None: + result = torch.zeros_like(noise) + weight_sum = torch.zeros((*noise.shape[:-1], 1), device=noise.device, dtype=tile_result.dtype) + + blend_mask = self._dit_blend_mask( + tile_h=y1 - y0, + tile_w=x1 - x0, + y0=y0, + y1=y1, + x0=x0, + x1=x1, + full_h=full_h, + full_w=full_w, + device=tile_result.device, + dtype=tile_result.dtype, + ) + + result[:, y0:y1, x0:x1, :] += tile_result * blend_mask + weight_sum[:, y0:y1, x0:x1, :] += blend_mask + + del noise_tile, condition_tile, tile_result, blend_mask + + weight_sum = torch.clamp(weight_sum, min=torch.finfo(weight_sum.dtype).eps) + result = result / weight_sum + del weight_sum + return result + + @torch.no_grad() + def inference( + self, + noises: List[Tensor], + conditions: List[Tensor], + texts_pos: Union[List[str], List[Tensor], List[Tuple[Tensor]]], + texts_neg: Union[List[str], List[Tensor], List[Tuple[Tensor]]], + cfg_scale: Optional[float] = None, + ) -> List[Tensor]: + if len(noises) == 0: + return [] + + if not self.dit_tiled or len(noises) != 1: + return self._inference_flat(noises, conditions, texts_pos, texts_neg, cfg_scale=cfg_scale) + + return [ + self._inference_tiled_single( + noise=noises[0], + condition=conditions[0], + texts_pos=texts_pos, + texts_neg=texts_neg, + cfg_scale=cfg_scale, + ) + ] diff --git a/src/core/model_configuration.py b/src/core/model_configuration.py index 02ce2abc..c81d4381 100644 --- a/src/core/model_configuration.py +++ b/src/core/model_configuration.py @@ -857,6 +857,9 @@ def configure_runner( decode_tile_size: Optional[Tuple[int, int]] = None, decode_tile_overlap: Optional[Tuple[int, int]] = None, tile_debug: str = "false", + dit_tiled: bool = False, + dit_tile_size: Optional[Tuple[int, int]] = None, + dit_tile_overlap: Optional[Tuple[int, int]] = None, attention_mode: str = 'sdpa', torch_compile_args_dit: Optional[Dict[str, Any]] = None, torch_compile_args_vae: Optional[Dict[str, Any]] = None @@ -884,6 +887,9 @@ def configure_runner( decode_tile_size: Tile size for decoding (height, width) decode_tile_overlap: Tile overlap for decoding (height, width) tile_debug: Tile visualization mode (false/encode/decode) + dit_tiled: Enable spatial DiT tiling during upscaling + dit_tile_size: Spatial DiT tile size (height, width) in latent-space pixels + dit_tile_overlap: Spatial overlap (height, width) between DiT tiles in latent-space pixels attention_mode: Attention computation backend ('sdpa', 'flash_attn_2', 'flash_attn_3', 'sageattn_2', or 'sageattn_3') torch_compile_args_dit: Optional torch.compile configuration for DiT model torch_compile_args_vae: Optional torch.compile configuration for VAE model @@ -934,7 +940,7 @@ def configure_runner( cache_context.get('vae_id') if vae_cache else None, encode_tiled, encode_tile_size, encode_tile_overlap, decode_tiled, decode_tile_size, decode_tile_overlap, - tile_debug, attention_mode, + tile_debug, dit_tiled, dit_tile_size, dit_tile_overlap, attention_mode, torch_compile_args_dit, torch_compile_args_vae, block_swap_config, debug ) @@ -1062,6 +1068,9 @@ def _configure_runner_settings( decode_tile_size: Optional[Tuple[int, int]], decode_tile_overlap: Optional[Tuple[int, int]], tile_debug: str, + dit_tiled: bool, + dit_tile_size: Optional[Tuple[int, int]], + dit_tile_overlap: Optional[Tuple[int, int]], attention_mode: str, torch_compile_args_dit: Optional[Dict[str, Any]], torch_compile_args_vae: Optional[Dict[str, Any]], @@ -1086,6 +1095,9 @@ def _configure_runner_settings( decode_tile_size: Tile dimensions (height, width) for decoding in pixels decode_tile_overlap: Overlap dimensions (height, width) between decoding tiles tile_debug: Tile visualization mode (false/encode/decode) + dit_tiled: Enable spatial DiT tiling during upscaling + dit_tile_size: Spatial DiT tile size (height, width) in latent-space pixels + dit_tile_overlap: Spatial overlap (height, width) between DiT tiles in latent-space pixels attention_mode: Attention computation backend ('sdpa', 'flash_attn_2', 'flash_attn_3', 'sageattn_2', or 'sageattn_3') torch_compile_args_dit: torch.compile configuration for DiT model or None torch_compile_args_vae: torch.compile configuration for VAE model or None @@ -1100,6 +1112,9 @@ def _configure_runner_settings( runner.decode_tile_size = decode_tile_size runner.decode_tile_overlap = decode_tile_overlap runner.tile_debug = tile_debug + runner.dit_tiled = dit_tiled + runner.dit_tile_size = dit_tile_size + runner.dit_tile_overlap = dit_tile_overlap # Store the new configs temporarily for later comparison # Don't set them as attributes yet - let the update functions handle that diff --git a/src/interfaces/dit_model_loader.py b/src/interfaces/dit_model_loader.py index 9d8204e8..2ca5a441 100644 --- a/src/interfaces/dit_model_loader.py +++ b/src/interfaces/dit_model_loader.py @@ -5,7 +5,7 @@ from comfy_api.latest import io from comfy_execution.utils import get_executing_context -from typing import Dict, Any, Tuple +from typing import Dict, Any from ..utils.model_registry import get_available_dit_models, DEFAULT_DIT from ..optimization.memory_manager import get_device_list @@ -124,6 +124,39 @@ def define_schema(cls) -> io.Schema: "Provides 20-40% speedup with compatible PyTorch 2.0+ and Triton installation." ) ), + io.Boolean.Input("dit_tiled", + default=False, + optional=True, + tooltip=( + "Enable spatial tiling for the DiT upscaling phase.\n" + "Reduces peak VRAM during final SeedVR2 diffusion inference by processing latent tiles with overlap blending.\n" + "Slower than full-frame DiT inference, but can prevent VRAM overflow on large crops." + ) + ), + io.Int.Input("dit_tile_size", + default=128, + min=32, + max=2048, + step=8, + optional=True, + tooltip=( + "Spatial tile size for DiT inference in latent-space pixels (default: 128).\n" + "Smaller tiles reduce VRAM further but increase runtime and may reduce global consistency.\n" + "Only used when dit_tiled is enabled." + ) + ), + io.Int.Input("dit_tile_overlap", + default=16, + min=0, + max=512, + step=1, + optional=True, + tooltip=( + "Overlap between DiT latent tiles in pixels (default: 16).\n" + "Higher overlap reduces visible seams but increases compute.\n" + "Only used when dit_tiled is enabled." + ) + ), ], outputs=[ io.Custom("SEEDVR2_DIT").Output( @@ -136,7 +169,8 @@ def define_schema(cls) -> io.Schema: def execute(cls, model: str, device: str, offload_device: str = "none", cache_model: bool = False, blocks_to_swap: int = 0, swap_io_components: bool = False, attention_mode: str = "sdpa", - torch_compile_args: Dict[str, Any] = None) -> io.NodeOutput: + torch_compile_args: Dict[str, Any] = None, dit_tiled: bool = False, + dit_tile_size: int = 128, dit_tile_overlap: int = 16) -> io.NodeOutput: """ Create DiT model configuration for SeedVR2 main node @@ -149,6 +183,9 @@ def execute(cls, model: str, device: str, offload_device: str = "none", swap_io_components: Whether to offload I/O components (requires offload_device != device) attention_mode: Attention computation backend ('sdpa', 'flash_attn_2', 'flash_attn_3', 'sageattn_2', or 'sageattn_3') torch_compile_args: Optional torch.compile configuration from settings node + dit_tiled: Enable spatial DiT tiling during upscaling + dit_tile_size: Spatial DiT tile size in latent-space pixels + dit_tile_overlap: Spatial overlap between DiT tiles in latent-space pixels Returns: NodeOutput containing configuration dictionary for SeedVR2 main node @@ -174,7 +211,10 @@ def execute(cls, model: str, device: str, offload_device: str = "none", "swap_io_components": swap_io_components, "attention_mode": attention_mode, "torch_compile_args": torch_compile_args, + "dit_tiled": dit_tiled, + "dit_tile_size": dit_tile_size, + "dit_tile_overlap": dit_tile_overlap, "node_id": get_executing_context().node_id, } - return io.NodeOutput(config) \ No newline at end of file + return io.NodeOutput(config) diff --git a/src/interfaces/video_upscaler.py b/src/interfaces/video_upscaler.py index 75bf9bb1..73e0b6e9 100644 --- a/src/interfaces/video_upscaler.py +++ b/src/interfaces/video_upscaler.py @@ -415,6 +415,9 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: # TorchCompile args (optional connection, can be None) dit_torch_compile_args = dit.get("torch_compile_args") + dit_tiled = dit.get("dit_tiled", False) + dit_tile_size = max(1, int(dit.get("dit_tile_size", 128))) + dit_tile_overlap = max(0, int(dit.get("dit_tile_overlap", 16))) vae_torch_compile_args = vae.get("torch_compile_args") # Print header @@ -472,6 +475,9 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: decode_tile_size=(decode_tile_size, decode_tile_size), decode_tile_overlap=(decode_tile_overlap, decode_tile_overlap), tile_debug=tile_debug, + dit_tiled=dit_tiled, + dit_tile_size=(dit_tile_size, dit_tile_size), + dit_tile_overlap=(dit_tile_overlap, dit_tile_overlap), attention_mode=attention_mode, torch_compile_args_dit=dit_torch_compile_args, torch_compile_args_vae=vae_torch_compile_args From 474688a76245bd65f2605041ad5bc41580207b47 Mon Sep 17 00:00:00 2001 From: xmarre Date: Fri, 17 Apr 2026 09:38:14 +0200 Subject: [PATCH 28/29] Remove temporary SeedVR2 breadcrumb tracing --- src/core/generation_phases.py | 76 +------------------------------- src/core/generation_utils.py | 42 +----------------- src/core/model_configuration.py | 2 - src/interfaces/video_upscaler.py | 13 +----- 4 files changed, 3 insertions(+), 130 deletions(-) diff --git a/src/core/generation_phases.py b/src/core/generation_phases.py index c00c4a21..971365b3 100644 --- a/src/core/generation_phases.py +++ b/src/core/generation_phases.py @@ -143,15 +143,8 @@ def _reconstruct_and_transform_batch( Transformed video in CTHW format, ready for color correction """ start_idx, end_idx, uniform_padding = ctx['batch_metadata'][batch_idx] - phase4_probe_enabled = getattr(debug, "enabled", False) - + # Prepare video batch - if phase4_probe_enabled: - debug.log( - f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} before _prepare_video_batch", - category="setup", - force=True, - ) video = _prepare_video_batch( images=ctx['input_images'], start_idx=start_idx, @@ -160,20 +153,8 @@ def _reconstruct_and_transform_batch( debug=None, log_info=False ) - if phase4_probe_enabled: - debug.log( - f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} after _prepare_video_batch", - category="setup", - force=True, - ) # Mirror Phase 1 ordering: move to the VAE device before padding and transform. - if phase4_probe_enabled: - debug.log( - f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} before manage_tensor(video->vae_device)", - category="setup", - force=True, - ) video = manage_tensor( tensor=video, target_device=ctx['vae_device'], @@ -183,27 +164,9 @@ def _reconstruct_and_transform_batch( reason="Phase 4 input reconstruction", indent_level=1, ) - if phase4_probe_enabled: - debug.log( - f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} after manage_tensor(video->vae_device)", - category="setup", - force=True, - ) # Apply 4n+1 padding using shared helper - if phase4_probe_enabled: - debug.log( - f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} before _apply_4n1_padding", - category="setup", - force=True, - ) video = _apply_4n1_padding(video) - if phase4_probe_enabled: - debug.log( - f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} after _apply_4n1_padding", - category="setup", - force=True, - ) # Extract RGB and transform if ctx.get('is_rgba', False): @@ -211,19 +174,7 @@ def _reconstruct_and_transform_batch( else: rgb_video = video - if phase4_probe_enabled: - debug.log( - f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} before video_transform(rgb_video)", - category="setup", - force=True, - ) transformed_video = ctx['video_transform'](rgb_video) - if phase4_probe_enabled: - debug.log( - f"SeedVR2 breadcrumb: reconstruct batch {batch_idx+1} after video_transform(rgb_video)", - category="setup", - force=True, - ) del video @@ -1348,8 +1299,6 @@ def postprocess_all_batches( # Move to VAE device for processing phase4_probe_enabled = getattr(debug, "enabled", False) - if phase4_probe_enabled: - debug.log(f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} before manage_tensor(sample->vae_device)", category="setup", force=True) sample = manage_tensor( tensor=sample, target_device=ctx['vae_device'], @@ -1359,24 +1308,15 @@ def postprocess_all_batches( reason="post-processing", indent_level=1 ) - if phase4_probe_enabled: - debug.log(f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} after manage_tensor(sample->vae_device)", category="setup", force=True) # Reconstruct transformed video on-demand for color correction input_video = None if color_correction != "none" and ctx.get('batch_metadata') is not None: if batch_idx < len(ctx['batch_metadata']) and ctx['batch_metadata'][batch_idx] is not None: # Reconstruct transformation - if phase4_probe_enabled: - debug.log(f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} before _reconstruct_and_transform_batch", category="setup", force=True) transformed_video = _reconstruct_and_transform_batch(ctx, batch_idx, debug) - if phase4_probe_enabled: - debug.log(f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} after _reconstruct_and_transform_batch", category="setup", force=True) - debug.log(f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} before optimized_single_video_rearrange", category="setup", force=True) input_video = optimized_single_video_rearrange(transformed_video) del transformed_video - if phase4_probe_enabled: - debug.log(f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} after optimized_single_video_rearrange", category="setup", force=True) # For batches after the first with temporal overlap, the overlap frames # were blended in Phase 3 and are not part of this slice. Skip them. @@ -1425,8 +1365,6 @@ def postprocess_all_batches( # Ensure both tensors are on same device (GPU) for color correction if input_video.device != sample.device: - if phase4_probe_enabled: - debug.log(f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} before manage_tensor(input_video->sample.device)", category="setup", force=True) input_video = manage_tensor( tensor=input_video, target_device=sample.device, @@ -1435,8 +1373,6 @@ def postprocess_all_batches( reason="color correction", indent_level=1 ) - if phase4_probe_enabled: - debug.log(f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} after manage_tensor(input_video->sample.device)", category="setup", force=True) # Apply selected color correction method debug.start_timer(f"color_correction_{color_correction}") @@ -1466,17 +1402,7 @@ def postprocess_all_batches( debug.log(f"Unknown color correction method: {color_correction}", level="WARNING", category="video", force=True, indent_level=1) if phase4_probe_enabled and color_correction_applied and sample.device.type != "cpu": - debug.log( - f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} before synchronize_device(after {color_correction})", - category="setup", - force=True, - ) synchronize_device(sample.device, debug=debug, reason=f"phase4 batch {info_idx+1} after {color_correction}") - debug.log( - f"SeedVR2 breadcrumb: phase4 batch {info_idx+1} after synchronize_device(after {color_correction})", - category="setup", - force=True, - ) debug.end_timer(f"color_correction_{color_correction}", f"Color correction ({color_correction})") diff --git a/src/core/generation_utils.py b/src/core/generation_utils.py index e02b8728..e7c69c68 100644 --- a/src/core/generation_utils.py +++ b/src/core/generation_utils.py @@ -47,26 +47,6 @@ script_directory = get_script_directory() -def _log_tensor_debug_state(debug: Optional['Debug'], label: str, tensor: Optional[torch.Tensor]) -> None: - """Log lightweight tensor metadata around transform breadcrumbs.""" - if debug is None: - return - - if tensor is None: - debug.log(f"SeedVR2 breadcrumb: {label}: tensor=None", category="setup", force=True) - return - - shape = tuple(tensor.shape) - dtype = tensor.dtype - device = tensor.device - approx_mib = (tensor.numel() * tensor.element_size()) / (1024 * 1024) - debug.log( - f"SeedVR2 breadcrumb: {label}: shape={shape}, dtype={dtype}, device={device}, approx_mib={approx_mib:.2f}", - category="setup", - force=True, - ) - - def _compute_side_resize_output_dims( input_height: int, input_width: int, @@ -105,23 +85,14 @@ def _compute_sample_frame_target_dims( max_resolution: int = 0, debug: Optional['Debug'] = None, ) -> Tuple[int, int]: - """Compute probe target dimensions with breadcrumbs but without running a real resize.""" + """Compute probe target dimensions without running a real resize.""" input_h, input_w = sample_frame.shape[-2:] - - _log_tensor_debug_state(debug, "temp_transform input", sample_frame) - debug.log("SeedVR2 breadcrumb: before compute_target_dims(sample_frame)", category="setup", force=True) if debug else None resized_h, resized_w = _compute_side_resize_output_dims( input_height=input_h, input_width=input_w, resolution=resolution, max_resolution=max_resolution, ) - debug.log( - f"SeedVR2 breadcrumb: computed temp_transform target dims: input=({input_h}, {input_w}) -> resized=({resized_h}, {resized_w})", - category="setup", - force=True, - ) if debug else None - debug.log("SeedVR2 breadcrumb: after compute_target_dims(sample_frame)", category="setup", force=True) if debug else None return resized_h, resized_w @@ -187,7 +158,6 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: existing_transform = ctx.get('video_transform') if existing_transform is not None: - debug.log("SeedVR2 breadcrumb: setup_video_transform using existing transform", category="setup", force=True) if debug else None # Transform exists - return cached dimensions without re-running the pipeline if 'true_target_dims' in ctx and 'padded_target_dims' in ctx: true_h, true_w = ctx['true_target_dims'] @@ -196,14 +166,12 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: debug.log("Reusing pre-initialized video transformation pipeline", category="reuse") return true_h, true_w, padded_h, padded_w if sample_frame is not None: - debug.log("SeedVR2 breadcrumb: before compute_target_dims(sample_frame)", category="setup", force=True) if debug else None resized_h, resized_w = _compute_sample_frame_target_dims( sample_frame, resolution, max_resolution, debug, ) - debug.log("SeedVR2 breadcrumb: after compute_target_dims(sample_frame)", category="setup", force=True) if debug else None # Round to even numbers for video codec compatibility (libx264 requirement) true_h = (resized_h // 2) * 2 @@ -231,20 +199,17 @@ def setup_video_transform(ctx: Dict[str, Any], resolution: int, max_resolution: return 0, 0, 0, 0 # Create transformation pipeline (first time or after cleanup) - debug.log("SeedVR2 breadcrumb: setup_video_transform creating new transform", category="setup", force=True) if debug else None ctx['video_transform'] = prepare_video_transforms(resolution, max_resolution, debug) # Compute dimensions if sample frame provided if sample_frame is not None: # Get true target size (after resize, before padding) - debug.log("SeedVR2 breadcrumb: before compute_target_dims(sample_frame)", category="setup", force=True) if debug else None resized_h, resized_w = _compute_sample_frame_target_dims( sample_frame, resolution, max_resolution, debug, ) - debug.log("SeedVR2 breadcrumb: after compute_target_dims(sample_frame)", category="setup", force=True) if debug else None # Round to even numbers for video codec compatibility (libx264 requirement) true_h = (resized_h // 2) * 2 @@ -309,23 +274,18 @@ def compute_generation_info( channels_info = "RGBA" if images.shape[-1] == 4 else "RGB" # Apply prepending if requested - debug.log("SeedVR2 breadcrumb: compute_generation_info before temporal prepend", category="generation", force=True) if debug else None if prepend_frames > 0: images = pad_video_temporal(images, count=prepend_frames, temporal_dim=0, prepend=True, debug=debug) - debug.log("SeedVR2 breadcrumb: compute_generation_info after temporal prepend", category="generation", force=True) if debug else None # Track total frames after prepending total_frames = len(images) ctx['total_frames'] = total_frames # Setup transform and compute dimensions on final frame count - debug.log("SeedVR2 breadcrumb: compute_generation_info before sample_frame", category="generation", force=True) if debug else None sample_frame = images[0].permute(2, 0, 1).unsqueeze(0) - debug.log("SeedVR2 breadcrumb: compute_generation_info before setup_video_transform", category="generation", force=True) if debug else None true_h, true_w, padded_h, padded_w = setup_video_transform( ctx, resolution, max_resolution, debug, sample_frame ) - debug.log("SeedVR2 breadcrumb: compute_generation_info after setup_video_transform", category="generation", force=True) if debug else None del sample_frame info = { diff --git a/src/core/model_configuration.py b/src/core/model_configuration.py index c81d4381..28c2c9dc 100644 --- a/src/core/model_configuration.py +++ b/src/core/model_configuration.py @@ -946,12 +946,10 @@ def configure_runner( ) # Phase 4: Setup models (load from cache or create new) - debug.log("SeedVR2 breadcrumb: before _setup_models", category="runner", force=True) _setup_models( runner, cache_context, dit_model, vae_model, base_cache_dir, block_swap_config, debug ) - debug.log("SeedVR2 breadcrumb: after _setup_models", category="runner", force=True) except BaseException: _evict_claimed_cached_models(cache_context, runner, debug) if runner is not None and cache_context.get('reusing_runner', False): diff --git a/src/interfaces/video_upscaler.py b/src/interfaces/video_upscaler.py index 73e0b6e9..8592dcc4 100644 --- a/src/interfaces/video_upscaler.py +++ b/src/interfaces/video_upscaler.py @@ -456,7 +456,6 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: ) # Prepare runner with model state management and global cache - debug.log("SeedVR2 breadcrumb: before prepare_runner", category="runner", force=True) runner, cache_context = prepare_runner( dit_model=dit_model, vae_model=vae_model, @@ -482,7 +481,6 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: torch_compile_args_dit=dit_torch_compile_args, torch_compile_args_vae=vae_torch_compile_args ) - debug.log("SeedVR2 breadcrumb: after prepare_runner", category="runner", force=True) runner._seedvr2_execution_active = True runner._seedvr2_runner_tainted = False @@ -497,32 +495,24 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: and cache_context.get('cached_dit') is not None and cache_context.get('cached_vae') is not None ): - debug.log("SeedVR2 breadcrumb: before set_runner", category="runner", force=True) cache_context['global_cache'].set_runner( cache_context.get('dit_id'), cache_context.get('vae_id'), runner, debug, ) - debug.log("SeedVR2 breadcrumb: after set_runner", category="runner", force=True) # Store cache context in ctx for use in generation phases ctx['cache_context'] = cache_context - debug.log("SeedVR2 breadcrumb: before load_text_embeddings", category="dit", force=True) # Preload text embeddings before Phase 1 to avoid sync stall in Phase 2 ctx['text_embeds'] = load_text_embeddings(script_directory, ctx['dit_device'], ctx['compute_dtype'], debug) - debug.log("SeedVR2 breadcrumb: after load_text_embeddings", category="dit", force=True) + debug.log("Loaded text embeddings for DiT", category="dit") - debug.log("SeedVR2 breadcrumb: before log_memory_state", category="memory", force=True) debug.log_memory_state("After model preparation", show_tensors=False, detailed_tensors=False) - debug.log("SeedVR2 breadcrumb: after log_memory_state", category="memory", force=True) - debug.log("SeedVR2 breadcrumb: before end_timer(model_preparation)", category="runner", force=True) debug.end_timer("model_preparation", "Model preparation", force=True, show_breakdown=True) - debug.log("SeedVR2 breadcrumb: after end_timer(model_preparation)", category="runner", force=True) - debug.log("SeedVR2 breadcrumb: before compute_generation_info", category="generation", force=True) # Compute generation info and log start (handles prepending internally) image, gen_info = compute_generation_info( ctx=ctx, @@ -536,7 +526,6 @@ def cleanup(dit_cache: bool = False, vae_cache: bool = False) -> None: temporal_overlap=temporal_overlap, debug=debug ) - debug.log("SeedVR2 breadcrumb: after compute_generation_info", category="generation", force=True) # Log generation start in consistent format log_generation_start(gen_info, debug) From 1673ddc7d87fb176e4121f5040e76b4e9b8ecba1 Mon Sep 17 00:00:00 2001 From: xmarre <54859656+xmarre@users.noreply.github.com> Date: Mon, 22 Jun 2026 00:09:13 +0200 Subject: [PATCH 29/29] Make BF16 import probe safe --- src/optimization/compatibility.py | 81 ++++++++++++++++++++++++++++--- 1 file changed, 73 insertions(+), 8 deletions(-) diff --git a/src/optimization/compatibility.py b/src/optimization/compatibility.py index c462022b..147b8b55 100644 --- a/src/optimization/compatibility.py +++ b/src/optimization/compatibility.py @@ -681,21 +681,86 @@ def _check_conv3d_memory_bug(): # Bfloat16 CUBLAS support +# +# Original behavior performed a CUDA BF16 matmul probe at import time and +# re-raised most CUDA errors. Under WSL / Blackwell / newer PyTorch builds this +# can fail with ``CUDA driver error: unknown error`` while ComfyUI is merely +# importing custom nodes, which prevents the whole SeedVR2 node pack from +# loading. Import-time CUDA probes must be best-effort only. +_SEEDVR2_BFLOAT16_PATCH = "wsl-safe-import-probe-2026-06-21" + + +def _env_flag(name: str, default: bool = False) -> bool: + value = os.environ.get(name) + if value is None: + return default + return value.strip().lower() in ("1", "true", "yes", "y", "on") + + +def _env_flag_set(name: str): + value = os.environ.get(name) + if value is None: + return None + value = value.strip().lower() + if value in ("1", "true", "yes", "y", "on"): + return True + if value in ("0", "false", "no", "n", "off"): + return False + return None + + def _probe_bfloat16_support() -> bool: - if not torch.cuda.is_available(): + """ + Import-safe BF16 capability selection. + + Defaults to float16 without touching CUDA at import time. This keeps the + custom node importable even if the CUDA context/driver is temporarily in a + bad state during ComfyUI startup. + + Environment controls: + SEEDVR2_FORCE_BFLOAT16=1 -> force BF16 on, no probe + SEEDVR2_FORCE_BFLOAT16=0 -> force BF16 off, no probe + SEEDVR2_IMPORT_BFLOAT16_PROBE=1 -> run best-effort CUDA probe at import + """ + forced = _env_flag_set("SEEDVR2_FORCE_BFLOAT16") + if forced is True: + print("[SeedVR2] BF16 forced on via SEEDVR2_FORCE_BFLOAT16=1", flush=True) return True + if forced is False: + print("[SeedVR2] BF16 forced off via SEEDVR2_FORCE_BFLOAT16=0; using float16", flush=True) + return False + + # Safer default: do not allocate CUDA tensors while ComfyUI is importing + # custom nodes. Users who want automatic probing can opt in explicitly. + if not _env_flag("SEEDVR2_IMPORT_BFLOAT16_PROBE", default=False): + print("[SeedVR2] Import-time BF16 CUDA probe skipped; using float16. Set SEEDVR2_IMPORT_BFLOAT16_PROBE=1 to probe.", flush=True) + return False + try: - a = torch.randn(8, 8, dtype=torch.bfloat16, device='cuda:0') - _ = torch.matmul(a, a) - del a - return True - except RuntimeError as e: - if "CUBLAS_STATUS_NOT_SUPPORTED" in str(e): + if not torch.cuda.is_available(): return False - raise + with torch.no_grad(): + a = torch.empty((8, 8), dtype=torch.bfloat16, device="cuda:0") + b = torch.matmul(a, a) + torch.cuda.synchronize() + del a, b + return True + except BaseException as e: + # Never let an optional import-time feature probe kill the node pack. + try: + print( + f"[SeedVR2] BF16 CUDA import probe failed; disabling BF16 for this session: " + f"{type(e).__name__}: {e}", + flush=True, + ) + except Exception: + pass + return False + BFLOAT16_SUPPORTED = _probe_bfloat16_support() COMPUTE_DTYPE = torch.bfloat16 if BFLOAT16_SUPPORTED else torch.float16 +print(f"[SeedVR2] compute dtype selected at import: {COMPUTE_DTYPE}", flush=True) def call_rope_with_stability(method, *args, **kwargs):