Improve ROCm compatibility for TriSplat native extension builds#3
Improve ROCm compatibility for TriSplat native extension builds#3ZJLi2013 wants to merge 8 commits into
Conversation
- diff-triangle-rasterization/setup.py: detect torch.version.hip and switch to hip_rasterizer/*.hip sources; drop NVIDIA-only nvcc flags (--use_fast_math, --expt-relaxed-constexpr) under ROCm. - simple-knn: drop unavailable device_launch_parameters.h and cooperative_groups/reduce.h; normalize CUDA chevron spacing for hipify; switch deprecated .data<T>() to .data_ptr<T>(). - curope/setup.py: skip -gencode arch=compute_*,code=sm_* under ROCm and let PyTorch pick HIP arch flags. - curope/kernels.cu: use scalar_type() instead of deprecated type(). - scripts/env/rebuild_extensions.sh: helper to rebuild all native extensions in-place and patch helper_math.h smoothstep / CUDA-only includes in the downloaded diff-gaussian-rasterization-w-pose source. All changes guarded by torch.version.hip so CUDA builds are unaffected. Co-authored-by: Cursor <cursoragent@cursor.com>
Remove vendored dependency patches from the TriSplat branch so those ROCm fixes can be reviewed in their upstream repositories.
|
Hi @ZJLi2013, thanks a lot for the patch and the very detailed reproduction steps — really appreciate the effort put into this. I've taken a careful look, and the build-level changes themselves are clean: the One thing I want to think through with you before we land it: on its own, this PR only makes the vendored To that end, a few suggestions that would really help me get this over the line confidently:
None of this is a hard blocker, and I'm not going to close the PR — just sharing what would make me most comfortable merging, and happy to keep iterating with you here. If some parts of the pipeline turn out not to work on ROCm, that's also useful information, and we can document it explicitly so AMD users know what to expect. Thanks again for pushing on this. Best, |
Summary
This PR keeps the TriSplat-side ROCm change minimal: it applies the small
curopebuild fixes needed for the vendored CroCo code path to compile with PyTorch ROCm. CroCo AMD support already exists in related work, so this PR only carries the compatibility delta needed inside TriSplat's vendored copy.Changes
-gencodearch flags whentorch.version.hipis present, letting PyTorch's ROCm extension tooling select the HIP architecture flags.--ptxas-options=-vand--use_fast_mathfrom thecuropeextension compile flags.AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), ...)totokens.scalar_type()for PyTorch 2.6 compatibility.Tested On
rocm/pytorch:rocm6.4.3_ubuntu24.04_py3.12_pytorch_release_2.6.0Reproduction
This checks the TriSplat native extension stack with this PR plus the two ROCm dependency PRs:
The dependency PRs are required for full TriSplat mesh runs because
diff-gaussian-rasterization-w-poseis used bysrc/model/decoder/cuda_splatting.py, anddiff-triangle-rasterizationis used bysrc/model/decoder/cuda_triangle_splatting.py.Results
curopeextension under ROCm.diff-gaussian-rasterization-w-pose#4anddiff-triangle-rasterization#3applied, completed an end-to-end public LLFFroommesh export smoke run on ROCm and produced aDIRECT_triangle_mesh_post.plyartifact.Notes