guard torch.cuda.synchronize for CPU/MPS eval paths#271
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces checks for CUDA availability before performing synchronization and cache clearing operations to ensure the code runs correctly on non-GPU environments. Review feedback suggests using the existing device_type variable for consistency and warns that calling empty_cache() inside a batch loop can significantly degrade performance. Additionally, one synchronization call was identified as potentially redundant because the subsequent data transfer to the host likely triggers an implicit synchronization.
| if torch.cuda.is_available(): | ||
| torch.cuda.synchronize() | ||
| torch.cuda.empty_cache() |
There was a problem hiding this comment.
Using the local device_type variable (defined at line 101) is more consistent with the rest of the function than calling torch.cuda.is_available() again.
Additionally, calling torch.cuda.empty_cache() inside a batch loop is generally discouraged as it forces a synchronization and causes the allocator to release memory back to the driver, which can significantly slow down execution. Unless this is strictly necessary to mitigate specific memory issues during evaluation, consider removing it or moving it outside the loop.
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| torch.cuda.empty_cache() | |
| if device_type == "cuda": | |
| torch.cuda.synchronize() | |
| torch.cuda.empty_cache() |
| if torch.cuda.is_available(): | ||
| torch.cuda.synchronize() |
There was a problem hiding this comment.
This synchronization is likely redundant. Since _predict_exp_for_adata returns a result that is subsequently used as a DataFrame or numpy array (e.g., in compute_gene_overlap_cross_pert), the data transfer from the GPU to the host has already performed an implicit synchronization. You may consider removing this block entirely to simplify the code.
Per review feedback: reuse the device_type variable defined above instead of calling torch.cuda.is_available() a second time.
|
Thanks for the review. Pushed a22a2bf addressing the consistency point — switched On the other two points:
Let me know. |
Two unguarded
torch.cuda.synchronize()calls (one paired withtorch.cuda.empty_cache()) currently raiseRuntimeError: No CUDA GPUs are availablewhen running the eval paths on a CPU- or MPS-only machine, even when the model itself can run on CPU:src/state/_cli/_emb/_eval.py:121-122— inside the embedding eval batch loopsrc/state/emb/nn/eval_utils.py:79— between_predict_exp_for_adataand the DE gene-overlap metricThis wraps each in
if torch.cuda.is_available():, matching the existing pattern already used insrc/state/tx/callbacks/model_flops_utilization.py(lines 131 and 148).Net change: 5 lines across 2 files. No behavior change on CUDA hosts; eval now runs on CPU/MPS hosts that previously crashed.