Skip to content

Commit 2c8db00

Browse files
committed
feat: add --aggressive-offload for Apple Silicon (MPS)
Eliminate swap pressure on unified memory systems by: - Force-destroying model parameters via meta device after use - Flushing MPS allocator cache per sampling step - Preserving small models (<1GB, e.g. VAE) via size threshold - Lifecycle callback system for execution cache invalidation Benchmarked on M5 Pro 48GB with FLUX.2 Dev 32B GGUF: - Latency: 50 min → 20 min per image (2.5× improvement) - Stability: 4+ consecutive generations without OOM
1 parent cfcd334 commit 2c8db00

4 files changed

Lines changed: 141 additions & 5 deletions

File tree

comfy/cli_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def from_string(cls, value: str):
158158
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
159159

160160
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
161+
parser.add_argument("--aggressive-offload", action="store_true", help="Aggressively free models from RAM after use. Designed for Apple Silicon where CPU RAM and GPU VRAM are the same physical memory. Frees ~18GB during sampling by unloading text encoders after encoding. Trade-off: ~10s reload penalty per subsequent generation.")
161162
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
162163

163164
class PerformanceFeature(enum.Enum):

comfy/model_management.py

Lines changed: 117 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -465,10 +465,51 @@ def aotriton_supported(gpu_arch):
465465
logging.info(f"Set vram state to: {vram_state.name}")
466466

467467
DISABLE_SMART_MEMORY = args.disable_smart_memory
468+
AGGRESSIVE_OFFLOAD = args.aggressive_offload
468469

469470
if DISABLE_SMART_MEMORY:
470471
logging.info("Disabling smart memory management")
471472

473+
if AGGRESSIVE_OFFLOAD:
474+
logging.info("Aggressive offload enabled: models will be freed from RAM after use (designed for Apple Silicon)")
475+
476+
# ---------------------------------------------------------------------------
477+
# Model lifecycle callbacks — on_model_destroyed
478+
# ---------------------------------------------------------------------------
479+
# Why not comfy.hooks? The existing hook system (comfy/hooks.py) is scoped
480+
# to *sampling conditioning* — LoRA weight injection, transformer_options,
481+
# and keyframe scheduling. It has no concept of model-management lifecycle
482+
# events such as "a model's parameters were deallocated".
483+
#
484+
# This lightweight callback list fills that gap. It is intentionally minimal
485+
# (append-only, no priorities, no removal) because the only current consumer
486+
# is the execution-engine cache invalidator registered in PromptExecutor.
487+
# If upstream adopts a formal lifecycle-event bus in the future, these
488+
# callbacks should migrate to that system.
489+
# ---------------------------------------------------------------------------
490+
_on_model_destroyed_callbacks: list = []
491+
492+
493+
def register_model_destroyed_callback(callback):
494+
"""Register a listener for post-destruction lifecycle events.
495+
496+
After ``free_memory`` moves one or more models to the ``meta`` device
497+
(aggressive offload), every registered callback is invoked once with a
498+
*reason* string describing the batch (e.g. ``"batch"``).
499+
500+
Typical usage — executed by ``PromptExecutor.__init__``::
501+
502+
def _invalidate(reason):
503+
executor.caches.outputs.clear_all()
504+
register_model_destroyed_callback(_invalidate)
505+
506+
Args:
507+
callback: ``Callable[[str], None]`` — receives a human-readable
508+
reason string. Must be safe to call from within the
509+
``free_memory`` critical section (no heavy I/O, no model loads).
510+
"""
511+
_on_model_destroyed_callbacks.append(callback)
512+
472513
def get_torch_device_name(device):
473514
if hasattr(device, 'type'):
474515
if device.type == "cuda":
@@ -640,14 +681,20 @@ def offloaded_memory(loaded_models, device):
640681
WINDOWS = any(platform.win32_ver())
641682

642683
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
643-
if WINDOWS:
684+
if cpu_state == CPUState.MPS:
685+
# macOS with Apple Silicon: shared memory means OS needs more headroom.
686+
# Reserve 4 GB for macOS + system services to prevent swap thrashing.
687+
EXTRA_RESERVED_VRAM = 4 * 1024 * 1024 * 1024
688+
logging.info("MPS detected: reserving 4 GB for macOS system overhead")
689+
elif WINDOWS:
644690
import comfy.windows
645691
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
646692
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
647693
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
648694
def get_free_ram():
649695
return comfy.windows.get_free_ram()
650-
else:
696+
697+
if not WINDOWS:
651698
def get_free_ram():
652699
return psutil.virtual_memory().available
653700

@@ -669,14 +716,25 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
669716

670717
for i in range(len(current_loaded_models) -1, -1, -1):
671718
shift_model = current_loaded_models[i]
672-
if device is None or shift_model.device == device:
719+
# On Apple Silicon SHARED mode, CPU RAM == GPU VRAM (same physical memory).
720+
# Bypass the device filter so CPU-loaded models (like CLIP) can be freed.
721+
device_match = (device is None or shift_model.device == device)
722+
if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED:
723+
device_match = True
724+
if device_match:
673725
if shift_model not in keep_loaded and not shift_model.is_dead():
674726
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
675727
shift_model.currently_used = False
676728

