Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def from_string(cls, value: str):
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")

parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--aggressive-offload", action="store_true", help="Aggressively free models from RAM after use. Designed for Apple Silicon where CPU RAM and GPU VRAM are the same physical memory. Moves all models larger than 1 GB to a virtual (meta) device between runs, preventing swap pressure on disk. Small models like the VAE are preserved. Trade-off: models are reloaded from disk on subsequent generations.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")

class PerformanceFeature(enum.Enum):
Expand Down
129 changes: 124 additions & 5 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,10 +465,51 @@ def aotriton_supported(gpu_arch):
logging.info(f"Set vram state to: {vram_state.name}")

DISABLE_SMART_MEMORY = args.disable_smart_memory
AGGRESSIVE_OFFLOAD = args.aggressive_offload

if DISABLE_SMART_MEMORY:
logging.info("Disabling smart memory management")

if AGGRESSIVE_OFFLOAD:
logging.info("Aggressive offload enabled: models will be freed from RAM after use (designed for Apple Silicon)")

# ---------------------------------------------------------------------------
# Model lifecycle callbacks — on_model_destroyed
# ---------------------------------------------------------------------------
# Why not comfy.hooks? The existing hook system (comfy/hooks.py) is scoped
# to *sampling conditioning* — LoRA weight injection, transformer_options,
# and keyframe scheduling. It has no concept of model-management lifecycle
# events such as "a model's parameters were deallocated".
#
# This lightweight callback list fills that gap. It is intentionally minimal
# (append-only, no priorities, no removal) because the only current consumer
# is the execution-engine cache invalidator registered in PromptExecutor.
# If upstream adopts a formal lifecycle-event bus in the future, these
# callbacks should migrate to that system.
# ---------------------------------------------------------------------------
_on_model_destroyed_callbacks: list = []


def register_model_destroyed_callback(callback):
"""Register a listener for post-destruction lifecycle events.

After ``free_memory`` moves one or more models to the ``meta`` device
(aggressive offload), every registered callback is invoked once with a
*reason* string describing the batch (e.g. ``"batch"``).

Typical usage — executed by ``PromptExecutor.__init__``::

def _invalidate(reason):
executor.caches.outputs.clear_all()
register_model_destroyed_callback(_invalidate)

Args:
callback: ``Callable[[str], None]`` — receives a human-readable
reason string. Must be safe to call from within the
``free_memory`` critical section (no heavy I/O, no model loads).
"""
_on_model_destroyed_callbacks.append(callback)
Comment thread
uxtechie marked this conversation as resolved.

def get_torch_device_name(device):
if hasattr(device, 'type'):
if device.type == "cuda":
Expand Down Expand Up @@ -640,14 +681,21 @@ def offloaded_memory(loaded_models, device):
WINDOWS = any(platform.win32_ver())

EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS:
if cpu_state == CPUState.MPS and AGGRESSIVE_OFFLOAD:
# macOS with Apple Silicon + aggressive offload: shared memory means OS
# needs more headroom. Reserve 4 GB for macOS + system services to
# prevent swap thrashing during model destruction/reload cycles.
EXTRA_RESERVED_VRAM = 4 * 1024 * 1024 * 1024
logging.info("MPS detected with --aggressive-offload: reserving 4 GB for macOS system overhead")
elif WINDOWS:
import comfy.windows
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
def get_free_ram():
return comfy.windows.get_free_ram()
else:

if not WINDOWS:
def get_free_ram():
return psutil.virtual_memory().available

Expand All @@ -669,14 +717,25 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins

for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
if device is None or shift_model.device == device:
# On Apple Silicon SHARED mode, CPU RAM == GPU VRAM (same physical memory).
# Bypass the device filter so CPU-loaded models (like CLIP) can be freed.
device_match = (device is None or shift_model.device == device)
if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED:
device_match = True
if device_match:
if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False

can_unload_sorted = sorted(can_unload)
# Collect models to destroy via meta device AFTER the unload loop completes,
# so we don't kill weakrefs of models still being iterated.
_meta_destroy_queue = []
for x in can_unload_sorted:
i = x[-1]
# Guard: weakref may already be dead from a previous iteration
if current_loaded_models[i].model is None:
continue
memory_to_free = 1e32
pins_to_free = 1e32
if not DISABLE_SMART_MEMORY or device is None:
Expand All @@ -687,15 +746,72 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
#as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size()
memory_to_free = 0

# Aggressive offload for Apple Silicon: force-unload unused models
# regardless of free memory, since CPU RAM == GPU VRAM.
# Only force-unload models > 1 GB — small models like the VAE (160 MB)
# are preserved to avoid unnecessary reload from disk.
if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED:
model_ref = current_loaded_models[i].model
if model_ref is not None and not current_loaded_models[i].currently_used:
model_size = current_loaded_models[i].model_memory()
if model_size > 1024 * 1024 * 1024: # 1 GB threshold
memory_to_free = 1e32 # Force unload
inner = getattr(model_ref, "model", None)
model_name = inner.__class__.__name__ if inner is not None else "unknown"
model_size_mb = model_size / (1024 * 1024)
logging.info(f"[aggressive-offload] Force-unloading {model_name} ({model_size_mb:.0f} MB) from shared RAM")

Comment thread
coderabbitai[bot] marked this conversation as resolved.
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
# Queue for meta device destruction after loop completes.
# Only destroy large models (>1 GB) — small models like the VAE (160 MB)
# are kept because the execution cache may reuse their patcher across
# workflow nodes (e.g. vae_loader is cached while vae_decode runs later).
if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED:
if current_loaded_models[i].model is not None:
model_size = current_loaded_models[i].model_memory()
if model_size > 1024 * 1024 * 1024: # Only meta-destroy models > 1 GB
_meta_destroy_queue.append(i)
unloaded_model.append(i)
if pins_to_free > 0:
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
if current_loaded_models[i].model is not None:
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(pins_to_free)

