|
| 1 | +import comfy_aimdo.model_vbar |
| 2 | +import comfy.model_management |
| 3 | +import comfy.ops |
| 4 | + |
| 5 | +PREFETCH_QUEUES = [] |
| 6 | + |
| 7 | +def cleanup_prefetched_modules(comfy_modules): |
| 8 | + for s in comfy_modules: |
| 9 | + prefetch = getattr(s, "_prefetch", None) |
| 10 | + if prefetch is None: |
| 11 | + continue |
| 12 | + for param_key in ("weight", "bias"): |
| 13 | + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) |
| 14 | + if lowvram_fn is not None: |
| 15 | + lowvram_fn.clear_prepared() |
| 16 | + if prefetch["signature"] is not None: |
| 17 | + comfy_aimdo.model_vbar.vbar_unpin(s._v) |
| 18 | + delattr(s, "_prefetch") |
| 19 | + |
| 20 | +def cleanup_prefetch_queues(): |
| 21 | + global PREFETCH_QUEUES |
| 22 | + |
| 23 | + for queue in PREFETCH_QUEUES: |
| 24 | + for entry in queue: |
| 25 | + if entry is None or not isinstance(entry, tuple): |
| 26 | + continue |
| 27 | + _, prefetch_state = entry |
| 28 | + comfy_modules = prefetch_state[1] |
| 29 | + if comfy_modules is not None: |
| 30 | + cleanup_prefetched_modules(comfy_modules) |
| 31 | + PREFETCH_QUEUES = [] |
| 32 | + |
| 33 | +def prefetch_queue_pop(queue, device, module): |
| 34 | + if queue is None: |
| 35 | + return |
| 36 | + |
| 37 | + consumed = queue.pop(0) |
| 38 | + if consumed is not None: |
| 39 | + offload_stream, prefetch_state = consumed |
| 40 | + offload_stream.wait_stream(comfy.model_management.current_stream(device)) |
| 41 | + _, comfy_modules = prefetch_state |
| 42 | + if comfy_modules is not None: |
| 43 | + cleanup_prefetched_modules(comfy_modules) |
| 44 | + |
| 45 | + prefetch = queue[0] |
| 46 | + if prefetch is not None: |
| 47 | + comfy_modules = [] |
| 48 | + for s in prefetch.modules(): |
| 49 | + if hasattr(s, "_v"): |
| 50 | + comfy_modules.append(s) |
| 51 | + |
| 52 | + offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True) |
| 53 | + comfy.model_management.sync_stream(device, offload_stream) |
| 54 | + queue[0] = (offload_stream, (prefetch, comfy_modules)) |
| 55 | + |
| 56 | +def make_prefetch_queue(queue, device, transformer_options): |
| 57 | + if (not transformer_options.get("prefetch_dynamic_vbars", False) |
| 58 | + or comfy.model_management.NUM_STREAMS == 0 |
| 59 | + or comfy.model_management.is_device_cpu(device) |
| 60 | + or not comfy.model_management.device_supports_non_blocking(device)): |
| 61 | + return None |
| 62 | + |
| 63 | + queue = [None] + queue + [None] |
| 64 | + PREFETCH_QUEUES.append(queue) |
| 65 | + return queue |
0 commit comments