Skip to content

Commit 8cef3cf

Browse files
lsteinclaude
andcommitted
feat(multi-gpu): surface per-session GPU number in logs and UI
Help users track which CUDA device is processing each session: - Model-load log: "Loaded model ... onto cuda device #N in ..s" - Denoise progress bars: "Denoising (#N)" across all architectures (SD1.5/SDXL, FLUX, FLUX2, Z-Image, Anima, SD3, CogView4) - Progress preview circle: GPU number centered in the ring, via a new `device` field on InvocationProgressEvent (resolved from the worker's thread-local session device) - Session Queue: new "GPU #" column between STATUS and TIME, backed by a `device` column on session_queue (migration_32) recorded when a worker claims an item Adds TorchDevice.get_session_device_label()/get_session_device_index() helpers and a frontend getCudaDeviceIndex() parser (with tests). Shows the number on CUDA only; CPU/MPS show nothing. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent d521ba4 commit 8cef3cf

25 files changed

Lines changed: 221 additions & 22 deletions

File tree

invokeai/app/invocations/anima_denoise.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ def _run_transformer(ctx: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> tor
608608

609609
if driver is not None:
610610
user_step = 0
611-
pbar = tqdm(total=total_steps, desc="Denoising (Anima)")
611+
pbar = tqdm(total=total_steps, desc=f"Denoising (Anima){TorchDevice.get_session_device_label()}")
612612
for it in driver.iterations():
613613
timestep = torch.tensor(
614614
[it.sigma_curr * ANIMA_MULTIPLIER], device=device, dtype=inference_dtype
@@ -655,7 +655,9 @@ def _run_transformer(ctx: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> tor
655655
pbar.close()
656656
else:
657657
# Built-in Euler implementation (default for Anima)
658-
for step_idx in tqdm(range(total_steps), desc="Denoising (Anima)"):
658+
for step_idx in tqdm(
659+
range(total_steps), desc=f"Denoising (Anima){TorchDevice.get_session_device_label()}"
660+
):
659661
sigma_curr = sigmas[step_idx]
660662
sigma_prev = sigmas[step_idx + 1]
661663

invokeai/app/invocations/cogview4_denoise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def _run_diffusion(
294294
assert isinstance(transformer, CogView4Transformer2DModel)
295295

296296
# Denoising loop
297-
for step_idx in tqdm(range(total_steps)):
297+
for step_idx in tqdm(range(total_steps), desc=f"Denoising{TorchDevice.get_session_device_label()}"):
298298
t_curr = timesteps[step_idx]
299299
sigma_curr = sigmas[step_idx]
300300
sigma_prev = sigmas[step_idx + 1]

invokeai/app/invocations/sd3_denoise.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,10 @@ def _run_diffusion(
284284
assert isinstance(transformer, SD3Transformer2DModel)
285285

286286
# 6. Denoising loop
287-
for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
287+
for step_idx, (t_curr, t_prev) in tqdm(
288+
list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True))),
289+
desc=f"Denoising{TorchDevice.get_session_device_label()}",
290+
):
288291
# Expand the latents if we are doing CFG.
289292
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
290293
# Expand the timestep to match the latent model input.

invokeai/app/invocations/z_image_denoise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
569569
# Use diffusers scheduler for stepping
570570
# Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps)
571571
# This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps
572-
pbar = tqdm(total=total_steps, desc="Denoising")
572+
pbar = tqdm(total=total_steps, desc=f"Denoising{TorchDevice.get_session_device_label()}")
573573
for step_index in range(num_scheduler_steps):
574574
sched_timestep = scheduler.timesteps[step_index]
575575
# Convert scheduler timestep (0-1000) to normalized sigma (0-1)
@@ -686,7 +686,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
686686
pbar.close()
687687
else:
688688
# Original Euler implementation (default, optimized for Z-Image)
689-
for step_idx in tqdm(range(total_steps)):
689+
for step_idx in tqdm(range(total_steps), desc=f"Denoising{TorchDevice.get_session_device_label()}"):
690690
sigma_curr = sigmas[step_idx]
691691
sigma_prev = sigmas[step_idx + 1]
692692

invokeai/app/services/events/events_common.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ class InvocationProgressEvent(InvocationEventBase):
138138
image: ProgressImage | None = Field(
139139
default=None, description="An image representing the current state of the progress"
140140
)
141+
device: str | None = Field(
142+
default=None,
143+
description="The device processing this session, e.g. 'cuda:1' (set only when running on a CUDA GPU)",
144+
)
141145

142146
@classmethod
143147
def build(
@@ -148,6 +152,13 @@ def build(
148152
percentage: float | None = None,
149153
image: ProgressImage | None = None,
150154
) -> "InvocationProgressEvent":
155+
# This is emitted from the session-processor worker thread, which pins its CUDA device via
156+
# TorchDevice.set_session_device(). Resolve that here so the UI can label progress by GPU.
157+
from invokeai.backend.util.devices import TorchDevice
158+
159+
session_device = TorchDevice.get_session_device()
160+
device = str(session_device) if session_device is not None and session_device.type == "cuda" else None
161+
151162
return cls(
152163
queue_id=queue_item.queue_id,
153164
item_id=queue_item.item_id,
@@ -161,6 +172,7 @@ def build(
161172
percentage=percentage,
162173
image=image,
163174
message=message,
175+
device=device,
164176
)
165177

166178

invokeai/app/services/session_processor/session_processor_default.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,11 @@ def _process(
529529
break
530530

531531
# Get the next session to process. dequeue() atomically claims the item, so concurrent
532-
# workers never receive the same item.
533-
worker.queue_item = self._invoker.services.session_queue.dequeue()
532+
# workers never receive the same item. Pass this worker's device so the item is
533+
# tagged with the GPU that ran it (None in single-device/legacy mode).
534+
worker.queue_item = self._invoker.services.session_queue.dequeue(
535+
device=str(worker.device) if worker.device is not None else None
536+
)
534537

535538
if worker.queue_item is None:
536539
# The queue was empty, wait for next polling interval or event to try again

invokeai/app/services/session_queue/session_queue_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ class SessionQueueBase(ABC):
3131
"""Base class for session queue"""
3232

3333
@abstractmethod
34-
def dequeue(self) -> Optional[SessionQueueItem]:
35-
"""Dequeues the next session queue item."""
34+
def dequeue(self, device: Optional[str] = None) -> Optional[SessionQueueItem]:
35+
"""Dequeues the next session queue item, recording the processing device (e.g. 'cuda:1') if given."""
3636
pass
3737

3838
@abstractmethod

invokeai/app/services/session_queue/session_queue_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,10 @@ class SessionQueueItem(BaseModel):
262262
retried_from_item_id: Optional[int] = Field(
263263
default=None, description="The item_id of the queue item that this item was retried from"
264264
)
265+
device: Optional[str] = Field(
266+
default=None,
267+
description="The device that processed this queue item, e.g. 'cuda:1' (set only when running on a CUDA GPU)",
268+
)
265269
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
266270
workflow: Optional[WorkflowWithoutID] = Field(
267271
default=None, description="The workflow associated with this queue item"

invokeai/app/services/session_queue/session_queue_sqlite.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ async def enqueue_batch(
216216
self.__invoker.services.events.emit_batch_enqueued(enqueue_result, user_id=user_id)
217217
return enqueue_result
218218

219-
def dequeue(self) -> Optional[SessionQueueItem]:
219+
def dequeue(self, device: Optional[str] = None) -> Optional[SessionQueueItem]:
220220
# Hold the dequeue lock across the select-then-claim so concurrent workers (multi-GPU)
221221
# cannot select and claim the same pending item. `_set_queue_item_status` already no-ops
222222
# if the item was concurrently moved to a terminal state (e.g. canceled), so we only need
@@ -242,7 +242,8 @@ def dequeue(self) -> Optional[SessionQueueItem]:
242242
if result is None:
243243
return None
244244
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
245-
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
245+
# Record the claiming worker's device so the UI can label the item by GPU.
246+
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress", device=device)
246247
return queue_item
247248

248249
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
@@ -299,6 +300,7 @@ def _set_queue_item_status(
299300
error_type: Optional[str] = None,
300301
error_message: Optional[str] = None,
301302
error_traceback: Optional[str] = None,
303+
device: Optional[str] = None,
302304
) -> SessionQueueItem:
303305
with self._db.transaction() as cursor:
304306
cursor.execute(
@@ -320,10 +322,10 @@ def _set_queue_item_status(
320322
cursor.execute(
321323
"""--sql
322324
UPDATE session_queue
323-
SET status = ?, status_sequence = COALESCE(status_sequence, 0) + 1, error_type = ?, error_message = ?, error_traceback = ?
325+
SET status = ?, status_sequence = COALESCE(status_sequence, 0) + 1, error_type = ?, error_message = ?, error_traceback = ?, device = COALESCE(?, device)
324326
WHERE item_id = ?
325327
""",
326-
(status, error_type, error_message, error_traceback, item_id),
328+
(status, error_type, error_message, error_traceback, device, item_id),
327329
)
328330

329331
queue_item = self.get_queue_item(item_id)

invokeai/app/services/shared/sqlite/sqlite_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_29 import build_migration_29
3535
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_30 import build_migration_30
3636
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_31 import build_migration_31
37+
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_32 import build_migration_32
3738
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
3839

3940

@@ -85,6 +86,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
8586
migrator.register_migration(build_migration_29())
8687
migrator.register_migration(build_migration_30())
8788
migrator.register_migration(build_migration_31())
89+
migrator.register_migration(build_migration_32())
8890
migrator.run_migrations()
8991

9092
return db

0 commit comments

Comments
 (0)