# --- Phase 2: Deferred meta-device destruction -------------------------
# Move parameters of queued models to the 'meta' device. This replaces
# every nn.Parameter with a zero-storage meta tensor, releasing physical
# RAM on unified-memory systems (Apple Silicon). The operation is
# deferred until *after* the unload loop to avoid invalidating weakrefs
# that other iterations may still reference.
for i in _meta_destroy_queue:
try:
model_ref = current_loaded_models[i].model
if model_ref is None:
continue
inner_model = model_ref.model
model_name = inner_model.__class__.__name__
param_count = sum(p.numel() * p.element_size() for p in inner_model.parameters())
inner_model.to(device="meta")
logging.info(f"[aggressive-offload] Moved {model_name} params to meta device, freed {param_count / (1024**2):.0f} MB")
except Exception as e:
logging.warning(f"[aggressive-offload] Failed to move model to meta: {e}")

# --- Phase 3: Notify lifecycle listeners --------------------------------
# Fire on_model_destroyed callbacks *once* after the entire batch has been
# processed, not per-model. This lets the execution engine clear its
# output cache in a single operation (see PromptExecutor.__init__).
if _meta_destroy_queue and _on_model_destroyed_callbacks:
for cb in _on_model_destroyed_callbacks:
cb("batch")
logging.info(f"[aggressive-offload] Invalidated execution cache after destroying {len(_meta_destroy_queue)} model(s)")

for x in can_unload_sorted:
i = x[-1]
# Guard: weakref may be dead after cache invalidation (meta device move)
if current_loaded_models[i].model is None:
continue
ram_to_free = ram_required - psutil.virtual_memory().available
if ram_to_free <= 0 and i not in unloaded_model:
continue
Expand All @@ -708,6 +824,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins

if len(unloaded_model) > 0:
soft_empty_cache()
if AGGRESSIVE_OFFLOAD:
gc.collect() # Force Python GC to release model tensors
soft_empty_cache() # Second pass to free MPS allocator cache
elif device is not None:
if vram_state != VRAMState.HIGH_VRAM:
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
Expand Down
12 changes: 12 additions & 0 deletions comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,18 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)

# On Apple Silicon MPS, flush the allocator pool between steps to prevent
# progressive memory fragmentation and swap thrashing. Wrapping the callback
# here (rather than patching individual samplers) covers all sampler variants.
import comfy.model_management
if noise.device.type == "mps" and getattr(comfy.model_management, "AGGRESSIVE_OFFLOAD", False):
_inner_callback = k_callback
def _mps_flush_callback(x):
if _inner_callback is not None:
_inner_callback(x)
torch.mps.empty_cache()
k_callback = _mps_flush_callback

samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
return samples
Expand Down
27 changes: 27 additions & 0 deletions comfy_execution/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,17 @@ def _clean_subcaches(self):
for key in to_remove:
del self.subcaches[key]

def clear_all(self):
"""Drop all cached outputs unconditionally.

This is the public API for external subsystems (e.g. aggressive model
offloading) that need to invalidate every cached result — for instance
after model parameters have been moved to the ``meta`` device and the
cached tensors are no longer usable.
"""
self.cache.clear()
self.subcaches.clear()
Comment thread
coderabbitai[bot] marked this conversation as resolved.

Comment thread
coderabbitai[bot] marked this conversation as resolved.
def clean_unused(self):
assert self.initialized
self._clean_cache()
Expand Down Expand Up @@ -417,6 +428,10 @@ def all_node_ids(self):
def clean_unused(self):
pass

def clear_all(self):
"""No-op: null backend has nothing to invalidate."""
pass

def poll(self, **kwargs):
pass

Expand Down Expand Up @@ -450,6 +465,13 @@ async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
for node_id in node_ids:
self._mark_used(node_id)

def clear_all(self):
"""Drop all cached outputs and reset LRU bookkeeping."""
super().clear_all()
self.used_generation.clear()
self.children.clear()
self.min_generation = 0

def clean_unused(self):
while len(self.cache) > self.max_size and self.min_generation < self.generation:
self.min_generation += 1
Expand Down Expand Up @@ -508,6 +530,11 @@ def __init__(self, key_class, enable_providers=False):
super().__init__(key_class, 0, enable_providers=enable_providers)
self.timestamps = {}

def clear_all(self):
"""Drop all cached outputs and reset RAM-pressure bookkeeping."""
super().clear_all()
self.timestamps.clear()

def clean_unused(self):
self._clean_subcaches()

Expand Down
11 changes: 11 additions & 0 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,17 @@ def __init__(self, server, cache_type=False, cache_args=None):
self.cache_type = cache_type
self.server = server
self.reset()
# Register callback so model_management can invalidate cached outputs
# after destroying a model via meta device move (aggressive offload).
# NOTE: self.caches is resolved at call time (not capture time), so this
# callback remains valid even if reset() replaces self.caches later.
import comfy.model_management as mm
if mm.AGGRESSIVE_OFFLOAD:
executor = self
def _invalidate_cache(reason):
logging.info(f"[aggressive-offload] Invalidating execution cache ({reason})")
executor.caches.outputs.clear_all()
mm.register_model_destroyed_callback(_invalidate_cache)

def reset(self):
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
Expand Down
Loading