[ROCm] support warpSize 32 and 64 in a single fbgemm_gpu build#5804
Open
jeffdaily wants to merge 1 commit into
Open
[ROCm] support warpSize 32 and 64 in a single fbgemm_gpu build#5804jeffdaily wants to merge 1 commit into
jeffdaily wants to merge 1 commit into
Conversation
Continues the warpSize-32 enablement from pytorch#5739 which fixed device-side kWarpSize. This change makes a single ROCm fbgemm_gpu wheel run correctly on both wave-32 and wave-64 AMD archs, including a mixed-arch build: PYTORCH_ROCM_ARCH="gfx90a;gfx1100" Four root causes were fixed. 1. TBE kernel templates bake kThreadGroupSize into the mangled name. On a mixed-arch build, the host pass instantiates kernels with the ROCm 64 placeholder; the gfx1100 device pass instantiates with 32. The linker can't find the gfx1100 device symbol. The kernel templates are reparameterised: instead of taking kThreadGroupSize as a template integer (which carries warpSize into the mangling), they take kSubwarpDivisor. The kernel body declares `constexpr int32_t kThreadGroupSize = kWarpSize / kSubwarpDivisor;` so the per-arch device pass picks the right value via the already-correct device-side kWarpSize. The mangled name carries the divisor literal, which is wave-size-free, so host and every device pass agree on the symbol. For a multi-arch wheel this is ~30% fewer symbols than emitting separate wave32 and wave64 instantiations (one symbol per bracket serves both waves; per-arch device code only generates the matching impl). 2. Host-side launch configurations use kWarpSize, which on ROCm is a placeholder 64 on the host pass and produces wrong block dims on a mixed-arch wheel running on gfx1100. A new kWarpSizeHost() inline function resolves to constexpr 32 on CUDA and at::cuda::warp_size() (cached device-properties lookup) on ROCm. Host-side dim3 / quotient sites switch to kWarpSizeHost. Device code keeps kWarpSize (per-arch correct since pytorch#5739). The host-side dispatch macros (DISPATCH_OPTIMAL_KERNEL, DISPATCH_NON_VEC_BLOCKING_KERNEL, DISPATCH_VEC_BLOCKING_KERNEL, DISPATCH_OPTIMAL_FORWARD_KERNEL) emit a wave-specific _WAVE32/_WAVE64 pair and select at runtime via kWarpSizeHost(). Single-wave-size builds emit only the matching table. A new CMake-time mechanism in cmake/Hip.cmake derives FBGEMM_HAS_WAVE32 and FBGEMM_HAS_WAVE64 from PYTORCH_ROCM_ARCH and passes them through to the codegen. Single-arch wheels (gfx1100-only or gfx90a-only) emit only the matching wave's bracket table and only the matching wave's kernel instantiations, so wheel size does not grow vs the pre-port ROCm wheel. The multi-arch wheel pays the cost only when explicitly asked for. 3. The TBE LRU/LFU/LXU cache hardcoded associativity (ways per set) to 64 on ROCm and guarded the populate kernels with TORCH_CHECK(warp_size()==64). Cache associativity must equal the device warp size, because one warp cooperatively scans the ways of a single set. Unlike the training kernels, the cache kernels are not templated on warpSize -- they use the device-pass kWarpSize constant for row indexing -- so the multi-arch fat binary already contains per-arch-correct device code; only the host-allocated cache geometry was wrong. _apply_cache_state now derives a per-instance cache associativity from the running device's warp size (torch.cuda.get_device_properties(dev).warp_size) instead of the hardcoded DEFAULT_ASSOC, so the cache-state tensors match the device kernel's per-arch indexing. The wave-64-only TORCH_CHECK guards are removed. DEFAULT_ASSOC remains as the CPU/no-device fallback. 4. warpReduceAllSum on ROCm dispatched unconditionally to rocm::wave_reduce, whose inner dpp_reduction is gated on `defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__)` and compiles to a no-op on every other arch. On gfx1100 (and any gfx9 not in that list, e.g. gfx908) wave_reduce therefore returned readlane(warpSize-1) of un-reduced data instead of the warp sum, silently corrupting the grad_indice_weights reduction (and every other warpReduceAllSum use). The CDNA row_bcast DPP controls cannot be reused on RDNA, so warpReduceAllSum now takes the rocm::wave_reduce path only on the archs where dpp_reduction is actually implemented (gfx942/gfx90a/gfx950) and falls back to the portable shfl_xor butterfly everywhere else -- correct for any warp size and the same path CUDA uses. (The two gates must stay in sync.) Deferred follow-ups: the experimental gen_ai / moe / attention warpSize cleanup is a separate change; gfx90a runtime regression should be confirmed on a wave-64 host. Test plan (gfx1100, warpSize 32): cd fbgemm_gpu rm -rf _skbuild dist PYTORCH_ROCM_ARCH=gfx1100 BUILD_ROCM_VERSION=7.2 \ python setup.py bdist_wheel --build-target default --build-variant rocm pip install --force-reinstall dist/fbgemm_gpu_nightly_rocm-*.whl python -c "import fbgemm_gpu" HIP_VISIBLE_DEVICES=0 python -m pytest test/tbe/training/forward_test.py # 13 passed, 3 skipped (uvm_cache paths work on wave32). HIP_VISIBLE_DEVICES=0 python -m pytest test/tbe/cache/ # all pass (lxu_cache, cache_test, cache_config, linearize, # cache_overflow). HIP_VISIBLE_DEVICES=0 python -m pytest \ test/tbe/training/backward_dense_test.py # 1 passed -- the per_sample_weights gradcheck that exercises the # warpReduceAllSum fix (item 4); previously a Jacobian mismatch. HIP_VISIBLE_DEVICES=0 python -m pytest \ test/tbe/training/backward_adagrad_test.py \ test/tbe/training/backward_sgd_test.py \ test/tbe/training/backward_none_test.py # backward_adagrad 11 passed/4 skipped, backward_sgd 5 passed/3 skipped, # backward_none 1 passed/2 skipped. # Multi-arch build (compile-only on this host; builds clean and imports # on gfx1100; gfx90a runtime regression to be run on a wave64 host): rm -rf _skbuild dist PYTORCH_ROCM_ARCH="gfx90a;gfx1100" BUILD_ROCM_VERSION=7.2 \ python setup.py bdist_wheel --build-target default --build-variant rocm Authored with assistance from Claude (Anthropic).
Contributor
|
@q10 has imported this pull request. If you are a Meta employee, you can view this in D107177824. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Continues the warpSize-32 enablement from #5739 which fixed device-side kWarpSize. This change makes a single ROCm fbgemm_gpu wheel run correctly on both wave-32 and wave-64 AMD archs, including a mixed-arch build:
Four root causes were fixed.
TBE kernel templates bake kThreadGroupSize into the mangled name. On a mixed-arch build, the host pass instantiates kernels with the ROCm 64 placeholder; the gfx1100 device pass instantiates with 32. The linker can't find the gfx1100 device symbol.
The kernel templates are reparameterised: instead of taking kThreadGroupSize as a template integer (which carries warpSize into the mangling), they take kSubwarpDivisor. The kernel body declares
constexpr int32_t kThreadGroupSize = kWarpSize / kSubwarpDivisor;so the per-arch device pass picks the right value via the already-correct device-side kWarpSize. The mangled name carries the divisor literal, which is wave-size-free, so host and every device pass agree on the symbol. For a multi-arch wheel this is ~30% fewer symbols than emitting separate wave32 and wave64 instantiations (one symbol per bracket serves both waves; per-arch device code only generates the matching impl).Host-side launch configurations use kWarpSize, which on ROCm is a placeholder 64 on the host pass and produces wrong block dims on a mixed-arch wheel running on gfx1100.
A new kWarpSizeHost() inline function resolves to constexpr 32 on CUDA and at::cuda::warp_size() (cached device-properties lookup) on ROCm. Host-side dim3 / quotient sites switch to kWarpSizeHost. Device code keeps kWarpSize (per-arch correct since support warpSize 32 and 64 in the same build (#5739) #5739).
The host-side dispatch macros (DISPATCH_OPTIMAL_KERNEL, DISPATCH_NON_VEC_BLOCKING_KERNEL, DISPATCH_VEC_BLOCKING_KERNEL, DISPATCH_OPTIMAL_FORWARD_KERNEL) emit a wave-specific _WAVE32/_WAVE64 pair and select at runtime via kWarpSizeHost(). Single-wave-size builds emit only the matching table.
A new CMake-time mechanism in cmake/Hip.cmake derives FBGEMM_HAS_WAVE32 and FBGEMM_HAS_WAVE64 from PYTORCH_ROCM_ARCH and passes them through to the codegen. Single-arch wheels (gfx1100-only or gfx90a-only) emit only the matching wave's bracket table and only the matching wave's kernel instantiations, so wheel size does not grow vs the pre-port ROCm wheel. The multi-arch wheel pays the cost only when explicitly asked for.
The TBE LRU/LFU/LXU cache hardcoded associativity (ways per set) to 64 on ROCm and guarded the populate kernels with TORCH_CHECK(warp_size()==64). Cache associativity must equal the device warp size, because one warp cooperatively scans the ways of a single set. Unlike the training kernels, the cache kernels are not templated on warpSize -- they use the device-pass kWarpSize constant for row indexing -- so the multi-arch fat binary already contains per-arch-correct device code; only the host-allocated cache geometry was wrong.
_apply_cache_state now derives a per-instance cache associativity from the running device's warp size (torch.cuda.get_device_properties(dev).warp_size) instead of the hardcoded DEFAULT_ASSOC, so the cache-state tensors match the device kernel's per-arch indexing. The wave-64-only TORCH_CHECK guards are removed. DEFAULT_ASSOC remains as the CPU/no-device fallback.
warpReduceAllSum on ROCm dispatched unconditionally to rocm::wave_reduce, whose inner dpp_reduction is gated on
defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__)and compiles to a no-op on every other arch. On gfx1100 (and any gfx9 not in that list, e.g. gfx908) wave_reduce therefore returned readlane(warpSize-1) of un-reduced data instead of the warp sum, silently corrupting the grad_indice_weights reduction (and every other warpReduceAllSum use). The CDNA row_bcast DPP controls cannot be reused on RDNA, so warpReduceAllSum now takes the rocm::wave_reduce path only on the archs where dpp_reduction is actually implemented (gfx942/gfx90a/gfx950) and falls back to the portable shfl_xor butterfly everywhere else -- correct for any warp size and the same path CUDA uses. (The two gates must stay in sync.)Deferred follow-ups: the experimental gen_ai / moe / attention warpSize cleanup is a separate change; gfx90a runtime regression should be confirmed on a wave-64 host.
Test plan (gfx1100, warpSize 32):
Authored with assistance from Claude (Anthropic).