Skip to content

Commit c0a7e10

Browse files
Implement persistent thread pool for multi-GPU CFG splitting
Replace per-step thread create/destroy in _calc_cond_batch_multigpu with a persistent MultiGPUThreadPool. Each worker thread calls torch.cuda.set_device() once at startup, preserving compiled kernel caches across diffusion steps. - Add MultiGPUThreadPool class in comfy/multigpu.py - Create pool in CFGGuider.outer_sample(), shut down in finally block - Main thread handles its own device batch directly for zero overhead - Falls back to sequential execution if no pool is available Co-authored-by: Amp <amp@ampcode.com> Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba
1 parent da38644 commit c0a7e10

2 files changed

Lines changed: 105 additions & 11 deletions

File tree

comfy/multigpu.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from __future__ import annotations
2+
import queue
3+
import threading
24
import torch
35
import logging
46

@@ -11,6 +13,67 @@
1113
import comfy.model_management
1214

1315

16+
class MultiGPUThreadPool:
17+
"""Persistent thread pool for multi-GPU work distribution.
18+
19+
Maintains one worker thread per extra GPU device. Each thread calls
20+
torch.cuda.set_device() once at startup so that compiled kernel caches
21+
(inductor/triton) stay warm across diffusion steps.
22+
"""
23+
24+
def __init__(self, devices: list[torch.device]):
25+
self._workers: list[threading.Thread] = []
26+
self._work_queues: dict[torch.device, queue.Queue] = {}
27+
self._result_queues: dict[torch.device, queue.Queue] = {}
28+
29+
for device in devices:
30+
wq = queue.Queue()
31+
rq = queue.Queue()
32+
self._work_queues[device] = wq
33+
self._result_queues[device] = rq
34+
t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
35+
t.start()
36+
self._workers.append(t)
37+
38+
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
39+
try:
40+
torch.cuda.set_device(device)
41+
except Exception as e:
42+
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
43+
while True:
44+
item = work_q.get()
45+
if item is None:
46+
return
47+
result_q.put((None, e))
48+
return
49+
while True:
50+
item = work_q.get()
51+
if item is None:
52+
break
53+
fn, args, kwargs = item
54+
try:
55+
result = fn(*args, **kwargs)
56+
result_q.put((result, None))
57+
except Exception as e:
58+
result_q.put((None, e))
59+
60+
def submit(self, device: torch.device, fn, *args, **kwargs):
61+
self._work_queues[device].put((fn, args, kwargs))
62+
63+
def get_result(self, device: torch.device):
64+
return self._result_queues[device].get()
65+
66+
@property
67+
def devices(self) -> list[torch.device]:
68+
return list(self._work_queues.keys())
69+
70+
def shutdown(self):
71+
for wq in self._work_queues.values():
72+
wq.put(None) # sentinel
73+
for t in self._workers:
74+
t.join(timeout=5.0)
75+
76+
1477
class GPUOptions:
1578
def __init__(self, device_index: int, relative_speed: float):
1679
self.device_index = device_index

comfy/samplers.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
import comfy.patcher_extension
1919
import comfy.hooks
2020
import comfy.context_windows
21+
import comfy.multigpu
2122
import comfy.utils
2223
import scipy.stats
2324
import numpy
24-
import threading
2525

2626

2727
def add_area_dims(area, num_dims):
@@ -509,15 +509,38 @@ def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup
509509
raise
510510

511511

512+
def _handle_batch_pooled(device, batch_tuple):
513+
worker_results = []
514+
_handle_batch(device, batch_tuple, worker_results)
515+
return worker_results
516+
512517
results: list[thread_result] = []
513-
threads: list[threading.Thread] = []
518+
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
519+
main_device = output_device
520+
main_batch_tuple = None
521+
522+
# Submit extra GPU work to pool first, then run main device on this thread
523+
pool_devices = []
514524
for device, batch_tuple in device_batched_hooked_to_run.items():
515-
new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results))
516-
threads.append(new_thread)
517-
new_thread.start()
525+
if device == main_device and thread_pool is not None:
526+
main_batch_tuple = batch_tuple
527+
elif thread_pool is not None:
528+
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
529+
pool_devices.append(device)
530+
else:
531+
# Fallback: no pool, run everything on main thread
532+
_handle_batch(device, batch_tuple, results)
518533

519-
for thread in threads:
520-
thread.join()
534+
# Run main device batch on this thread (parallel with pool workers)
535+
if main_batch_tuple is not None:
536+
_handle_batch(main_device, main_batch_tuple, results)
537+
538+
# Collect results from pool workers
539+
for device in pool_devices:
540+
worker_results, error = thread_pool.get_result(device)
541+
if error is not None:
542+
raise error
543+
results.extend(worker_results)
521544

522545
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
523546
if error is not None:
@@ -1187,17 +1210,25 @@ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None,
11871210

11881211
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
11891212

1190-
noise = noise.to(device=device, dtype=torch.float32)
1191-
latent_image = latent_image.to(device=device, dtype=torch.float32)
1192-
sigmas = sigmas.to(device)
1193-
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
1213+
# Create persistent thread pool for extra GPU devices
1214+
if multigpu_patchers:
1215+
extra_devices = [p.load_device for p in multigpu_patchers]
1216+
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(extra_devices)
11941217

11951218
try:
1219+
noise = noise.to(device=device, dtype=torch.float32)
1220+
latent_image = latent_image.to(device=device, dtype=torch.float32)
1221+
sigmas = sigmas.to(device)
1222+
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
1223+
11961224
self.model_patcher.pre_run()
11971225
for multigpu_patcher in multigpu_patchers:
11981226
multigpu_patcher.pre_run()
11991227
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
12001228
finally:
1229+
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
1230+
if thread_pool is not None:
1231+
thread_pool.shutdown()
12011232
self.model_patcher.cleanup()
12021233
for multigpu_patcher in multigpu_patchers:
12031234
multigpu_patcher.cleanup()

0 commit comments

Comments
 (0)