Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
b9a3816
Investigate CUDA cleanup hang
xmarre Mar 15, 2026
3a5411d
Fix CUDA cleanup reuse hangs
xmarre Mar 15, 2026
516da73
Fix runner template cache handling
xmarre Mar 15, 2026
b9f6033
Fix release_model_memory containers
xmarre Mar 15, 2026
8e4c89a
Synchronize runner template access
xmarre Mar 15, 2026
e16c827
Synchronize runner template access
xmarre Mar 15, 2026
1d02d47
Merge pull request #1 from xmarre/codex/fix-cuda-state-cleanup-hang
xmarre Mar 15, 2026
4760017
Address runner cache eviction bug
xmarre Mar 15, 2026
f0376ea
Keep model claim until cache rewrite
xmarre Mar 15, 2026
7d28241
Allow reuse of hot model cache
xmarre Mar 16, 2026
549c065
Fix cached model claim ownership during teardown
xmarre Mar 19, 2026
92d7a6e
Fix atomic runner eviction and claim cleanup
xmarre Mar 19, 2026
8a07d15
Fix cached claim release after finalize refresh
xmarre Mar 20, 2026
4b2c202
Merge pull request #2 from xmarre/codex/apply-patch-instructions
xmarre Mar 20, 2026
08fad5b
Track newly cached models for claim release
xmarre Mar 20, 2026
c6466fa
Add breadcrumbs around SeedVR2 model prep
xmarre Mar 20, 2026
9b572ed
Add breadcrumbs around video transform setup
xmarre Mar 20, 2026
5cd8f96
Fix video transform dim planning without live tensor pass
xmarre Mar 21, 2026
04e00da
Merge pull request #3 from xmarre/codex/stop-full-transform-for-padding
xmarre Mar 21, 2026
a915809
Refactor target-dimension probe and add debug breadcrumbs
xmarre Apr 3, 2026
20627f5
Merge pull request #4 from xmarre/codex/refactor-targetdim-probe
xmarre Apr 3, 2026
6dcbf02
Remove repository metadata and documentation
xmarre Apr 6, 2026
40c491b
Merge pull request #5 from xmarre/codex/investigate-batch-2-stall
xmarre Apr 6, 2026
a292b2b
Remove repository metadata and documentation files
xmarre Apr 6, 2026
0b0dbbd
Remove repository metadata and documentation files
xmarre Apr 6, 2026
ae69005
Remove repository metadata and documentation files
xmarre Apr 6, 2026
ca7f1a2
Remove repository metadata and documentation files
xmarre Apr 6, 2026
ee506ad
Remove repository metadata and documentation files
xmarre Apr 6, 2026
828f5fd
Merge pull request #6 from xmarre/codex/fix-probe-resize-hang
xmarre Apr 6, 2026
f5c46f4
Align Phase 4 reconstructed input transform with Phase 1 device path
xmarre Apr 6, 2026
29aae4a
Revert "Align Phase 4 reconstructed input transform with Phase 1 devi…
xmarre Apr 6, 2026
5a043be
Revert "Revert "Align Phase 4 reconstructed input transform with Phas…
xmarre Apr 6, 2026
aaa98e9
Merge pull request #7 from xmarre/codex/reapply-phase4-device-path
xmarre Apr 6, 2026
431ca7b
Add DiT tiling GUI patch wiring
xmarre Apr 16, 2026
d14078b
Merge pull request #8 from xmarre/codex/seedvr2-dit-tiling-pr
xmarre Apr 16, 2026
474688a
Remove temporary SeedVR2 breadcrumb tracing
xmarre Apr 17, 2026
b352c80
Merge pull request #9 from xmarre/codex/remove-seedvr2-breadcrumb-log…
xmarre Apr 17, 2026
1673ddc
Make BF16 import probe safe
xmarre Jun 21, 2026
b62b5b6
Merge pull request #10 from xmarre/codex/apply-patch-for-bf16-support…
xmarre Jun 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added .codex
Empty file.
326 changes: 227 additions & 99 deletions inference_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,19 @@
decode_all_batches,
postprocess_all_batches
)
from src.core.model_configuration import (
_evict_claimed_cached_models,
_finalize_claimed_cached_models_for_reuse,
)
from src.utils.debug import Debug
from src.optimization.memory_manager import clear_memory, get_gpu_backend, is_cuda_available
from src.optimization.memory_manager import (
cleanup_text_embeddings,
clear_memory,
complete_cleanup,
get_gpu_backend,
is_cuda_available,
set_model_cache_claimed_state,
)
debug = Debug(enabled=False) # Will be enabled via --debug CLI flag