677729
can_unload_sorted = sorted(can_unload)
730+
# Collect models to destroy via meta device AFTER the unload loop completes,
731+
# so we don't kill weakrefs of models still being iterated.
732+
_meta_destroy_queue = []
678733
for x in can_unload_sorted:
679734
i = x[-1]
735+
# Guard: weakref may already be dead from a previous iteration
736+
if current_loaded_models[i].model is None:
737+
continue
680738
memory_to_free = 1e32
681739
pins_to_free = 1e32
682740
if not DISABLE_SMART_MEMORY or device is None:
@@ -687,15 +745,66 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
687745
#as that works on-demand.
688746
memory_required -= current_loaded_models[i].model.loaded_size()
689747
memory_to_free = 0
748+
749+
# Aggressive offload for Apple Silicon: force-unload unused models
750+
# regardless of free memory, since CPU RAM == GPU VRAM.
751+
if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED:
752+
if not current_loaded_models[i].currently_used:
753+
memory_to_free = 1e32 # Force unload
754+
model_name = current_loaded_models[i].model.model.__class__.__name__
755+
model_size_mb = current_loaded_models[i].model_memory() / (1024 * 1024)
756+
logging.info(f"[aggressive-offload] Force-unloading {model_name} ({model_size_mb:.0f} MB) from shared RAM")
757+
690758
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
691759
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
760+
# Queue for meta device destruction after loop completes.
761+
# Only destroy large models (>1 GB) — small models like the VAE (160 MB)
762+
# are kept because the execution cache may reuse their patcher across
763+
# workflow nodes (e.g. vae_loader is cached while vae_decode runs later).
764+
if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED:
765+
if current_loaded_models[i].model is not None:
766+
model_size = current_loaded_models[i].model_memory()
767+
if model_size > 1024 * 1024 * 1024: # Only meta-destroy models > 1 GB
768+
_meta_destroy_queue.append(i)
692769
unloaded_model.append(i)
693770
if pins_to_free > 0:
694-
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
695-
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
771+
if current_loaded_models[i].model is not None:
772+
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
773+
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
774+
775+
# --- Phase 2: Deferred meta-device destruction -------------------------
776+
# Move parameters of queued models to the 'meta' device. This replaces
777+
# every nn.Parameter with a zero-storage meta tensor, releasing physical
778+
# RAM on unified-memory systems (Apple Silicon). The operation is
779+
# deferred until *after* the unload loop to avoid invalidating weakrefs
780+
# that other iterations may still reference.
781+
for i in _meta_destroy_queue:
782+
try:
783+
model_ref = current_loaded_models[i].model
784+
if model_ref is None:
785+
continue
786+
inner_model = model_ref.model
787+
model_name = inner_model.__class__.__name__
788+
param_count = sum(p.numel() * p.element_size() for p in inner_model.parameters())
789+
inner_model.to(device="meta")
790+
logging.info(f"[aggressive-offload] Moved {model_name} params to meta device, freed {param_count / (1024**2):.0f} MB")
791+
except Exception as e:
792+
logging.warning(f"[aggressive-offload] Failed to move model to meta: {e}")
793+
794+
# --- Phase 3: Notify lifecycle listeners --------------------------------
795+
# Fire on_model_destroyed callbacks *once* after the entire batch has been
796+
# processed, not per-model. This lets the execution engine clear its
797+
# output cache in a single operation (see PromptExecutor.__init__).
798+
if _meta_destroy_queue and _on_model_destroyed_callbacks:
799+
for cb in _on_model_destroyed_callbacks:
800+
cb("batch")
801+
logging.info(f"[aggressive-offload] Invalidated execution cache after destroying {len(_meta_destroy_queue)} model(s)")
696802

697803
for x in can_unload_sorted:
698804
i = x[-1]
805+
# Guard: weakref may be dead after cache invalidation (meta device move)
806+
if current_loaded_models[i].model is None:
807+
continue
699808
ram_to_free = ram_required - psutil.virtual_memory().available
700809
if ram_to_free <= 0 and i not in unloaded_model:
701810
continue
@@ -708,6 +817,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
708817

709818
if len(unloaded_model) > 0:
710819
soft_empty_cache()
820+
if AGGRESSIVE_OFFLOAD:
821+
gc.collect() # Force Python GC to release model tensors
822+
soft_empty_cache() # Second pass to free MPS allocator cache
711823
elif device is not None:
712824
if vram_state != VRAMState.HIGH_VRAM:
713825
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)

comfy/samplers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,18 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N
748748
if callback is not None:
749749
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
750750

751+
# On Apple Silicon MPS, flush the allocator pool between steps to prevent
752+
# progressive memory fragmentation and swap thrashing. Wrapping the callback
753+
# here (rather than patching individual samplers) covers all sampler variants.
754+
import comfy.model_management
755+
if noise.device.type == "mps" and getattr(comfy.model_management, "AGGRESSIVE_OFFLOAD", False):
756+
_inner_callback = k_callback
757+
def _mps_flush_callback(x):
758+
if _inner_callback is not None:
759+
_inner_callback(x)
760+
torch.mps.empty_cache()
761+
k_callback = _mps_flush_callback
762+
751763
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
752764
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
753765
return samples

execution.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,17 @@ def __init__(self, server, cache_type=False, cache_args=None):
651651
self.cache_type = cache_type
652652
self.server = server
653653
self.reset()
654+
# Register callback so model_management can invalidate cached outputs
655+
# after destroying a model via meta device move (aggressive offload).
656+
# NOTE: self.caches is resolved at call time (not capture time), so this
657+
# callback remains valid even if reset() replaces self.caches later.
658+
import comfy.model_management as mm
659+
if mm.AGGRESSIVE_OFFLOAD:
660+
executor = self
661+
def _invalidate_cache(reason):
662+
logging.info(f"[aggressive-offload] Invalidating execution cache ({reason})")
663+
executor.caches.outputs.clear_all()
664+
mm.register_model_destroyed_callback(_invalidate_cache)
654665

655666
def reset(self):
656667
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)

0 commit comments

Comments
 (0)