Skip to content

[ROCm] support warpSize 32 and 64 in a single fbgemm_gpu build#5804

Open
jeffdaily wants to merge 1 commit into
pytorch:mainfrom
jeffdaily:jeffdaily/warpSize-finish-port
Open

[ROCm] support warpSize 32 and 64 in a single fbgemm_gpu build#5804
jeffdaily wants to merge 1 commit into
pytorch:mainfrom
jeffdaily:jeffdaily/warpSize-finish-port

Conversation

@jeffdaily
Copy link
Copy Markdown
Contributor

Continues the warpSize-32 enablement from #5739 which fixed device-side kWarpSize. This change makes a single ROCm fbgemm_gpu wheel run correctly on both wave-32 and wave-64 AMD archs, including a mixed-arch build:

PYTORCH_ROCM_ARCH="gfx90a;gfx1100"

Four root causes were fixed.

  1. TBE kernel templates bake kThreadGroupSize into the mangled name. On a mixed-arch build, the host pass instantiates kernels with the ROCm 64 placeholder; the gfx1100 device pass instantiates with 32. The linker can't find the gfx1100 device symbol.

    The kernel templates are reparameterised: instead of taking kThreadGroupSize as a template integer (which carries warpSize into the mangling), they take kSubwarpDivisor. The kernel body declares constexpr int32_t kThreadGroupSize = kWarpSize / kSubwarpDivisor; so the per-arch device pass picks the right value via the already-correct device-side kWarpSize. The mangled name carries the divisor literal, which is wave-size-free, so host and every device pass agree on the symbol. For a multi-arch wheel this is ~30% fewer symbols than emitting separate wave32 and wave64 instantiations (one symbol per bracket serves both waves; per-arch device code only generates the matching impl).

  2. Host-side launch configurations use kWarpSize, which on ROCm is a placeholder 64 on the host pass and produces wrong block dims on a mixed-arch wheel running on gfx1100.

    A new kWarpSizeHost() inline function resolves to constexpr 32 on CUDA and at::cuda::warp_size() (cached device-properties lookup) on ROCm. Host-side dim3 / quotient sites switch to kWarpSizeHost. Device code keeps kWarpSize (per-arch correct since support warpSize 32 and 64 in the same build (#5739) #5739).

    The host-side dispatch macros (DISPATCH_OPTIMAL_KERNEL, DISPATCH_NON_VEC_BLOCKING_KERNEL, DISPATCH_VEC_BLOCKING_KERNEL, DISPATCH_OPTIMAL_FORWARD_KERNEL) emit a wave-specific _WAVE32/_WAVE64 pair and select at runtime via kWarpSizeHost(). Single-wave-size builds emit only the matching table.

    A new CMake-time mechanism in cmake/Hip.cmake derives FBGEMM_HAS_WAVE32 and FBGEMM_HAS_WAVE64 from PYTORCH_ROCM_ARCH and passes them through to the codegen. Single-arch wheels (gfx1100-only or gfx90a-only) emit only the matching wave's bracket table and only the matching wave's kernel instantiations, so wheel size does not grow vs the pre-port ROCm wheel. The multi-arch wheel pays the cost only when explicitly asked for.

  3. The TBE LRU/LFU/LXU cache hardcoded associativity (ways per set) to 64 on ROCm and guarded the populate kernels with TORCH_CHECK(warp_size()==64). Cache associativity must equal the device warp size, because one warp cooperatively scans the ways of a single set. Unlike the training kernels, the cache kernels are not templated on warpSize -- they use the device-pass kWarpSize constant for row indexing -- so the multi-arch fat binary already contains per-arch-correct device code; only the host-allocated cache geometry was wrong.

    _apply_cache_state now derives a per-instance cache associativity from the running device's warp size (torch.cuda.get_device_properties(dev).warp_size) instead of the hardcoded DEFAULT_ASSOC, so the cache-state tensors match the device kernel's per-arch indexing. The wave-64-only TORCH_CHECK guards are removed. DEFAULT_ASSOC remains as the CPU/no-device fallback.

  4. warpReduceAllSum on ROCm dispatched unconditionally to rocm::wave_reduce, whose inner dpp_reduction is gated on defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__) and compiles to a no-op on every other arch. On gfx1100 (and any gfx9 not in that list, e.g. gfx908) wave_reduce therefore returned readlane(warpSize-1) of un-reduced data instead of the warp sum, silently corrupting the grad_indice_weights reduction (and every other warpReduceAllSum use). The CDNA row_bcast DPP controls cannot be reused on RDNA, so warpReduceAllSum now takes the rocm::wave_reduce path only on the archs where dpp_reduction is actually implemented (gfx942/gfx90a/gfx950) and falls back to the portable shfl_xor butterfly everywhere else -- correct for any warp size and the same path CUDA uses. (The two gates must stay in sync.)

Deferred follow-ups: the experimental gen_ai / moe / attention warpSize cleanup is a separate change; gfx90a runtime regression should be confirmed on a wave-64 host.

Test plan (gfx1100, warpSize 32):

cd fbgemm_gpu
rm -rf _skbuild dist
PYTORCH_ROCM_ARCH=gfx1100 BUILD_ROCM_VERSION=7.2 \
    python setup.py bdist_wheel --build-target default --build-variant rocm
pip install --force-reinstall dist/fbgemm_gpu_nightly_rocm-*.whl
python -c "import fbgemm_gpu"

HIP_VISIBLE_DEVICES=0 python -m pytest test/tbe/training/forward_test.py
# 13 passed, 3 skipped (uvm_cache paths work on wave32).
HIP_VISIBLE_DEVICES=0 python -m pytest test/tbe/cache/
# all pass (lxu_cache, cache_test, cache_config, linearize,
# cache_overflow).

HIP_VISIBLE_DEVICES=0 python -m pytest \
    test/tbe/training/backward_dense_test.py
# 1 passed -- the per_sample_weights gradcheck that exercises the
# warpReduceAllSum fix (item 4); previously a Jacobian mismatch.
HIP_VISIBLE_DEVICES=0 python -m pytest \
    test/tbe/training/backward_adagrad_test.py \
    test/tbe/training/backward_sgd_test.py \
    test/tbe/training/backward_none_test.py
# backward_adagrad 11 passed/4 skipped, backward_sgd 5 passed/3 skipped,
# backward_none 1 passed/2 skipped.

# Multi-arch build (compile-only on this host; builds clean and imports
# on gfx1100; gfx90a runtime regression to be run on a wave64 host):
rm -rf _skbuild dist
PYTORCH_ROCM_ARCH="gfx90a;gfx1100" BUILD_ROCM_VERSION=7.2 \
    python setup.py bdist_wheel --build-target default --build-variant rocm

Authored with assistance from Claude (Anthropic).

Continues the warpSize-32 enablement from pytorch#5739 which
fixed device-side kWarpSize. This change makes a single ROCm
fbgemm_gpu wheel run correctly on both wave-32 and wave-64 AMD archs,
including a mixed-arch build:

    PYTORCH_ROCM_ARCH="gfx90a;gfx1100"

Four root causes were fixed.

1. TBE kernel templates bake kThreadGroupSize into the mangled name.
   On a mixed-arch build, the host pass instantiates kernels with the
   ROCm 64 placeholder; the gfx1100 device pass instantiates with 32.
   The linker can't find the gfx1100 device symbol.

   The kernel templates are reparameterised: instead of taking
   kThreadGroupSize as a template integer (which carries warpSize
   into the mangling), they take kSubwarpDivisor. The kernel body
   declares `constexpr int32_t kThreadGroupSize = kWarpSize /
   kSubwarpDivisor;` so the per-arch device pass picks the right
   value via the already-correct device-side kWarpSize. The mangled
   name carries the divisor literal, which is wave-size-free, so
   host and every device pass agree on the symbol. For a multi-arch
   wheel this is ~30% fewer symbols than emitting separate wave32
   and wave64 instantiations (one symbol per bracket serves both
   waves; per-arch device code only generates the matching impl).

2. Host-side launch configurations use kWarpSize, which on ROCm is a
   placeholder 64 on the host pass and produces wrong block dims on
   a mixed-arch wheel running on gfx1100.

   A new kWarpSizeHost() inline function resolves to constexpr 32 on
   CUDA and at::cuda::warp_size() (cached device-properties lookup)
   on ROCm. Host-side dim3 / quotient sites switch to kWarpSizeHost.
   Device code keeps kWarpSize (per-arch correct since pytorch#5739).

   The host-side dispatch macros (DISPATCH_OPTIMAL_KERNEL,
   DISPATCH_NON_VEC_BLOCKING_KERNEL, DISPATCH_VEC_BLOCKING_KERNEL,
   DISPATCH_OPTIMAL_FORWARD_KERNEL) emit a wave-specific
   _WAVE32/_WAVE64 pair and select at runtime via kWarpSizeHost().
   Single-wave-size builds emit only the matching table.

   A new CMake-time mechanism in cmake/Hip.cmake derives
   FBGEMM_HAS_WAVE32 and FBGEMM_HAS_WAVE64 from PYTORCH_ROCM_ARCH and
   passes them through to the codegen. Single-arch wheels (gfx1100-only
   or gfx90a-only) emit only the matching wave's bracket table and only
   the matching wave's kernel instantiations, so wheel size does not
   grow vs the pre-port ROCm wheel. The multi-arch wheel pays the cost
   only when explicitly asked for.

3. The TBE LRU/LFU/LXU cache hardcoded associativity (ways per set) to
   64 on ROCm and guarded the populate kernels with
   TORCH_CHECK(warp_size()==64). Cache associativity must equal the
   device warp size, because one warp cooperatively scans the ways of a
   single set. Unlike the training kernels, the cache kernels are not
   templated on warpSize -- they use the device-pass kWarpSize constant
   for row indexing -- so the multi-arch fat binary already contains
   per-arch-correct device code; only the host-allocated cache geometry
   was wrong.

   _apply_cache_state now derives a per-instance cache associativity
   from the running device's warp size
   (torch.cuda.get_device_properties(dev).warp_size) instead of the
   hardcoded DEFAULT_ASSOC, so the cache-state tensors match the device
   kernel's per-arch indexing. The wave-64-only TORCH_CHECK guards are
   removed. DEFAULT_ASSOC remains as the CPU/no-device fallback.

4. warpReduceAllSum on ROCm dispatched unconditionally to
   rocm::wave_reduce, whose inner dpp_reduction is gated on
   `defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__)`
   and compiles to a no-op on every other arch. On gfx1100 (and any
   gfx9 not in that list, e.g. gfx908) wave_reduce therefore returned
   readlane(warpSize-1) of un-reduced data instead of the warp sum,
   silently corrupting the grad_indice_weights reduction (and every
   other warpReduceAllSum use). The CDNA row_bcast DPP controls cannot
   be reused on RDNA, so warpReduceAllSum now takes the rocm::wave_reduce
   path only on the archs where dpp_reduction is actually implemented
   (gfx942/gfx90a/gfx950) and falls back to the portable shfl_xor
   butterfly everywhere else -- correct for any warp size and the same
   path CUDA uses. (The two gates must stay in sync.)

Deferred follow-ups: the experimental gen_ai / moe / attention warpSize
cleanup is a separate change; gfx90a runtime regression should be
confirmed on a wave-64 host.

Test plan (gfx1100, warpSize 32):

    cd fbgemm_gpu
    rm -rf _skbuild dist
    PYTORCH_ROCM_ARCH=gfx1100 BUILD_ROCM_VERSION=7.2 \
        python setup.py bdist_wheel --build-target default --build-variant rocm
    pip install --force-reinstall dist/fbgemm_gpu_nightly_rocm-*.whl
    python -c "import fbgemm_gpu"

    HIP_VISIBLE_DEVICES=0 python -m pytest test/tbe/training/forward_test.py
    # 13 passed, 3 skipped (uvm_cache paths work on wave32).
    HIP_VISIBLE_DEVICES=0 python -m pytest test/tbe/cache/
    # all pass (lxu_cache, cache_test, cache_config, linearize,
    # cache_overflow).

    HIP_VISIBLE_DEVICES=0 python -m pytest \
        test/tbe/training/backward_dense_test.py
    # 1 passed -- the per_sample_weights gradcheck that exercises the
    # warpReduceAllSum fix (item 4); previously a Jacobian mismatch.
    HIP_VISIBLE_DEVICES=0 python -m pytest \
        test/tbe/training/backward_adagrad_test.py \
        test/tbe/training/backward_sgd_test.py \
        test/tbe/training/backward_none_test.py
    # backward_adagrad 11 passed/4 skipped, backward_sgd 5 passed/3 skipped,
    # backward_none 1 passed/2 skipped.

    # Multi-arch build (compile-only on this host; builds clean and imports
    # on gfx1100; gfx90a runtime regression to be run on a wave64 host):
    rm -rf _skbuild dist
    PYTORCH_ROCM_ARCH="gfx90a;gfx1100" BUILD_ROCM_VERSION=7.2 \
        python setup.py bdist_wheel --build-target default --build-variant rocm

Authored with assistance from Claude (Anthropic).
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented Jun 1, 2026

@q10 has imported this pull request. If you are a Meta employee, you can view this in D107177824.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant