Skip to content

guard torch.cuda.synchronize for CPU/MPS eval paths#271

Open
LiudengZhang wants to merge 2 commits into
ArcInstitute:mainfrom
LiudengZhang:fix/guard-cuda-synchronize-for-cpu-eval
Open

guard torch.cuda.synchronize for CPU/MPS eval paths#271
LiudengZhang wants to merge 2 commits into
ArcInstitute:mainfrom
LiudengZhang:fix/guard-cuda-synchronize-for-cpu-eval

Conversation

@LiudengZhang
Copy link
Copy Markdown

Two unguarded torch.cuda.synchronize() calls (one paired with torch.cuda.empty_cache()) currently raise RuntimeError: No CUDA GPUs are available when running the eval paths on a CPU- or MPS-only machine, even when the model itself can run on CPU:

  • src/state/_cli/_emb/_eval.py:121-122 — inside the embedding eval batch loop
  • src/state/emb/nn/eval_utils.py:79 — between _predict_exp_for_adata and the DE gene-overlap metric

This wraps each in if torch.cuda.is_available():, matching the existing pattern already used in src/state/tx/callbacks/model_flops_utilization.py (lines 131 and 148).

Net change: 5 lines across 2 files. No behavior change on CUDA hosts; eval now runs on CPU/MPS hosts that previously crashed.

@LiudengZhang LiudengZhang requested a review from a team as a code owner May 15, 2026 18:40
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces checks for CUDA availability before performing synchronization and cache clearing operations to ensure the code runs correctly on non-GPU environments. Review feedback suggests using the existing device_type variable for consistency and warns that calling empty_cache() inside a batch loop can significantly degrade performance. Additionally, one synchronization call was identified as potentially redundant because the subsequent data transfer to the host likely triggers an implicit synchronization.

Comment thread src/state/_cli/_emb/_eval.py Outdated
Comment on lines +121 to +123
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using the local device_type variable (defined at line 101) is more consistent with the rest of the function than calling torch.cuda.is_available() again.

Additionally, calling torch.cuda.empty_cache() inside a batch loop is generally discouraged as it forces a synchronization and causes the allocator to release memory back to the driver, which can significantly slow down execution. Unless this is strictly necessary to mitigate specific memory issues during evaluation, consider removing it or moving it outside the loop.

Suggested change
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
if device_type == "cuda":
torch.cuda.synchronize()
torch.cuda.empty_cache()

Comment on lines +79 to +80
if torch.cuda.is_available():
torch.cuda.synchronize()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This synchronization is likely redundant. Since _predict_exp_for_adata returns a result that is subsequently used as a DataFrame or numpy array (e.g., in compute_gene_overlap_cross_pert), the data transfer from the GPU to the host has already performed an implicit synchronization. You may consider removing this block entirely to simplify the code.

Per review feedback: reuse the device_type variable defined above
instead of calling torch.cuda.is_available() a second time.
@LiudengZhang
Copy link
Copy Markdown
Author

Thanks for the review. Pushed a22a2bf addressing the consistency point — switched _eval.py:121 to use the existing device_type variable rather than re-calling torch.cuda.is_available().

On the other two points:

  • empty_cache() inside the batch loop (_eval.py:123) — agreed it's a known perf anti-pattern, but it pre-dates this PR; this PR only adds the CPU/MPS guard around the existing block. Happy to do that as a follow-up if maintainers want it cleaned up here.
  • torch.cuda.synchronize() at eval_utils.py:80 — the whole purpose of this PR is to guard that sync so CPU/MPS eval paths don't crash. Removing the sync entirely (rather than guarding it) is a different change than what this PR proposes. Happy to revisit if maintainers prefer that route, but I'd suggest keeping the guarded sync for now since the CPU/MPS bug is the motivating issue.

Let me know.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant