Skip to content

Commit 552f3a6

Browse files
authored
Merge branch 'master' into feat/api-nodes/TopazVideo-Astra2
2 parents d247e70 + ef6722f commit 552f3a6

9 files changed

Lines changed: 261 additions & 76 deletions

File tree

comfy/ldm/lightricks/av_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
1717
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
1818
import comfy.ldm.common_dit
19+
import comfy.model_prefetch
1920

2021
class CompressedTimestep:
2122
"""Store video timestep embeddings in compressed form using per-frame indexing."""
@@ -907,9 +908,11 @@ def _process_transformer_blocks(
907908
"""Process transformer blocks for LTXAV."""
908909
patches_replace = transformer_options.get("patches_replace", {})
909910
blocks_replace = patches_replace.get("dit", {})
911+
prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options)
910912

911913
# Process transformer blocks
912914
for i, block in enumerate(self.transformer_blocks):
915+
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
913916
if ("double_block", i) in blocks_replace:
914917

915918
def block_wrap(args):
@@ -982,6 +985,8 @@ def block_wrap(args):
982985
a_prompt_timestep=a_prompt_timestep,
983986
)
984987

988+
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
989+
985990
return [vx, ax]
986991

987992
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):

comfy/lora.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""
1818

1919
from __future__ import annotations
20+
import comfy.memory_management
2021
import comfy.utils
2122
import comfy.model_management
2223
import comfy.model_base
@@ -473,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
473474
weight = old_weight
474475

475476
return weight
477+
478+
def prefetch_prepared_value(value, allocate_buffer, stream):
479+
if isinstance(value, torch.Tensor):
480+
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
481+
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
482+
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
483+
elif isinstance(value, weight_adapter.WeightAdapterBase):
484+
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
485+
elif isinstance(value, tuple):
486+
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
487+
elif isinstance(value, list):
488+
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
489+
490+
return value

comfy/model_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,11 @@ def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, tran
214214
if "latent_shapes" in extra_conds:
215215
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
216216

217+
transformer_options = transformer_options.copy()
218+
transformer_options["prefetch_dynamic_vbars"] = (
219+
self.current_patcher is not None and self.current_patcher.is_dynamic()
220+
)
221+
217222
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
218223
if len(model_output) > 1 and not torch.is_tensor(model_output):
219224
model_output, _ = utils.pack_latents(model_output)

comfy/model_management.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import comfy.memory_management
3232
import comfy.utils
3333
import comfy.quant_ops
34+
import comfy_aimdo.vram_buffer
3435

3536
class VRAMState(Enum):
3637
DISABLED = 0 #No vram present: no need to move models to vram
@@ -1175,6 +1176,10 @@ def current_stream(device):
11751176

11761177
STREAM_CAST_BUFFERS = {}
11771178
LARGEST_CASTED_WEIGHT = (None, 0)
1179+
STREAM_AIMDO_CAST_BUFFERS = {}
1180+
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
1181+
1182+
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
11781183

11791184
def get_cast_buffer(offload_stream, device, size, ref):
11801185
global LARGEST_CASTED_WEIGHT
@@ -1208,13 +1213,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
12081213

12091214
return cast_buffer
12101215

1216+
def get_aimdo_cast_buffer(offload_stream, device):
1217+
cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
1218+
if cast_buffer is None:
1219+
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
1220+
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
1221+
1222+
return cast_buffer
12111223
def reset_cast_buffers():
12121224
global LARGEST_CASTED_WEIGHT
1225+
global LARGEST_AIMDO_CASTED_WEIGHT
1226+
12131227
LARGEST_CASTED_WEIGHT = (None, 0)
1214-
for offload_stream in STREAM_CAST_BUFFERS:
1215-
offload_stream.synchronize()
1228+
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
1229+
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
1230+
if offload_stream is not None:
1231+
offload_stream.synchronize()
12161232
synchronize()
1233+
12171234
STREAM_CAST_BUFFERS.clear()
1235+
STREAM_AIMDO_CAST_BUFFERS.clear()
12181236
soft_empty_cache()
12191237

12201238
def get_offload_stream(device):

comfy/model_patcher.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,20 @@ def __init__(self, key, patches, convert_func=None, set_func=None):
121121
self.patches = patches
122122
self.convert_func = convert_func # TODO: remove
123123
self.set_func = set_func
124+
self.prepared_patches = None
125+
126+
def prepare(self, allocate_buffer, stream):
127+
self.prepared_patches = [
128+
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
129+
for patch in self.patches[self.key]
130+
]
131+
132+
def clear_prepared(self):
133+
self.prepared_patches = None
124134

125135
def __call__(self, weight):
126-
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
136+
patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
137+
return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
127138

128139
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
129140

comfy/model_prefetch.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import comfy_aimdo.model_vbar
2+
import comfy.model_management
3+
import comfy.ops
4+
5+
PREFETCH_QUEUES = []
6+
7+
def cleanup_prefetched_modules(comfy_modules):
8+
for s in comfy_modules:
9+
prefetch = getattr(s, "_prefetch", None)
10+
if prefetch is None:
11+
continue
12+
for param_key in ("weight", "bias"):
13+
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
14+
if lowvram_fn is not None:
15+
lowvram_fn.clear_prepared()
16+
if prefetch["signature"] is not None:
17+
comfy_aimdo.model_vbar.vbar_unpin(s._v)
18+
delattr(s, "_prefetch")
19+
20+
def cleanup_prefetch_queues():
21+
global PREFETCH_QUEUES
22+
23+
for queue in PREFETCH_QUEUES:
24+
for entry in queue:
25+
if entry is None or not isinstance(entry, tuple):
26+
continue
27+
_, prefetch_state = entry
28+
comfy_modules = prefetch_state[1]
29+
if comfy_modules is not None:
30+
cleanup_prefetched_modules(comfy_modules)
31+
PREFETCH_QUEUES = []
32+
33+
def prefetch_queue_pop(queue, device, module):
34+
if queue is None:
35+
return
36+
37+
consumed = queue.pop(0)
38+
if consumed is not None:
39+
offload_stream, prefetch_state = consumed
40+
offload_stream.wait_stream(comfy.model_management.current_stream(device))
41+
_, comfy_modules = prefetch_state
42+
if comfy_modules is not None:
43+
cleanup_prefetched_modules(comfy_modules)
44+
45+
prefetch = queue[0]
46+
if prefetch is not None:
47+
comfy_modules = []
48+
for s in prefetch.modules():
49+
if hasattr(s, "_v"):
50+
comfy_modules.append(s)
51+
52+
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
53+
comfy.model_management.sync_stream(device, offload_stream)
54+
queue[0] = (offload_stream, (prefetch, comfy_modules))
55+
56+
def make_prefetch_queue(queue, device, transformer_options):
57+
if (not transformer_options.get("prefetch_dynamic_vbars", False)
58+
or comfy.model_management.NUM_STREAMS == 0
59+
or comfy.model_management.is_device_cpu(device)
60+
or not comfy.model_management.device_supports_non_blocking(device)):
61+
return None
62+
63+
queue = [None] + queue + [None]
64+
PREFETCH_QUEUES.append(queue)
65+
return queue

0 commit comments

Comments
 (0)