Skip to content

Commit 07c3887

Browse files
committed
test: add unit tests for --aggressive-offload (12 tests)
1 parent 2c8db00 commit 07c3887

1 file changed

Lines changed: 254 additions & 0 deletions

File tree

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
"""Tests for the aggressive-offload memory management feature.
2+
3+
These tests validate the Apple Silicon (MPS) memory optimisation path without
4+
requiring a GPU or actual model weights. Every test mocks the relevant model
5+
and cache structures so the suite can run in CI on any platform.
6+
"""
7+
8+
import pytest
9+
import types
10+
import torch
11+
import torch.nn as nn
12+
13+
# ---------------------------------------------------------------------------
14+
# Fixtures & helpers
15+
# ---------------------------------------------------------------------------
16+
17+
class FakeLinearModel(nn.Module):
18+
"""Minimal nn.Module whose parameters consume measurable memory."""
19+
20+
def __init__(self, size_mb: float = 2.0):
21+
super().__init__()
22+
# Each float32 param = 4 bytes, so `n` params ≈ size_mb * 1024² / 4
23+
n = int(size_mb * 1024 * 1024 / 4)
24+
self.weight = nn.Parameter(torch.zeros(n, dtype=torch.float32))
25+
26+
27+
class FakeModelPatcher:
28+
"""Mimics the subset of ModelPatcher used by model_management.free_memory."""
29+
30+
def __init__(self, size_mb: float = 2.0):
31+
self.model = FakeLinearModel(size_mb)
32+
self._loaded_size = int(size_mb * 1024 * 1024)
33+
34+
def loaded_size(self):
35+
return self._loaded_size
36+
37+
def is_dynamic(self):
38+
return False
39+
40+
41+
class FakeLoadedModel:
42+
"""Mimics LoadedModel entries in current_loaded_models."""
43+
44+
def __init__(self, patcher: FakeModelPatcher, *, currently_used: bool = False):
45+
self._model = patcher
46+
self.currently_used = currently_used
47+
48+
@property
49+
def model(self):
50+
return self._model
51+
52+
def model_memory(self):
53+
return self._model.loaded_size()
54+
55+
def model_unload(self, _memory_to_free):
56+
return True
57+
58+
def model_load(self, _device, _keep_loaded):
59+
pass
60+
61+
62+
# ---------------------------------------------------------------------------
63+
# 1. BasicCache.clear_all()
64+
# ---------------------------------------------------------------------------
65+
66+
class TestBasicCacheClearAll:
67+
"""Verify that BasicCache.clear_all() is a proper public API."""
68+
69+
def test_clear_all_empties_cache_and_subcaches(self):
70+
"""clear_all() must remove every entry in both dicts."""
71+
from comfy_execution.caching import BasicCache, CacheKeySetInputSignature
72+
73+
cache = BasicCache(CacheKeySetInputSignature)
74+
cache.cache["key1"] = "value1"
75+
cache.cache["key2"] = "value2"
76+
cache.subcaches["sub1"] = "subvalue1"
77+
78+
cache.clear_all()
79+
80+
assert len(cache.cache) == 0
81+
assert len(cache.subcaches) == 0
82+
83+
def test_clear_all_is_idempotent(self):
84+
"""Calling clear_all() on an already-empty cache must not raise."""
85+
from comfy_execution.caching import BasicCache, CacheKeySetInputSignature
86+
87+
cache = BasicCache(CacheKeySetInputSignature)
88+
cache.clear_all() # should be a no-op
89+
cache.clear_all() # still a no-op
90+
91+
assert len(cache.cache) == 0
92+
93+
94+
# ---------------------------------------------------------------------------
95+
# 2. Callback registration & dispatch
96+
# ---------------------------------------------------------------------------
97+
98+
class TestModelDestroyedCallbacks:
99+
"""Validate the on_model_destroyed lifecycle callback system."""
100+
101+
def setup_method(self):
102+
"""Reset the callback list before every test."""
103+
import comfy.model_management as mm
104+
self._original = mm._on_model_destroyed_callbacks.copy()
105+
mm._on_model_destroyed_callbacks.clear()
106+
107+
def teardown_method(self):
108+
"""Restore the original callback list."""
109+
import comfy.model_management as mm
110+
mm._on_model_destroyed_callbacks.clear()
111+
mm._on_model_destroyed_callbacks.extend(self._original)
112+
113+
def test_register_single_callback(self):
114+
import comfy.model_management as mm
115+
116+
invocations = []
117+
mm.register_model_destroyed_callback(lambda reason: invocations.append(reason))
118+
119+
assert len(mm._on_model_destroyed_callbacks) == 1
120+
121+
# Simulate dispatch
122+
for cb in mm._on_model_destroyed_callbacks:
123+
cb("test")
124+
assert invocations == ["test"]
125+
126+
def test_register_multiple_callbacks(self):
127+
"""Multiple registrants must all fire — no silent overwrites."""
128+
import comfy.model_management as mm
129+
130+
results_a, results_b = [], []
131+
mm.register_model_destroyed_callback(lambda r: results_a.append(r))
132+
mm.register_model_destroyed_callback(lambda r: results_b.append(r))
133+
134+
for cb in mm._on_model_destroyed_callbacks:
135+
cb("batch")
136+
137+
assert results_a == ["batch"]
138+
assert results_b == ["batch"]
139+
140+
def test_callback_receives_reason_string(self):
141+
"""The callback signature is (reason: str) -> None."""
142+
import comfy.model_management as mm
143+
144+
captured = {}
145+
def _cb(reason):
146+
captured["reason"] = reason
147+
captured["type"] = type(reason).__name__
148+
149+
mm.register_model_destroyed_callback(_cb)
150+
for cb in mm._on_model_destroyed_callbacks:
151+
cb("batch")
152+
153+
assert captured["reason"] == "batch"
154+
assert captured["type"] == "str"
155+
156+
157+
# ---------------------------------------------------------------------------
158+
# 3. Meta-device destruction threshold
159+
# ---------------------------------------------------------------------------
160+
161+
class TestMetaDeviceThreshold:
162+
"""Verify that only models > 1 GB are queued for meta-device destruction."""
163+
164+
def test_small_model_not_destroyed(self):
165+
"""A 160 MB model (VAE-sized) must NOT be moved to meta device."""
166+
model = FakeLinearModel(size_mb=160)
167+
168+
# Simulate the threshold check from free_memory
169+
model_size = sum(p.numel() * p.element_size() for p in model.parameters())
170+
threshold = 1024 * 1024 * 1024 # 1 GB
171+
172+
assert model_size < threshold, (
173+
f"160 MB model should be below 1 GB threshold, got {model_size / (1024**2):.0f} MB"
174+
)
175+
# Confirm parameters are still on a real device
176+
assert model.weight.device.type != "meta"
177+
178+
def test_large_model_above_threshold(self):
179+
"""A 2 GB model (UNET/CLIP-sized) must BE above the destruction threshold."""
180+
model = FakeLinearModel(size_mb=2048)
181+
182+
model_size = sum(p.numel() * p.element_size() for p in model.parameters())
183+
threshold = 1024 * 1024 * 1024 # 1 GB
184+
185+
assert model_size > threshold, (
186+
f"2 GB model should be above 1 GB threshold, got {model_size / (1024**2):.0f} MB"
187+
)
188+
189+
def test_meta_device_move_releases_storage(self):
190+
"""Moving parameters to 'meta' must place them on the meta device."""
191+
model = FakeLinearModel(size_mb=2)
192+
assert model.weight.device.type != "meta"
193+
194+
model.to(device="meta")
195+
196+
assert model.weight.device.type == "meta"
197+
# Meta tensors retain their logical shape but live on a virtual device
198+
# with no physical backing — this is what releases RAM.
199+
assert model.weight.nelement() > 0 # still has logical shape
200+
assert model.weight.untyped_storage().device.type == "meta"
201+
202+
203+
# ---------------------------------------------------------------------------
204+
# 4. MPS flush conditionality
205+
# ---------------------------------------------------------------------------
206+
207+
class TestMpsFlushConditionality:
208+
"""Verify the MPS flush only activates under correct conditions."""
209+
210+
def test_flush_requires_aggressive_offload_flag(self):
211+
"""The MPS flush in samplers is gated on AGGRESSIVE_OFFLOAD."""
212+
import comfy.model_management as mm
213+
214+
# When False, flush should NOT be injected
215+
original = getattr(mm, "AGGRESSIVE_OFFLOAD", False)
216+
try:
217+
mm.AGGRESSIVE_OFFLOAD = False
218+
assert not (True and getattr(mm, "AGGRESSIVE_OFFLOAD", False))
219+
220+
mm.AGGRESSIVE_OFFLOAD = True
221+
assert (True and getattr(mm, "AGGRESSIVE_OFFLOAD", False))
222+
finally:
223+
mm.AGGRESSIVE_OFFLOAD = original
224+
225+
def test_flush_requires_mps_device(self):
226+
"""The flush condition checks device.type == 'mps'."""
227+
# Simulate CPU device — flush should not activate
228+
cpu_device = torch.device("cpu")
229+
assert cpu_device.type != "mps"
230+
231+
# Simulate MPS device string check
232+
if torch.backends.mps.is_available():
233+
mps_device = torch.device("mps")
234+
assert mps_device.type == "mps"
235+
236+
237+
# ---------------------------------------------------------------------------
238+
# 5. AGGRESSIVE_OFFLOAD flag integration
239+
# ---------------------------------------------------------------------------
240+
241+
class TestAggressiveOffloadFlag:
242+
"""Verify the CLI flag is correctly exposed."""
243+
244+
def test_flag_exists_in_model_management(self):
245+
"""AGGRESSIVE_OFFLOAD must be importable from model_management."""
246+
import comfy.model_management as mm
247+
assert hasattr(mm, "AGGRESSIVE_OFFLOAD")
248+
assert isinstance(mm.AGGRESSIVE_OFFLOAD, bool)
249+
250+
def test_flag_defaults_from_cli_args(self):
251+
"""The flag should be sourced from cli_args."""
252+
import comfy.cli_args as cli_args
253+
assert hasattr(cli_args, "args")
254+
assert hasattr(cli_args.args, "aggressive_offload")

0 commit comments

Comments
 (0)