Skip to content

Improve ROCm compatibility for TriSplat native extension builds#3

Closed
ZJLi2013 wants to merge 8 commits into
ziplab:mainfrom
PhysicalAI-AIM:rocm_support
Closed

Improve ROCm compatibility for TriSplat native extension builds#3
ZJLi2013 wants to merge 8 commits into
ziplab:mainfrom
PhysicalAI-AIM:rocm_support

Conversation

@ZJLi2013

Copy link
Copy Markdown

Summary

This PR keeps the TriSplat-side ROCm change minimal: it applies the small curope build fixes needed for the vendored CroCo code path to compile with PyTorch ROCm. CroCo AMD support already exists in related work, so this PR only carries the compatibility delta needed inside TriSplat's vendored copy.

Changes

  • Skip NVIDIA-only -gencode arch flags when torch.version.hip is present, letting PyTorch's ROCm extension tooling select the HIP architecture flags.
  • Remove NVIDIA-only --ptxas-options=-v and --use_fast_math from the curope extension compile flags.
  • Update AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), ...) to tokens.scalar_type() for PyTorch 2.6 compatibility.

Tested On

  • GPU: AMD Instinct MI300X
  • ROCm: 6.4.3
  • PyTorch: 2.6.0 ROCm Docker build
  • Docker: rocm/pytorch:rocm6.4.3_ubuntu24.04_py3.12_pytorch_release_2.6.0

Reproduction

This checks the TriSplat native extension stack with this PR plus the two ROCm dependency PRs:

git clone https://github.com/ZJLi2013/TriSplat.git
cd TriSplat
git checkout rocm_support
git submodule update --init --recursive

export PYTORCH_ROCM_ARCH=gfx942

# diff-gaussian-rasterization-w-pose ROCm support:
# https://github.com/rmurai0610/diff-gaussian-rasterization-w-pose/pull/4
pip install --no-build-isolation \
  "git+https://github.com/ZJLi2013/diff-gaussian-rasterization-w-pose.git@rocm_support"

# diff-triangle-rasterization ROCm support:
# https://github.com/trianglesplatting/diff-triangle-rasterization/pull/3
pip install --no-build-isolation \
  "git+https://github.com/ZJLi2013/diff-triangle-rasterization.git@rocm_support"

pip install --no-build-isolation ./submodules/simple-knn

cd src/model/encoder/backbone/croco/curope
pip install . --no-build-isolation
cd ../../../../../..

python - <<'PY'
import torch
import diff_gaussian_rasterization
import diff_triangle_rasterization
import simple_knn._C
import curope
print(torch.__version__, torch.version.hip)
print("native extension imports ok")
PY

The dependency PRs are required for full TriSplat mesh runs because diff-gaussian-rasterization-w-pose is used by src/model/decoder/cuda_splatting.py, and diff-triangle-rasterization is used by src/model/decoder/cuda_triangle_splatting.py.

Results

  • Built and imported the vendored CroCo curope extension under ROCm.
  • With diff-gaussian-rasterization-w-pose#4 and diff-triangle-rasterization#3 applied, completed an end-to-end public LLFF room mesh export smoke run on ROCm and produced a DIRECT_triangle_mesh_post.ply artifact.

Notes

lhmd and others added 8 commits May 26, 2026 10:41
- diff-triangle-rasterization/setup.py: detect torch.version.hip and
  switch to hip_rasterizer/*.hip sources; drop NVIDIA-only nvcc flags
  (--use_fast_math, --expt-relaxed-constexpr) under ROCm.
- simple-knn: drop unavailable device_launch_parameters.h and
  cooperative_groups/reduce.h; normalize CUDA chevron spacing for hipify;
  switch deprecated .data<T>() to .data_ptr<T>().
- curope/setup.py: skip -gencode arch=compute_*,code=sm_* under ROCm and
  let PyTorch pick HIP arch flags.
- curope/kernels.cu: use scalar_type() instead of deprecated type().
- scripts/env/rebuild_extensions.sh: helper to rebuild all native
  extensions in-place and patch helper_math.h smoothstep / CUDA-only
  includes in the downloaded diff-gaussian-rasterization-w-pose source.

All changes guarded by torch.version.hip so CUDA builds are unaffected.

Co-authored-by: Cursor <cursoragent@cursor.com>
Remove vendored dependency patches from the TriSplat branch so those ROCm fixes can be reviewed in their upstream repositories.
@lhmd

lhmd commented May 29, 2026

Copy link
Copy Markdown
Member

Hi @ZJLi2013, thanks a lot for the patch and the very detailed reproduction steps — really appreciate the effort put into this.

I've taken a careful look, and the build-level changes themselves are clean: the torch.version.hip guards keep the CUDA path untouched, dropping the NVIDIA-only nvcc flags under ROCm is the right call, and tokens.scalar_type() is a welcome PyTorch 2.6+ cleanup. I'll keep the PR open while we figure out the best way forward together.

One thing I want to think through with you before we land it: on its own, this PR only makes the vendored curope extension compile under ROCm, but TriSplat's actual pipeline also depends on diff-gaussian-rasterization-w-pose, diff-triangle-rasterization, and simple-knn, which are CUDA-only upstream today. Until your two dependency PRs (diff-gaussian-rasterization-w-pose#4, diff-triangle-rasterization#3) are merged upstream, a stock pip install on an AMD box will still fail, and I'd worry about advertising ROCm support that users can't actually exercise yet. So I'd rather not rush this in and then have to triage AMD issues we can't reproduce — I'd love to coordinate the timing with you instead.

To that end, a few suggestions that would really help me get this over the line confidently:

  1. Coordinate with the upstream dependency PRs. If you can ping the maintainers of diff-gaussian-rasterization-w-pose and diff-triangle-rasterization and get a sense of their merge timeline (or even just a thumbs-up from them), that would be great. Once those are in, this PR becomes immediately useful, and I'd be happy to merge it alongside a README section pointing AMD users to the right setup.
  2. A fuller end-to-end evaluation on AMD. Beyond the current smoke test, it would be very reassuring to see one of the paper's evaluation protocols (e.g. RealEstate10K or DL3DV) run on MI300X, with PSNR / SSIM / LPIPS (and Chamfer / normal consistency for the mesh) side-by-side with our CUDA numbers. Even one scene is fine — the goal is just to confirm there's no silent numerical drift on ROCm.
  3. A small training sanity check on AMD, if feasible. Training is where ROCm tends to diverge from CUDA the most (mixed precision, distributed backend, kernel numerics). If you could reproduce one of our shipped training configs on MI300X, even with a shortened schedule, and share the loss curve plus an eval on the resulting checkpoint, that would give us a much clearer picture of what works and what doesn't on AMD.

None of this is a hard blocker, and I'm not going to close the PR — just sharing what would make me most comfortable merging, and happy to keep iterating with you here. If some parts of the pipeline turn out not to work on ROCm, that's also useful information, and we can document it explicitly so AMD users know what to expect.

Thanks again for pushing on this.

Best,
Weijie

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants