@@ -465,10 +465,51 @@ def aotriton_supported(gpu_arch):
465465logging .info (f"Set vram state to: { vram_state .name } " )
466466
467467DISABLE_SMART_MEMORY = args .disable_smart_memory
468+ AGGRESSIVE_OFFLOAD = args .aggressive_offload
468469
469470if DISABLE_SMART_MEMORY :
470471 logging .info ("Disabling smart memory management" )
471472
473+ if AGGRESSIVE_OFFLOAD :
474+ logging .info ("Aggressive offload enabled: models will be freed from RAM after use (designed for Apple Silicon)" )
475+
476+ # ---------------------------------------------------------------------------
477+ # Model lifecycle callbacks — on_model_destroyed
478+ # ---------------------------------------------------------------------------
479+ # Why not comfy.hooks? The existing hook system (comfy/hooks.py) is scoped
480+ # to *sampling conditioning* — LoRA weight injection, transformer_options,
481+ # and keyframe scheduling. It has no concept of model-management lifecycle
482+ # events such as "a model's parameters were deallocated".
483+ #
484+ # This lightweight callback list fills that gap. It is intentionally minimal
485+ # (append-only, no priorities, no removal) because the only current consumer
486+ # is the execution-engine cache invalidator registered in PromptExecutor.
487+ # If upstream adopts a formal lifecycle-event bus in the future, these
488+ # callbacks should migrate to that system.
489+ # ---------------------------------------------------------------------------
490+ _on_model_destroyed_callbacks : list = []
491+
492+
493+ def register_model_destroyed_callback (callback ):
494+ """Register a listener for post-destruction lifecycle events.
495+
496+ After ``free_memory`` moves one or more models to the ``meta`` device
497+ (aggressive offload), every registered callback is invoked once with a
498+ *reason* string describing the batch (e.g. ``"batch"``).
499+
500+ Typical usage — executed by ``PromptExecutor.__init__``::
501+
502+ def _invalidate(reason):
503+ executor.caches.outputs.clear_all()
504+ register_model_destroyed_callback(_invalidate)
505+
506+ Args:
507+ callback: ``Callable[[str], None]`` — receives a human-readable
508+ reason string. Must be safe to call from within the
509+ ``free_memory`` critical section (no heavy I/O, no model loads).
510+ """
511+ _on_model_destroyed_callbacks .append (callback )
512+
472513def get_torch_device_name (device ):
473514 if hasattr (device , 'type' ):
474515 if device .type == "cuda" :
@@ -640,14 +681,20 @@ def offloaded_memory(loaded_models, device):
640681WINDOWS = any (platform .win32_ver ())
641682
642683EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
643- if WINDOWS :
684+ if cpu_state == CPUState .MPS :
685+ # macOS with Apple Silicon: shared memory means OS needs more headroom.
686+ # Reserve 4 GB for macOS + system services to prevent swap thrashing.
687+ EXTRA_RESERVED_VRAM = 4 * 1024 * 1024 * 1024
688+ logging .info ("MPS detected: reserving 4 GB for macOS system overhead" )
689+ elif WINDOWS :
644690 import comfy .windows
645691 EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
646692 if total_vram > (15 * 1024 ): # more extra reserved vram on 16GB+ cards
647693 EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
648694 def get_free_ram ():
649695 return comfy .windows .get_free_ram ()
650- else :
696+
697+ if not WINDOWS :
651698 def get_free_ram ():
652699 return psutil .virtual_memory ().available
653700
@@ -669,14 +716,25 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
669716
670717 for i in range (len (current_loaded_models ) - 1 , - 1 , - 1 ):
671718 shift_model = current_loaded_models [i ]
672- if device is None or shift_model .device == device :
719+ # On Apple Silicon SHARED mode, CPU RAM == GPU VRAM (same physical memory).
720+ # Bypass the device filter so CPU-loaded models (like CLIP) can be freed.
721+ device_match = (device is None or shift_model .device == device )
722+ if AGGRESSIVE_OFFLOAD and vram_state == VRAMState .SHARED :
723+ device_match = True
724+ if device_match :
673725 if shift_model not in keep_loaded and not shift_model .is_dead ():
674726 can_unload .append ((- shift_model .model_offloaded_memory (), sys .getrefcount (shift_model .model ), shift_model .model_memory (), i ))
675727 shift_model .currently_used = False
676728
677729 can_unload_sorted = sorted (can_unload )
730+ # Collect models to destroy via meta device AFTER the unload loop completes,
731+ # so we don't kill weakrefs of models still being iterated.
732+ _meta_destroy_queue = []
678733 for x in can_unload_sorted :
679734 i = x [- 1 ]
735+ # Guard: weakref may already be dead from a previous iteration
736+ if current_loaded_models [i ].model is None :
737+ continue
680738 memory_to_free = 1e32
681739 pins_to_free = 1e32
682740 if not DISABLE_SMART_MEMORY or device is None :
@@ -687,15 +745,66 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
687745 #as that works on-demand.
688746 memory_required -= current_loaded_models [i ].model .loaded_size ()
689747 memory_to_free = 0
748+
749+ # Aggressive offload for Apple Silicon: force-unload unused models
750+ # regardless of free memory, since CPU RAM == GPU VRAM.
751+ if AGGRESSIVE_OFFLOAD and vram_state == VRAMState .SHARED :
752+ if not current_loaded_models [i ].currently_used :
753+ memory_to_free = 1e32 # Force unload
754+ model_name = current_loaded_models [i ].model .model .__class__ .__name__
755+ model_size_mb = current_loaded_models [i ].model_memory () / (1024 * 1024 )
756+ logging .info (f"[aggressive-offload] Force-unloading { model_name } ({ model_size_mb :.0f} MB) from shared RAM" )
757+
690758 if memory_to_free > 0 and current_loaded_models [i ].model_unload (memory_to_free ):
691759 logging .debug (f"Unloading { current_loaded_models [i ].model .model .__class__ .__name__ } " )
760+ # Queue for meta device destruction after loop completes.
761+ # Only destroy large models (>1 GB) — small models like the VAE (160 MB)
762+ # are kept because the execution cache may reuse their patcher across
763+ # workflow nodes (e.g. vae_loader is cached while vae_decode runs later).
764+ if AGGRESSIVE_OFFLOAD and vram_state == VRAMState .SHARED :
765+ if current_loaded_models [i ].model is not None :
766+ model_size = current_loaded_models [i ].model_memory ()
767+ if model_size > 1024 * 1024 * 1024 : # Only meta-destroy models > 1 GB
768+ _meta_destroy_queue .append (i )
692769 unloaded_model .append (i )
693770 if pins_to_free > 0 :
694- logging .debug (f"PIN Unloading { current_loaded_models [i ].model .model .__class__ .__name__ } " )
695- current_loaded_models [i ].model .partially_unload_ram (pins_to_free )
771+ if current_loaded_models [i ].model is not None :
772+ logging .debug (f"PIN Unloading { current_loaded_models [i ].model .model .__class__ .__name__ } " )
773+ current_loaded_models [i ].model .partially_unload_ram (pins_to_free )
774+
775+ # --- Phase 2: Deferred meta-device destruction -------------------------
776+ # Move parameters of queued models to the 'meta' device. This replaces
777+ # every nn.Parameter with a zero-storage meta tensor, releasing physical
778+ # RAM on unified-memory systems (Apple Silicon). The operation is
779+ # deferred until *after* the unload loop to avoid invalidating weakrefs
780+ # that other iterations may still reference.
781+ for i in _meta_destroy_queue :
782+ try :
783+ model_ref = current_loaded_models [i ].model
784+ if model_ref is None :
785+ continue
786+ inner_model = model_ref .model
787+ model_name = inner_model .__class__ .__name__
788+ param_count = sum (p .numel () * p .element_size () for p in inner_model .parameters ())
789+ inner_model .to (device = "meta" )
790+ logging .info (f"[aggressive-offload] Moved { model_name } params to meta device, freed { param_count / (1024 ** 2 ):.0f} MB" )
791+ except Exception as e :
792+ logging .warning (f"[aggressive-offload] Failed to move model to meta: { e } " )
793+
794+ # --- Phase 3: Notify lifecycle listeners --------------------------------
795+ # Fire on_model_destroyed callbacks *once* after the entire batch has been
796+ # processed, not per-model. This lets the execution engine clear its
797+ # output cache in a single operation (see PromptExecutor.__init__).
798+ if _meta_destroy_queue and _on_model_destroyed_callbacks :
799+ for cb in _on_model_destroyed_callbacks :
800+ cb ("batch" )
801+ logging .info (f"[aggressive-offload] Invalidated execution cache after destroying { len (_meta_destroy_queue )} model(s)" )
696802
697803 for x in can_unload_sorted :
698804 i = x [- 1 ]
805+ # Guard: weakref may be dead after cache invalidation (meta device move)
806+ if current_loaded_models [i ].model is None :
807+ continue
699808 ram_to_free = ram_required - psutil .virtual_memory ().available
700809 if ram_to_free <= 0 and i not in unloaded_model :
701810 continue
@@ -708,6 +817,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
708817
709818 if len (unloaded_model ) > 0 :
710819 soft_empty_cache ()
820+ if AGGRESSIVE_OFFLOAD :
821+ gc .collect () # Force Python GC to release model tensors
822+ soft_empty_cache () # Second pass to free MPS allocator cache
711823 elif device is not None :
712824 if vram_state != VRAMState .HIGH_VRAM :
713825 mem_free_total , mem_free_torch = get_free_memory (device , torch_free_too = True )
0 commit comments