Expand Down Expand Up @@ -913,103 +924,220 @@ def _process_frames_core(
dit_id = "cli_dit" if cache_dit else None
vae_id = "cli_vae" if cache_vae else None

runner, cache_context = prepare_runner(
dit_model=args.dit_model,
vae_model=DEFAULT_VAE,
model_dir=model_dir,
debug=debug,
ctx=ctx,
dit_cache=cache_dit,
vae_cache=cache_vae,
dit_id=dit_id,
vae_id=vae_id,
block_swap_config={
'blocks_to_swap': args.blocks_to_swap,
'swap_io_components': args.swap_io_components,
'offload_device': dit_offload,
},
encode_tiled=args.vae_encode_tiled,
encode_tile_size=(args.vae_encode_tile_size, args.vae_encode_tile_size),
encode_tile_overlap=(args.vae_encode_tile_overlap, args.vae_encode_tile_overlap),
decode_tiled=args.vae_decode_tiled,
decode_tile_size=(args.vae_decode_tile_size, args.vae_decode_tile_size),
decode_tile_overlap=(args.vae_decode_tile_overlap, args.vae_decode_tile_overlap),
tile_debug=args.tile_debug.lower() if args.tile_debug else "false",
attention_mode=args.attention_mode,
torch_compile_args_dit=torch_compile_args_dit,
torch_compile_args_vae=torch_compile_args_vae
)

ctx['cache_context'] = cache_context
if runner_cache is not None:
runner_cache['runner'] = runner

# Preload text embeddings before Phase 1 to avoid sync stall in Phase 2
ctx['text_embeds'] = load_text_embeddings(script_directory, ctx['dit_device'], ctx['compute_dtype'], debug)
debug.log("Loaded text embeddings for DiT", category="dit")

# Compute generation info and log start (handles prepending internally)
frames_tensor, gen_info = compute_generation_info(
ctx=ctx,
images=frames_tensor,
resolution=args.resolution,
max_resolution=args.max_resolution,
batch_size=args.batch_size,
uniform_batch_size=args.uniform_batch_size,
seed=args.seed,
prepend_frames=args.prepend_frames,
temporal_overlap=args.temporal_overlap,
debug=debug
)
log_generation_start(gen_info, debug)

# Phase 1: Encode
ctx = encode_all_batches(
runner, ctx=ctx, images=frames_tensor,
debug=debug,
batch_size=args.batch_size,
uniform_batch_size=args.uniform_batch_size,
seed=args.seed,
progress_callback=None,
temporal_overlap=args.temporal_overlap,
resolution=args.resolution,
max_resolution=args.max_resolution,
input_noise_scale=args.input_noise_scale,
color_correction=args.color_correction
)

# Phase 2: Upscale
ctx = upscale_all_batches(
runner, ctx=ctx, debug=debug, progress_callback=None,
seed=args.seed,
latent_noise_scale=args.latent_noise_scale,
cache_model=cache_dit
)

# Phase 3: Decode
ctx = decode_all_batches(
runner, ctx=ctx, debug=debug, progress_callback=None,
cache_model=cache_vae
)

# Phase 4: Post-process
ctx = postprocess_all_batches(
ctx=ctx, debug=debug, progress_callback=None,
color_correction=args.color_correction,
prepend_frames=0, # Worker mode handles this in main process
temporal_overlap=args.temporal_overlap,
batch_size=args.batch_size
)

result_tensor = ctx['final_video']

# Convert to CPU and compatible dtype
if result_tensor.is_cuda or result_tensor.is_mps:
result_tensor = result_tensor.cpu()
if result_tensor.dtype in (torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2):
result_tensor = result_tensor.to(torch.float32)

return result_tensor
runner = None
cache_context = None

def cleanup(dit_cache_flag: bool = False, vae_cache_flag: bool = False) -> None:
nonlocal runner, ctx

if runner is not None:
claimed_dit = cache_context.get('cached_dit') if cache_context is not None else None
claimed_vae = cache_context.get('cached_vae') if cache_context is not None else None
refreshed_dit = None
refreshed_vae = None
try:
try:
complete_cleanup(
runner=runner,
debug=debug,
dit_cache=dit_cache_flag,
vae_cache=vae_cache_flag,
)
if dit_cache_flag or vae_cache_flag:
refreshed_dit, refreshed_vae = _finalize_claimed_cached_models_for_reuse(cache_context, runner, debug)
except Exception:
try:
_evict_claimed_cached_models(cache_context, runner, debug)
except Exception as evict_error:
if debug is not None:
debug.log(
f"Failed to evict claimed cached models while handling prior cleanup/finalize exception: {evict_error}",
level="WARNING",
category="cleanup",
force=True,
)
raise
finally:
if dit_cache_flag and claimed_dit is not None:
set_model_cache_claimed_state(claimed_dit, False)
if dit_cache_flag and refreshed_dit is not None and refreshed_dit is not claimed_dit:
set_model_cache_claimed_state(refreshed_dit, False)
if vae_cache_flag and claimed_vae is not None:
set_model_cache_claimed_state(claimed_vae, False)
if vae_cache_flag and refreshed_vae is not None and refreshed_vae is not claimed_vae:
set_model_cache_claimed_state(refreshed_vae, False)
runner._seedvr2_execution_active = False

if not (dit_cache_flag or vae_cache_flag):
runner = None
if runner_cache is not None:
runner_cache.pop('runner', None)

if ctx is not None:
cleanup_text_embeddings(ctx, debug)
if not (dit_cache_flag or vae_cache_flag):
ctx = None
if runner_cache is not None:
runner_cache.pop('ctx', None)

try:
runner, cache_context = prepare_runner(
dit_model=args.dit_model,
vae_model=DEFAULT_VAE,
model_dir=model_dir,
debug=debug,
ctx=ctx,
dit_cache=cache_dit,
vae_cache=cache_vae,
dit_id=dit_id,
vae_id=vae_id,
block_swap_config={
'blocks_to_swap': args.blocks_to_swap,
'swap_io_components': args.swap_io_components,
'offload_device': dit_offload,
},
encode_tiled=args.vae_encode_tiled,
encode_tile_size=(args.vae_encode_tile_size, args.vae_encode_tile_size),
encode_tile_overlap=(args.vae_encode_tile_overlap, args.vae_encode_tile_overlap),
decode_tiled=args.vae_decode_tiled,
decode_tile_size=(args.vae_decode_tile_size, args.vae_decode_tile_size),
decode_tile_overlap=(args.vae_decode_tile_overlap, args.vae_decode_tile_overlap),
tile_debug=args.tile_debug.lower() if args.tile_debug else "false",
attention_mode=args.attention_mode,
torch_compile_args_dit=torch_compile_args_dit,
torch_compile_args_vae=torch_compile_args_vae
)

runner._seedvr2_execution_active = True
runner._seedvr2_runner_tainted = False
runner._seedvr2_dit_phase_cleaned = False
runner._seedvr2_vae_phase_cleaned = False

if (
cache_context is not None
and not cache_context.get('reusing_runner', False)
and cache_context.get('cached_dit') is not None
and cache_context.get('cached_vae') is not None
):
cache_context['global_cache'].set_runner(
cache_context.get('dit_id'),
cache_context.get('vae_id'),
runner,
debug,
)

ctx['cache_context'] = cache_context
if runner_cache is not None:
runner_cache['runner'] = runner

# Preload text embeddings before Phase 1 to avoid sync stall in Phase 2
ctx['text_embeds'] = load_text_embeddings(script_directory, ctx['dit_device'], ctx['compute_dtype'], debug)
debug.log("Loaded text embeddings for DiT", category="dit")

# Compute generation info and log start (handles prepending internally)
frames_tensor, gen_info = compute_generation_info(
ctx=ctx,
images=frames_tensor,
resolution=args.resolution,
max_resolution=args.max_resolution,
batch_size=args.batch_size,
uniform_batch_size=args.uniform_batch_size,
seed=args.seed,
prepend_frames=args.prepend_frames,
temporal_overlap=args.temporal_overlap,
debug=debug
)
log_generation_start(gen_info, debug)

# Phase 1: Encode
ctx = encode_all_batches(
runner, ctx=ctx, images=frames_tensor,
debug=debug,
batch_size=args.batch_size,
uniform_batch_size=args.uniform_batch_size,
seed=args.seed,
progress_callback=None,
temporal_overlap=args.temporal_overlap,
resolution=args.resolution,
max_resolution=args.max_resolution,
input_noise_scale=args.input_noise_scale,
color_correction=args.color_correction
)

# Phase 2: Upscale
ctx = upscale_all_batches(
runner, ctx=ctx, debug=debug, progress_callback=None,
seed=args.seed,
latent_noise_scale=args.latent_noise_scale,
cache_model=cache_dit
)

# Phase 3: Decode
ctx = decode_all_batches(
runner, ctx=ctx, debug=debug, progress_callback=None,
cache_model=cache_vae
)

# Phase 4: Post-process
ctx = postprocess_all_batches(
ctx=ctx, debug=debug, progress_callback=None,
color_correction=args.color_correction,
prepend_frames=0, # Worker mode handles this in main process
temporal_overlap=args.temporal_overlap,
batch_size=args.batch_size
)

result_tensor = ctx['final_video']

# Convert to CPU and compatible dtype
if result_tensor.is_cuda or result_tensor.is_mps:
result_tensor = result_tensor.cpu()
if result_tensor.dtype in (torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2):
result_tensor = result_tensor.to(torch.float32)

cleanup(dit_cache_flag=cache_dit, vae_cache_flag=cache_vae)
return result_tensor
except BaseException:
claimed_dit = cache_context.get('cached_dit') if cache_context is not None else None
claimed_vae = cache_context.get('cached_vae') if cache_context is not None else None
if cache_context is not None:
_evict_claimed_cached_models(cache_context, runner, debug)
if runner is not None and cache_context.get('reusing_runner', False):
try:
cache_context['global_cache'].taint_and_remove_runner(
cache_context.get('dit_id'),
cache_context.get('vae_id'),
debug,
expected_runner=runner,
)
except BaseException as cache_error:
if debug is not None:
debug.log(
f"Failed to evict cached runner while handling prior exception "
f"(runner={id(runner)}, dit_id={cache_context.get('dit_id')}, vae_id={cache_context.get('vae_id')}): {cache_error}",
level="WARNING",
category="cleanup",
force=True,
)

try:
try:
cleanup(dit_cache_flag=False, vae_cache_flag=False)
except BaseException as cleanup_error:
if debug is not None:
debug.log(
f"Cleanup failed while handling prior exception "
f"(runner={id(runner) if runner is not None else 'none'}, dit_id={cache_context.get('dit_id') if cache_context is not None else 'none'}, vae_id={cache_context.get('vae_id') if cache_context is not None else 'none'}): {cleanup_error}",
level="WARNING",
category="cleanup",
force=True,
)
finally:
if claimed_dit is not None:
set_model_cache_claimed_state(claimed_dit, False)
if claimed_vae is not None:
set_model_cache_claimed_state(claimed_vae, False)
raise


def _worker_process(
Expand Down Expand Up @@ -1709,4 +1837,4 @@ def main() -> None:
debug.print_footer()

if __name__ == "__main__":
main()
main()
Loading