Skip to content

Support Blackwell GPUs (torch 2.7 + DGL 2.4 cu124)#29

Open
jackytamkc wants to merge 1 commit intoZJUFanLab:mainfrom
jackytamkc:pr1-blackwell-compat
Open

Support Blackwell GPUs (torch 2.7 + DGL 2.4 cu124)#29
jackytamkc wants to merge 1 commit intoZJUFanLab:mainfrom
jackytamkc:pr1-blackwell-compat

Conversation

@jackytamkc
Copy link
Copy Markdown

Summary

  • Replace graph.adjacency_matrix().to_dense().shape[0] with graph.num_nodes() at the two pos_weight sites in scniche/trainer/_train.py. Same value, no N×N dense allocation, and avoids the torch-pinned libdgl_sparse_pytorch_<X>.so load that fails on torch 2.7 with the DGL 2.4 cu124 wheel.
  • Add a "Blackwell GPUs" subsection to the README documenting the install order for sm_120 hardware (RTX PRO 6000, RTX 50-series). Existing CUDA 11.3 instructions are kept intact — this is purely additive.

Background

Blackwell GPUs (compute capability sm_120) require PyTorch built against CUDA 12.8. No DGL wheel is currently published for cu128, but the DGL 2.4 cu124 wheel's CUDA kernels are ABI-compatible with torch 2.7 and JIT-forward to sm_120. The only blocker for scNiche on this stack is DGLGraph.adjacency_matrix(), which tries to dlopen a torch-version-pinned sparse library (libdgl_sparse_pytorch_2.4.so) and crashes against torch 2.7. Both call sites in scNiche only use the result to read N, so swapping to num_nodes() is functionally equivalent and removes the dependency on the sparse library.

Backward compatibility

  • DGLGraph.num_nodes() has been part of the public DGL API since 0.5, so existing users on DGL 1.1.0+cu113 are unaffected (same integer, faster).
  • README changes are additive — the existing CUDA 11.3 install path is unchanged.

Test plan

  • Install on Ampere/Ada with the existing cu113 instructions; verify Runner.fit and Runner_batch.fit still produce the same pos_weight and embeddings as before.
  • Install on Blackwell (sm_120) following the new README section; verify both Runner.fit and Runner_batch.fit train end-to-end without the libdgl_sparse_pytorch load error.
  • pip check warning about dgl requires torch<=2.4.0 is expected on the Blackwell stack and documented in the README.

🤖 Generated with Claude Code

- _train.py: replace `graph.adjacency_matrix().to_dense().shape[0]`
  with `graph.num_nodes()` at the two pos_weight sites. Same value,
  no N*N allocation, and avoids the torch-pinned libdgl_sparse_pytorch
  load that fails on torch 2.7 with the DGL 2.4 cu124 wheel.
- README: add Blackwell / sm_120 install instructions alongside the
  existing cu113 path.

Backward-compatible: num_nodes() works on all supported DGL versions;
README changes are additive.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@ProDong0512
Copy link
Copy Markdown

Thank you for your comments. We will consider testing it and add it into our next version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants