From a4e700c3d90dfdd7445afe3e275d1aabdb971adc Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Fri, 29 May 2026 03:33:12 +0000 Subject: [PATCH] [ROCm] support warpSize 32 and 64 in a single fbgemm_gpu build Continues the warpSize-32 enablement from pytorch/FBGEMM#5739 which fixed device-side kWarpSize. This change makes a single ROCm fbgemm_gpu wheel run correctly on both wave-32 and wave-64 AMD archs, including a mixed-arch build: PYTORCH_ROCM_ARCH="gfx90a;gfx1100" Four root causes were fixed. 1. TBE kernel templates bake kThreadGroupSize into the mangled name. On a mixed-arch build, the host pass instantiates kernels with the ROCm 64 placeholder; the gfx1100 device pass instantiates with 32. The linker can't find the gfx1100 device symbol. The kernel templates are reparameterised: instead of taking kThreadGroupSize as a template integer (which carries warpSize into the mangling), they take kSubwarpDivisor. The kernel body declares `constexpr int32_t kThreadGroupSize = kWarpSize / kSubwarpDivisor;` so the per-arch device pass picks the right value via the already-correct device-side kWarpSize. The mangled name carries the divisor literal, which is wave-size-free, so host and every device pass agree on the symbol. For a multi-arch wheel this is ~30% fewer symbols than emitting separate wave32 and wave64 instantiations (one symbol per bracket serves both waves; per-arch device code only generates the matching impl). 2. Host-side launch configurations use kWarpSize, which on ROCm is a placeholder 64 on the host pass and produces wrong block dims on a mixed-arch wheel running on gfx1100. A new kWarpSizeHost() inline function resolves to constexpr 32 on CUDA and at::cuda::warp_size() (cached device-properties lookup) on ROCm. Host-side dim3 / quotient sites switch to kWarpSizeHost. Device code keeps kWarpSize (per-arch correct since #5739). The host-side dispatch macros (DISPATCH_OPTIMAL_KERNEL, DISPATCH_NON_VEC_BLOCKING_KERNEL, DISPATCH_VEC_BLOCKING_KERNEL, DISPATCH_OPTIMAL_FORWARD_KERNEL) emit a wave-specific _WAVE32/_WAVE64 pair and select at runtime via kWarpSizeHost(). Single-wave-size builds emit only the matching table. A new CMake-time mechanism in cmake/Hip.cmake derives FBGEMM_HAS_WAVE32 and FBGEMM_HAS_WAVE64 from PYTORCH_ROCM_ARCH and passes them through to the codegen. Single-arch wheels (gfx1100-only or gfx90a-only) emit only the matching wave's bracket table and only the matching wave's kernel instantiations, so wheel size does not grow vs the pre-port ROCm wheel. The multi-arch wheel pays the cost only when explicitly asked for. 3. The TBE LRU/LFU/LXU cache hardcoded associativity (ways per set) to 64 on ROCm and guarded the populate kernels with TORCH_CHECK(warp_size()==64). Cache associativity must equal the device warp size, because one warp cooperatively scans the ways of a single set. Unlike the training kernels, the cache kernels are not templated on warpSize -- they use the device-pass kWarpSize constant for row indexing -- so the multi-arch fat binary already contains per-arch-correct device code; only the host-allocated cache geometry was wrong. _apply_cache_state now derives a per-instance cache associativity from the running device's warp size (torch.cuda.get_device_properties(dev).warp_size) instead of the hardcoded DEFAULT_ASSOC, so the cache-state tensors match the device kernel's per-arch indexing. The wave-64-only TORCH_CHECK guards are removed. DEFAULT_ASSOC remains as the CPU/no-device fallback. 4. warpReduceAllSum on ROCm dispatched unconditionally to rocm::wave_reduce, whose inner dpp_reduction is gated on `defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__)` and compiles to a no-op on every other arch. On gfx1100 (and any gfx9 not in that list, e.g. gfx908) wave_reduce therefore returned readlane(warpSize-1) of un-reduced data instead of the warp sum, silently corrupting the grad_indice_weights reduction (and every other warpReduceAllSum use). The CDNA row_bcast DPP controls cannot be reused on RDNA, so warpReduceAllSum now takes the rocm::wave_reduce path only on the archs where dpp_reduction is actually implemented (gfx942/gfx90a/gfx950) and falls back to the portable shfl_xor butterfly everywhere else -- correct for any warp size and the same path CUDA uses. (The two gates must stay in sync.) Deferred follow-ups: the experimental gen_ai / moe / attention warpSize cleanup is a separate change; gfx90a runtime regression should be confirmed on a wave-64 host. Test plan (gfx1100, warpSize 32): cd fbgemm_gpu rm -rf _skbuild dist PYTORCH_ROCM_ARCH=gfx1100 BUILD_ROCM_VERSION=7.2 \ python setup.py bdist_wheel --build-target default --build-variant rocm pip install --force-reinstall dist/fbgemm_gpu_nightly_rocm-*.whl python -c "import fbgemm_gpu" HIP_VISIBLE_DEVICES=0 python -m pytest test/tbe/training/forward_test.py # 13 passed, 3 skipped (uvm_cache paths work on wave32). HIP_VISIBLE_DEVICES=0 python -m pytest test/tbe/cache/ # all pass (lxu_cache, cache_test, cache_config, linearize, # cache_overflow). HIP_VISIBLE_DEVICES=0 python -m pytest \ test/tbe/training/backward_dense_test.py # 1 passed -- the per_sample_weights gradcheck that exercises the # warpReduceAllSum fix (item 4); previously a Jacobian mismatch. HIP_VISIBLE_DEVICES=0 python -m pytest \ test/tbe/training/backward_adagrad_test.py \ test/tbe/training/backward_sgd_test.py \ test/tbe/training/backward_none_test.py # backward_adagrad 11 passed/4 skipped, backward_sgd 5 passed/3 skipped, # backward_none 1 passed/2 skipped. # Multi-arch build (compile-only on this host; builds clean and imports # on gfx1100; gfx90a runtime regression to be run on a wave64 host): rm -rf _skbuild dist PYTORCH_ROCM_ARCH="gfx90a;gfx1100" BUILD_ROCM_VERSION=7.2 \ python setup.py bdist_wheel --build-target default --build-variant rocm Authored with assistance from Claude (Anthropic). --- fbgemm_gpu/CMakeLists.txt | 9 +- fbgemm_gpu/cmake/Hip.cmake | 16 +++ .../codegen/genscript/jinja_environment.py | 121 +++++++++++++++++- .../codegen/genscript/scripts_argsparse.py | 5 + ...mbedding_forward_quantized_split_lookup.cu | 8 +- ...ward_quantized_split_nbit_host_template.cu | 2 +- ...embedding_backward_split_host_template.cpp | 6 +- ..._backward_split_indice_weights_template.cu | 69 ++++++++-- ...ding_backward_split_kernel_cta_template.cu | 22 ++-- ...ing_backward_split_kernel_warp_template.cu | 40 +++--- .../embedding_backward_split_template.cu | 76 ++++++++--- ...embedding_forward_split_kernel_template.cu | 22 ++-- .../embedding_forward_split_template.cu | 83 +++++++++--- ...bedding_optimizer_split_kernel_template.cu | 27 +++- .../embedding_optimizer_split_template.cu | 30 ++++- .../utils/embedding_bounds_check_v1.cu | 4 +- .../utils/embedding_bounds_check_v2.cu | 4 +- ...t_table_batched_embeddings_ops_training.py | 36 ++++-- .../include/fbgemm_gpu/utils/cuda_prelude.cuh | 47 ++++++- .../embedding_inplace_update.cu | 8 +- .../src/input_combine_ops/input_combine.cu | 6 +- ...atched_dense_vec_jagged_2d_mul_backward.cu | 4 +- ...batched_dense_vec_jagged_2d_mul_forward.cu | 2 +- fbgemm_gpu/src/jagged_tensor_ops/common.cuh | 4 +- .../layout_transform_ops.cu | 4 +- .../permute_multi_embedding_ops.cu | 4 +- .../permute_pooled_embedding_ops.cu | 2 +- .../permute_pooled_embedding_ops_split.cu | 2 +- .../quantize_fused_8bit_rowwise.cu | 2 +- .../sparse_block_bucketize_features.cu | 8 +- ...rse_block_bucketize_features_2d_weights.cu | 2 +- .../sparse_compute_frequency_sequence.cu | 4 +- .../sparse_expand_into_jagged_permute.cu | 4 +- .../sparse_ops/sparse_reorder_batched_ad.cu | 6 +- .../split_embeddings_cache/lfu_cache_find.cu | 4 +- .../lfu_cache_populate.cu | 13 +- .../lfu_cache_populate_byte.cu | 4 +- .../split_embeddings_cache/lru_cache_find.cu | 4 +- .../lru_cache_populate.cu | 13 +- .../lru_cache_populate_byte.cu | 8 +- .../src/split_embeddings_cache/lxu_cache.cu | 23 +--- .../ssd_split_embeddings_cache_cuda.cu | 16 +-- .../test/tbe/cache/cache_overflow_test.py | 14 +- fbgemm_gpu/test/tbe/cache/lxu_cache_test.py | 31 ++++- 44 files changed, 594 insertions(+), 225 deletions(-) diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index e2181cd0aa..a30360e0a4 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -182,8 +182,15 @@ endif() set(CMAKE_CODEGEN_DIR ${CMAKE_CURRENT_SOURCE_DIR}/codegen) macro(RUN_GEN_SCRIPT SCRIPT) + set(rocm_flag) if(FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_ROCM) - set(rocm_flag --is_rocm) + list(APPEND rocm_flag --is_rocm) + if(FBGEMM_HAS_WAVE32) + list(APPEND rocm_flag --has_wave32) + endif() + if(FBGEMM_HAS_WAVE64) + list(APPEND rocm_flag --has_wave64) + endif() endif() BLOCK_PRINT( diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index 445dd99d24..8e306dd35b 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -34,6 +34,22 @@ if(PYTORCH_ROCM_ARCH STREQUAL "") endif() message("Building FBGEMM for GPU arch: ${PYTORCH_ROCM_ARCH}") +# Derive the wave-size set in scope for this build from PYTORCH_ROCM_ARCH. +# gfx9* archs (CDNA: gfx90a, gfx940/941/942, gfx950) run wave64; gfx10/11/12 +# archs (RDNA3/4: gfx1100, gfx1200, ...) run wave32. The codegen uses these +# to emit only the host-side dispatch branches needed by this wheel, so +# single-arch wheels stay free of the wrong wave's bracket table. +set(FBGEMM_HAS_WAVE32 OFF) +set(FBGEMM_HAS_WAVE64 OFF) +foreach(fbgemm_rocm_arch ${PYTORCH_ROCM_ARCH}) + if(fbgemm_rocm_arch MATCHES "^gfx9") + set(FBGEMM_HAS_WAVE64 ON) + else() + set(FBGEMM_HAS_WAVE32 ON) + endif() +endforeach() +message("FBGEMM wave-size set: WAVE32=${FBGEMM_HAS_WAVE32} WAVE64=${FBGEMM_HAS_WAVE64}") + ADD_DEFINITIONS(-DNDEBUG) # USE_ROCM flag is used inside FBGEMM_GPU C++ code ADD_DEFINITIONS(-DUSE_ROCM) diff --git a/fbgemm_gpu/codegen/genscript/jinja_environment.py b/fbgemm_gpu/codegen/genscript/jinja_environment.py index e3aaec473b..e22ad6b072 100644 --- a/fbgemm_gpu/codegen/genscript/jinja_environment.py +++ b/fbgemm_gpu/codegen/genscript/jinja_environment.py @@ -64,9 +64,14 @@ # larger max embedding dimension. env.globals["legacy_max_embedding_dim"] = 1024 -# An optimization for ROCm +# An optimization for ROCm: wave64 archs (CDNA) want a larger items-per-warp. env.globals["items_per_warp"] = 128 if args.is_rocm is False else 256 +# Per-wave-size items_per_warp values; used by codegen helpers that emit +# kernel instantiations and host dispatch tables that adapt per-wave. +env.globals["items_per_warp32"] = 128 +env.globals["items_per_wave64"] = 256 + # The fixed max vectors per thread for different kernels. The numbers were # derived from empirical studies env.globals["fixed_max_vecs_per_thread"] = {"backward": 2, "backward_indice_weights": 6} @@ -74,6 +79,20 @@ env.globals["dense"] = False env.globals["is_rocm"] = args.is_rocm +# Wave-size set in scope for this build. CUDA is always wave32. On ROCm the +# values come from cmake/Hip.cmake parsing PYTORCH_ROCM_ARCH. The host +# dispatcher emits a runtime warp_size branch only when both are present. +# If a ROCm build has neither flag set (older CMake / direct codegen invoke), +# fall back to wave64 only to preserve pre-port behavior. +if args.is_rocm: + env.globals["has_wave32"] = args.has_wave32 + env.globals["has_wave64"] = args.has_wave64 or ( + not args.has_wave32 and not args.has_wave64 + ) +else: + env.globals["has_wave32"] = True + env.globals["has_wave64"] = False + ################################################################################ # Helper functions in Jinja Environment @@ -189,7 +208,16 @@ def dispatch_non_vec_blocking_kernel( Generate code for kernel dispatching for kernels that do not use vector blocking (i.e., an entire embedding row can fit in the allocated Vec4T buffer) + + Each branch emits a constexpr ``kSubwarpDivisor`` literal (the divisor + that, applied to the per-arch ``kWarpSize`` in device code or to + ``kWarpSizeHost`` on host, yields the kernel's thread-group size) and a + matching ``kThreadGroupSize`` (constexpr on CUDA, runtime on ROCm). The + consumer uses ``kSubwarpDivisor`` as the kernel template argument so the + mangled name is warpSize-free; ``kThreadGroupSize`` is the value to set + block dims with. """ + warp_size = items_per_warp // 4 blob = "" for ( kFixedMaxVecsPerThread, @@ -201,18 +229,21 @@ def dispatch_non_vec_blocking_kernel( use_subwarp_shuffle, use_vec_blocking=False, ): + kSubwarpDivisor = warp_size // kThreadGroupSize formats = { "max_D_val": kFixedMaxVecsPerThread * kThreadGroupSize * 4, "kFixedMaxVecsPerThread": kFixedMaxVecsPerThread, - "kThreadGroupSize": kThreadGroupSize, + "kSubwarpDivisor": kSubwarpDivisor, "kUseVecBlocking": kUseVecBlocking, } d_blob = """if (MAX_D <= {max_D_val}) { \\ [[ maybe_unused ]] const int max_vecs_per_thread = \\ {kFixedMaxVecsPerThread}; \\ constexpr int kFixedMaxVecsPerThread = {kFixedMaxVecsPerThread}; \\ - [[ maybe_unused ]] constexpr int kThreadGroupSize = \\ - {kThreadGroupSize}; \\ + [[ maybe_unused ]] constexpr int kSubwarpDivisor = \\ + {kSubwarpDivisor}; \\ + [[ maybe_unused ]] const int kThreadGroupSize = \\ + kWarpSizeHost() / kSubwarpDivisor; \\ [[ maybe_unused ]] constexpr bool kUseVecBlocking = \\ {kUseVecBlocking}; \\ return __VA_ARGS__(); \\ @@ -230,6 +261,8 @@ def dispatch_vec_blocking_kernel( """ Generate code for kernel dispatching for kernels that use vector blocking (i.e., an entire embedding row cannot fit in the allocated Vec4T buffer) + + Vec blocking always uses the full warp, so ``kSubwarpDivisor = 1``. """ formats = { "max_D_val": fixed_max_vecs_per_thread * items_per_warp, @@ -240,7 +273,8 @@ def dispatch_vec_blocking_kernel( [[ maybe_unused ]] const int max_vecs_per_thread = \\ (MAX_D + {items_per_warp} - 1) / {items_per_warp}; \\ constexpr int kFixedMaxVecsPerThread = {fixed_max_vecs_per_thread}; \\ - [[ maybe_unused ]] constexpr int kThreadGroupSize = kWarpSize; \\ + [[ maybe_unused ]] constexpr int kSubwarpDivisor = 1; \\ + [[ maybe_unused ]] const int kThreadGroupSize = kWarpSizeHost(); \\ [[ maybe_unused ]] constexpr bool kUseVecBlocking = true; \\ return __VA_ARGS__(); \\ } \\ @@ -270,6 +304,79 @@ def dispatch_optimal_kernel( return blob +def _enabled_waves() -> list[tuple[int, int]]: + """Return (items_per_warp, warp_size) pairs for each enabled wave size.""" + waves: list[tuple[int, int]] = [] + if env.globals["has_wave64"]: + waves.append((env.globals["items_per_wave64"], 64)) + if env.globals["has_wave32"]: + waves.append((env.globals["items_per_warp32"], 32)) + if not waves: + # Defensive fallback: codegen invoked without --has_wave* on ROCm. + waves.append((env.globals["items_per_warp"], env.globals["items_per_warp"] // 4)) + return waves + + +def get_max_vecs_template_configs_union( + fixed_max_vecs_per_thread: int, + use_subwarp_shuffle: bool, + use_vec_blocking: bool, +) -> list[tuple[int, int, str]]: + """ + Returns the union of (kFixedMaxVecsPerThread, kSubwarpDivisor, + kUseVecBlocking) tuples needed by every wave size in scope for this build + (driven by ``has_wave32`` / ``has_wave64``). Templates use the result to + emit explicit instantiations: one kernel symbol per tuple, which serves + every enabled wave size because ``kSubwarpDivisor`` (not warpSize) is the + template parameter — the per-arch ``kThreadGroupSize`` falls out of + ``kWarpSize / kSubwarpDivisor`` in the device pass. + """ + seen: set[tuple[int, int, str]] = set() + configs: list[tuple[int, int, str]] = [] + for items_per_warp_local, warp_size in _enabled_waves(): + for kFixedMaxVecs, kThreadGroupSize, kUseVecBlocking in get_max_vecs_template_configs( + items_per_warp_local, + fixed_max_vecs_per_thread, + use_subwarp_shuffle, + use_vec_blocking, + ): + kSubwarpDivisor = warp_size // kThreadGroupSize + key = (kFixedMaxVecs, kSubwarpDivisor, kUseVecBlocking) + if key not in seen: + seen.add(key) + configs.append(key) + return configs + + +def get_max_vecs_template_configs_union_forward( + max_forward_embedding_dim: int, + use_subwarp_shuffle: bool, + use_vec_blocking: bool, +) -> list[tuple[int, int, str]]: + """ + Like :func:`get_max_vecs_template_configs_union`, but the + ``fixed_max_vecs_per_thread`` value depends on wave size: forward kernels + use ``max_forward_embedding_dim // items_per_warp``, which differs between + wave32 and wave64 because they have different ``items_per_warp`` values. + """ + seen: set[tuple[int, int, str]] = set() + configs: list[tuple[int, int, str]] = [] + for items_per_warp_local, warp_size in _enabled_waves(): + fixed_max_vecs = max_forward_embedding_dim // items_per_warp_local + for kFixedMaxVecs, kThreadGroupSize, kUseVecBlocking in get_max_vecs_template_configs( + items_per_warp_local, + fixed_max_vecs, + use_subwarp_shuffle, + use_vec_blocking, + ): + kSubwarpDivisor = warp_size // kThreadGroupSize + key = (kFixedMaxVecs, kSubwarpDivisor, kUseVecBlocking) + if key not in seen: + seen.add(key) + configs.append(key) + return configs + + def is_valid_forward_config( nobag: bool, weighted: bool, @@ -346,6 +453,10 @@ def compute_global_weight_decay(is_global_weight_decay_kernel: bool) -> str: generate_optimized_grad_sum_loop_access ) env.globals["get_max_vecs_template_configs"] = get_max_vecs_template_configs +env.globals["get_max_vecs_template_configs_union"] = get_max_vecs_template_configs_union +env.globals["get_max_vecs_template_configs_union_forward"] = ( + get_max_vecs_template_configs_union_forward +) env.globals["dispatch_optimal_kernel"] = dispatch_optimal_kernel env.globals["dispatch_non_vec_blocking_kernel"] = dispatch_non_vec_blocking_kernel env.globals["dispatch_vec_blocking_kernel"] = dispatch_vec_blocking_kernel diff --git a/fbgemm_gpu/codegen/genscript/scripts_argsparse.py b/fbgemm_gpu/codegen/genscript/scripts_argsparse.py index 171a96b926..b1650e9c3e 100644 --- a/fbgemm_gpu/codegen/genscript/scripts_argsparse.py +++ b/fbgemm_gpu/codegen/genscript/scripts_argsparse.py @@ -21,6 +21,11 @@ ) parser.add_argument("--opensource", action="store_false", dest="is_fbcode") parser.add_argument("--is_rocm", action="store_true") +# CMake-derived wave-size set for the current ROCm build. CUDA builds ignore +# these (always wave32). Both unset on a ROCm build defaults to wave64-only +# to preserve pre-port behavior. +parser.add_argument("--has_wave32", action="store_true") +parser.add_argument("--has_wave64", action="store_true") args: argparse.Namespace _: list[str] diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu index c03383aa39..016bf99abf 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu @@ -176,8 +176,8 @@ Tensor pruned_hashmap_lookup_cuda( (int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel< index_t, hash_t>), - nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSizeHost()), + dim3(kWarpSizeHost(), kForwardMaxThreads / kWarpSizeHost()), 0, at::cuda::getCurrentCUDAStream(), PTA_B(indices, index_t, 1, 32), @@ -238,8 +238,8 @@ Tensor pruned_array_lookup_cuda( index_t, remap_t>), nbit::div_round_up( - offsets.size(0), kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + offsets.size(0), kForwardMaxThreads / kWarpSizeHost()), + dim3(kWarpSizeHost(), kForwardMaxThreads / kWarpSizeHost()), 0, at::cuda::getCurrentCUDAStream(), PTA_B(indices, index_t, 1, 32), diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu index 0edd97a0c6..c6be4ab862 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu @@ -72,7 +72,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no FBGEMM_LAUNCH_KERNEL( \ ({{ func_name }}), \ nbit::div_round_up(T * nbit::div_round_up(B, num_packed_bags * OutputRowsPerThread), kWarpsPerBlock), \ - dim3(kWarpSize, kWarpsPerBlock), \ + dim3(kWarpSizeHost(), kWarpsPerBlock), \ 0, \ at::cuda::getCurrentCUDAStream(), \ PTA_B(dev_weights, uint8_t, 1, 64), \ diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 3fe516891f..f9514c5288 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -16,6 +16,7 @@ #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/ops_utils.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" +#include "fbgemm_gpu/utils/cuda_prelude.cuh" #include "fbgemm_gpu/config/feature_gates.h" using Tensor = at::Tensor; @@ -959,7 +960,10 @@ class {{ autograd_func }} : TORCH_CHECK_EQ(grad_outputs.size(), 1); #ifdef USE_ROCM - constexpr int32_t BT_block_size = 64; + // BT_block_size matches the active device's warp size on ROCm (32 on + // wave32 archs like gfx1100, 64 on wave64 archs like gfx90a). Multi-arch + // wheels must read this at runtime, hence the const int (not constexpr). + const int32_t BT_block_size = kWarpSizeHost(); constexpr int32_t max_segment_length_per_warp = 16384; #else constexpr int32_t BT_block_size = 32; diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index d978faaff9..5d51fab664 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -30,26 +30,79 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; -#define DISPATCH_NON_VEC_BLOCKING_KERNEL(MAX_D, ...) \ - [&] { \ +{%- if has_wave64 %} +#define DISPATCH_NON_VEC_BLOCKING_KERNEL_WAVE64(MAX_D, ...) \ + [&] { \ {{ dispatch_non_vec_blocking_kernel( - items_per_warp, + items_per_wave64, fixed_max_vecs_per_thread["backward_indice_weights"], use_subwarp_shuffle=False, ) -}} }() +#define DISPATCH_VEC_BLOCKING_KERNEL_WAVE64(MAX_D, ...) \ + [&] { \ + {{ + dispatch_vec_blocking_kernel( + items_per_wave64, + fixed_max_vecs_per_thread["backward_indice_weights"], + ) + -}} + }() +{%- endif %} -#define DISPATCH_VEC_BLOCKING_KERNEL(MAX_D, ...) \ - [&] { \ +{%- if has_wave32 %} +#define DISPATCH_NON_VEC_BLOCKING_KERNEL_WAVE32(MAX_D, ...) \ + [&] { \ + {{ + dispatch_non_vec_blocking_kernel( + items_per_warp32, + fixed_max_vecs_per_thread["backward_indice_weights"], + use_subwarp_shuffle=False, + ) + -}} + }() +#define DISPATCH_VEC_BLOCKING_KERNEL_WAVE32(MAX_D, ...) \ + [&] { \ {{ dispatch_vec_blocking_kernel( - items_per_warp, + items_per_warp32, fixed_max_vecs_per_thread["backward_indice_weights"], ) -}} }() +{%- endif %} + +#define DISPATCH_NON_VEC_BLOCKING_KERNEL(MAX_D, ...) \ + [&] { \ +{%- if has_wave32 and has_wave64 %} + if (kWarpSizeHost() == 64) { \ + return DISPATCH_NON_VEC_BLOCKING_KERNEL_WAVE64(MAX_D, __VA_ARGS__); \ + } else { \ + return DISPATCH_NON_VEC_BLOCKING_KERNEL_WAVE32(MAX_D, __VA_ARGS__); \ + } \ +{%- elif has_wave64 %} + return DISPATCH_NON_VEC_BLOCKING_KERNEL_WAVE64(MAX_D, __VA_ARGS__); \ +{%- else %} + return DISPATCH_NON_VEC_BLOCKING_KERNEL_WAVE32(MAX_D, __VA_ARGS__); \ +{%- endif %} + }() + +#define DISPATCH_VEC_BLOCKING_KERNEL(MAX_D, ...) \ + [&] { \ +{%- if has_wave32 and has_wave64 %} + if (kWarpSizeHost() == 64) { \ + return DISPATCH_VEC_BLOCKING_KERNEL_WAVE64(MAX_D, __VA_ARGS__); \ + } else { \ + return DISPATCH_VEC_BLOCKING_KERNEL_WAVE32(MAX_D, __VA_ARGS__); \ + } \ +{%- elif has_wave64 %} + return DISPATCH_VEC_BLOCKING_KERNEL_WAVE64(MAX_D, __VA_ARGS__); \ +{%- else %} + return DISPATCH_VEC_BLOCKING_KERNEL_WAVE32(MAX_D, __VA_ARGS__); \ +{%- endif %} + }() {%- for vbe in ([True, False]) %} {%- set vdesc = "_vbe" if vbe else "" %} @@ -548,8 +601,8 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( cache_t, index_t, kFixedMaxVecsPerThread>), - div_round_up(total_B, kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + div_round_up(total_B, kForwardMaxThreads / kWarpSizeHost()), + dim3(kWarpSizeHost(), kForwardMaxThreads / kWarpSizeHost()), 0, at::cuda::getCurrentCUDAStream(), PTA_B(grad_output_reshaped, grad_t, 2, 64), diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu index 49505a3fba..f13b7923f0 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu @@ -82,7 +82,7 @@ template < typename {{ ph_name + "_ph_t" }}, {%- endfor %} int32_t kFixedMaxVecsPerThread, - int32_t kThreadGroupSize, + int32_t kSubwarpDivisor, bool kUseVecBlocking> __global__ __launch_bounds__(kMaxThreads) void {%- if is_index_select %} @@ -165,6 +165,9 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row( {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} {%- endif %} ) { + // kThreadGroupSize derived per-arch from the device-pass kWarpSize. + // The template's mangled name carries kSubwarpDivisor, not kThreadGroupSize. + constexpr int32_t kThreadGroupSize = kWarpSize / kSubwarpDivisor; #ifdef FBGEMM_USE_SUBWARP_SHUFFLE const unsigned int shfl_sync_mask = ((1L << kThreadGroupSize) - 1) << @@ -440,7 +443,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row( index_type, ph_type_combo, kFixedMaxVecsPerThread, - kThreadGroupSize, + kSubwarpDivisor, kUseVecBlocking ) %} @@ -459,7 +462,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row {{ ph_type_combo[ph_name].primitive_type }}, {%- endfor %} {{ kFixedMaxVecsPerThread }}, - {{ kThreadGroupSize }}, + {{ kSubwarpDivisor }}, {{ kUseVecBlocking }} > ( const pta::PackedTensorAccessor64<{{ grad_type }}, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> grad_output, @@ -546,7 +549,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row ); {%- endmacro %} -{%- macro bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} +{%- macro bulk_template_instantiations(kFixedMaxVecsPerThread, kSubwarpDivisor, kUseVecBlocking) %} {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} @@ -559,7 +562,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row index_type, ph_type_combo, kFixedMaxVecsPerThread, - kThreadGroupSize, + kSubwarpDivisor, kUseVecBlocking) }} {%- endfor %} @@ -575,7 +578,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row {{ bulk_template_instantiations( fixed_max_vecs_per_thread["backward"], - 'kWarpSize', + '1', 'true' ) }} @@ -583,9 +586,8 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row {%- else %} {%- macro instantiate_templates(use_subwarp_shuffle) %} -{%- for (kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) - in get_max_vecs_template_configs( - items_per_warp, +{%- for (kFixedMaxVecsPerThread, kSubwarpDivisor, kUseVecBlocking) + in get_max_vecs_template_configs_union( fixed_max_vecs_per_thread["backward"], use_subwarp_shuffle, use_vec_blocking=True, @@ -594,7 +596,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row {{ bulk_template_instantiations( kFixedMaxVecsPerThread, - kThreadGroupSize, + kSubwarpDivisor, kUseVecBlocking, ) }} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 92a0f9712b..f35aaf5854 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -75,7 +75,7 @@ template < typename {{ ph_name + "_ph_t"}}, {%- endfor %} int32_t kFixedMaxVecsPerThread, - int32_t kThreadGroupSize, + int32_t kSubwarpDivisor, bool kUseVecBlocking> __global__ __launch_bounds__(kBackwardMaxThreads) void {%- if is_index_select %} @@ -151,6 +151,10 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} {%- endif %} ) { + // kThreadGroupSize derived per-arch from the device-pass kWarpSize. + // The template's mangled name carries kSubwarpDivisor, not kThreadGroupSize, + // so host and every per-arch device pass agree on the wrapper symbol. + constexpr int32_t kThreadGroupSize = kWarpSize / kSubwarpDivisor; {%- if not nobag %} int32_t T = D_offsets.size(0) - 1; {%- else %} @@ -359,7 +363,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( index_type, ph_type_combo, kFixedMaxVecsPerThread, - kThreadGroupSize, + kSubwarpDivisor, kUseVecBlocking ) %} @@ -379,7 +383,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {{ ph_type_combo[ph_name].primitive_type }}, {%- endfor %} {{ kFixedMaxVecsPerThread }}, - {{ kThreadGroupSize }}, + {{ kSubwarpDivisor }}, {{ kUseVecBlocking }} > ( const pta::PackedTensorAccessor64<{{ grad_type }}, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> grad_output, @@ -457,7 +461,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row ); {%- endmacro %} -{%- macro bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} +{%- macro bulk_template_instantiations(kFixedMaxVecsPerThread, kSubwarpDivisor, kUseVecBlocking) %} {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} @@ -470,7 +474,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row index_type, ph_type_combo, kFixedMaxVecsPerThread, - kThreadGroupSize, + kSubwarpDivisor, kUseVecBlocking ) }} @@ -487,7 +491,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {{ bulk_template_instantiations( fixed_max_vecs_per_thread["backward"], - 'kWarpSize', + '1', 'true' ) }} @@ -495,9 +499,8 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- else %} {%- macro instantiate_templates(use_subwarp_shuffle) %} -{%- for (kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) - in get_max_vecs_template_configs( - items_per_warp, +{%- for (kFixedMaxVecsPerThread, kSubwarpDivisor, kUseVecBlocking) + in get_max_vecs_template_configs_union( fixed_max_vecs_per_thread["backward"], use_subwarp_shuffle, use_vec_blocking=True, @@ -506,7 +509,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {{ bulk_template_instantiations( kFixedMaxVecsPerThread, - kThreadGroupSize, + kSubwarpDivisor, kUseVecBlocking, ) }} @@ -558,7 +561,7 @@ template < typename cache_t, typename index_t, int32_t kFixedMaxVecsPerThread, - int32_t kThreadGroupSize, + int32_t kSubwarpDivisor, bool kUseVecBlocking, int32_t embedding_dim, int32_t weight_decay_mode_v> @@ -684,7 +687,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd cache_type, index_type, kFixedMaxVecsPerThread, - kThreadGroupSize, + kSubwarpDivisor, kUseVecBlocking, kEmbeddingDim, kWeighDecayMode @@ -697,7 +700,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {{ cache_type }}, {{ index_type }}, {{ kFixedMaxVecsPerThread }}, - {{ kThreadGroupSize }}, + {{ kSubwarpDivisor }}, {{ kUseVecBlocking }}, {{ kEmbeddingDim }}, {{ kWeighDecayMode }} @@ -760,7 +763,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd ); {%- endmacro %} -{%- macro hip_bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} +{%- macro hip_bulk_template_instantiations(kFixedMaxVecsPerThread, kSubwarpDivisor, kUseVecBlocking) %} {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} @@ -773,7 +776,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd cache_type, index_type, kFixedMaxVecsPerThread, - kThreadGroupSize, + kSubwarpDivisor, kUseVecBlocking, kEmbeddingDim, kWeighDecayMode @@ -788,9 +791,8 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- endmacro %} {%- macro hip_instantiate_templates(use_subwarp_shuffle) %} -{%- for (kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) - in get_max_vecs_template_configs( - items_per_warp, +{%- for (kFixedMaxVecsPerThread, kSubwarpDivisor, kUseVecBlocking) + in get_max_vecs_template_configs_union( fixed_max_vecs_per_thread["backward"], use_subwarp_shuffle, use_vec_blocking=True, @@ -799,7 +801,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {{ hip_bulk_template_instantiations( kFixedMaxVecsPerThread, - kThreadGroupSize, + kSubwarpDivisor, kUseVecBlocking, ) }} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 0d37e15435..b343d2a1d5 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -66,7 +66,7 @@ template < typename {{ ph_name + "_ph_t" }}, {%- endfor %} int32_t kFixedMaxVecsPerThread, - int32_t kThreadGroupSize, + int32_t kSubwarpDivisor, bool kUseVecBlocking> __global__ __launch_bounds__(kMaxThreads) void {%- if is_index_select %} @@ -159,7 +159,7 @@ template < typename {{ ph_name + "_ph_t" }}, {%- endfor %} int32_t kFixedMaxVecsPerThread, - int32_t kThreadGroupSize, + int32_t kSubwarpDivisor, bool kUseVecBlocking> __global__ __launch_bounds__(kBackwardMaxThreads) void {%- if is_index_select %} @@ -244,7 +244,7 @@ template < typename cache_t, typename index_t, int32_t kFixedMaxVecsPerThread, - int32_t kThreadGroupSize, + int32_t kSubwarpDivisor, bool kUseVecBlocking, int32_t embedding_dim, int32_t weight_decay_mode_v> @@ -376,15 +376,17 @@ using namespace embedding_ops; {%- if is_experimental_optimizer %} /* - For the experimental optimizers, kThreadGroupSize, kFixedMaxVecsPerThread, - and kUseVecBlocking are fixed to kWarpSize, {{ fixed_max_vecs_per_thread["backward"] }}, - and true. + For the experimental optimizers, kSubwarpDivisor, kFixedMaxVecsPerThread, + and kUseVecBlocking are fixed to 1 (full warp), {{ fixed_max_vecs_per_thread["backward"] }}, + and true. kThreadGroupSize falls out of kWarpSizeHost() on host (or kWarpSize + in the kernel body) and equals the full per-arch warp size. */ #define DISPATCH_OPTIMAL_KERNEL(MAX_D, ...) \ [&] { \ const int max_vecs_per_thread = \ (max_D + {{ items_per_warp }} - 1) / {{ items_per_warp }}; \ - constexpr int kThreadGroupSize = kWarpSize; \ + constexpr int kSubwarpDivisor = 1; \ + const int kThreadGroupSize = kWarpSizeHost(); \ constexpr int kFixedMaxVecsPerThread = \ {{ fixed_max_vecs_per_thread["backward"] }}; \ constexpr bool kUseVecBlocking = true; \ @@ -397,32 +399,76 @@ using namespace embedding_ops; For the non-experimental optimizers, we determine the kernel template instantiation that is best optimized for MAX_D and invoke it. + On ROCm multi-arch builds (both has_wave32 and has_wave64), the wave32 + and wave64 bracket tables are emitted as separate _WAVE{32,64} macros and + the unified macro picks at runtime via kWarpSizeHost(). Single-wave builds + emit only the matching table. + Please see dispatch_optimal_kernel in codegen/embedding_common_code_generator.py for more details */ +{%- if has_wave64 %} #ifdef FBGEMM_USE_SUBWARP_SHUFFLE -#define DISPATCH_OPTIMAL_KERNEL(MAX_D, ...) \ +#define DISPATCH_OPTIMAL_KERNEL_WAVE64(MAX_D, ...) \ [&] { \ {{ dispatch_optimal_kernel( - items_per_warp, + items_per_wave64, fixed_max_vecs_per_thread["backward"], use_subwarp_shuffle=True) -}} }() - #else -#define DISPATCH_OPTIMAL_KERNEL(MAX_D, ...) \ +#define DISPATCH_OPTIMAL_KERNEL_WAVE64(MAX_D, ...) \ [&] { \ {{ dispatch_optimal_kernel( - items_per_warp, + items_per_wave64, fixed_max_vecs_per_thread["backward"], use_subwarp_shuffle=False) -}} }() +#endif +{%- endif %} +{%- if has_wave32 %} +#ifdef FBGEMM_USE_SUBWARP_SHUFFLE +#define DISPATCH_OPTIMAL_KERNEL_WAVE32(MAX_D, ...) \ + [&] { \ + {{ + dispatch_optimal_kernel( + items_per_warp32, + fixed_max_vecs_per_thread["backward"], + use_subwarp_shuffle=True) + -}} + }() +#else +#define DISPATCH_OPTIMAL_KERNEL_WAVE32(MAX_D, ...) \ + [&] { \ + {{ + dispatch_optimal_kernel( + items_per_warp32, + fixed_max_vecs_per_thread["backward"], + use_subwarp_shuffle=False) + -}} + }() #endif +{%- endif %} + +#define DISPATCH_OPTIMAL_KERNEL(MAX_D, ...) \ + [&] { \ +{%- if has_wave32 and has_wave64 %} + if (kWarpSizeHost() == 64) { \ + return DISPATCH_OPTIMAL_KERNEL_WAVE64(MAX_D, __VA_ARGS__); \ + } else { \ + return DISPATCH_OPTIMAL_KERNEL_WAVE32(MAX_D, __VA_ARGS__); \ + } \ +{%- elif has_wave64 %} + return DISPATCH_OPTIMAL_KERNEL_WAVE64(MAX_D, __VA_ARGS__); \ +{%- else %} + return DISPATCH_OPTIMAL_KERNEL_WAVE32(MAX_D, __VA_ARGS__); \ +{%- endif %} + }() {%- endif %} @@ -1073,7 +1119,7 @@ Tensor {{ embedding_cuda_op }}( {{ ph_name + "_ph_t" }}, {%- endfor %} kFixedMaxVecsPerThread, - kThreadGroupSize, + kSubwarpDivisor, kUseVecBlocking>; // Compute shared memory size for cta_per_row @@ -1210,7 +1256,7 @@ Tensor {{ embedding_cuda_op }}( {{ ph_name + "_ph_t" }}, {%- endfor %} kFixedMaxVecsPerThread, - kThreadGroupSize, + kSubwarpDivisor, kUseVecBlocking>; // Compute shared memory size for warp_per_row @@ -1280,7 +1326,7 @@ Tensor {{ embedding_cuda_op }}( cache_t, index_t, kFixedMaxVecsPerThread, - kThreadGroupSize, + kSubwarpDivisor, kUseVecBlocking, {{ kDimSize }}, {{ kWeightDecayMode }}>; diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index ce71aea376..642eae9b0c 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -552,7 +552,7 @@ template < {%- if not nobag %} size_t kMaxVecsPerThread, {%- endif %} - size_t kThreadGroupSize> + size_t kSubwarpDivisor> __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} batch_index_select_dim0_codegen_forward_kernel( @@ -619,6 +619,9 @@ batch_index_select_dim0_codegen_forward_kernel( // If 2D, shape is [B][total_D] pta::PackedTensorAccessor64 output ) { +// kThreadGroupSize derived per-arch from the device-pass kWarpSize. +// The template's mangled name carries kSubwarpDivisor, not kThreadGroupSize. +constexpr size_t kThreadGroupSize = kWarpSize / kSubwarpDivisor; // shfl_sync_mask is implicitly used by SHFL_SYNC #ifdef FBGEMM_USE_SUBWARP_SHUFFLE const unsigned int shfl_sync_mask = @@ -844,7 +847,7 @@ batch_index_select_dim0_codegen_forward_kernel( index_type, use_cache, kMaxVecsPerThread, - kThreadGroupSize) + kSubwarpDivisor) %} template __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} @@ -863,7 +866,7 @@ batch_index_select_dim0_codegen_forward_kernel {%- if not nobag %} {{ kMaxVecsPerThread }}, {%- endif %} - {{ kThreadGroupSize }} + {{ kSubwarpDivisor }} > ( const pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights, {%- if not dense %} @@ -916,7 +919,7 @@ batch_index_select_dim0_codegen_forward_kernel pta::PackedTensorAccessor64<{{ output_type }}, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> output); {%- endmacro %} -{%- macro bulk_template_instantiations(use_cache, kMaxVecsPerThread, kThreadGroupSize) %} +{%- macro bulk_template_instantiations(use_cache, kMaxVecsPerThread, kSubwarpDivisor) %} {%- set max_vecs_per_thread = kMaxVecsPerThread %} {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} @@ -929,7 +932,7 @@ batch_index_select_dim0_codegen_forward_kernel index_type, use_cache, max_vecs_per_thread, - kThreadGroupSize) + kSubwarpDivisor) }} {%- endfor %} {%- endfor %} @@ -945,10 +948,9 @@ batch_index_select_dim0_codegen_forward_kernel legacy_max_embedding_dim if has_experimental else max_embedding_dim %} {%- for use_cache in (["true", "false"] if not dense else ["NULL"]) %} -{%- for (kMaxVecsPerThread, kThreadGroupSize, use_blocking) - in get_max_vecs_template_configs( - items_per_warp, - fixed_max_vecs_per_thread=max_forward_embedding_dim // items_per_warp, +{%- for (kMaxVecsPerThread, kSubwarpDivisor, use_blocking) + in get_max_vecs_template_configs_union_forward( + max_forward_embedding_dim, use_subwarp_shuffle=use_subwarp_shuffle, use_vec_blocking=False, ) @@ -959,7 +961,7 @@ batch_index_select_dim0_codegen_forward_kernel bulk_template_instantiations( use_cache, kMaxVecsPerThread, - kThreadGroupSize + kSubwarpDivisor ) }} {%- endif %} diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu index 7318e45d1a..932193b2d3 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -153,7 +153,7 @@ template < {%- if not nobag %} size_t kMaxVecsPerThread, {%- endif %} - size_t kThreadGroupSize = kWarpSize + size_t kSubwarpDivisor > __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} @@ -232,32 +232,73 @@ batch_index_select_dim0_codegen_forward_kernel( max_forward_embedding_dim ) %} - {%- set fixed_max_vecs_per_thread = max_forward_embedding_dim // items_per_warp%} + {%- set fixed_max_vecs_per_thread_wave32 = max_forward_embedding_dim // items_per_warp32 %} + {%- set fixed_max_vecs_per_thread_wave64 = max_forward_embedding_dim // items_per_wave64 %} +{%- if has_wave64 %} #ifdef FBGEMM_USE_SUBWARP_SHUFFLE -#define {{ dispatch_macro_name }}(MAX_D, ...) \ - [&] { \ +#define {{ dispatch_macro_name }}_WAVE64(MAX_D, ...) \ + [&] { \ {{ dispatch_non_vec_blocking_kernel( - items_per_warp, - fixed_max_vecs_per_thread, + items_per_wave64, + fixed_max_vecs_per_thread_wave64, use_subwarp_shuffle=True) -}} - return; \ + return; \ }() - #else -#define {{ dispatch_macro_name }}(MAX_D, ...) \ - [&] { \ +#define {{ dispatch_macro_name }}_WAVE64(MAX_D, ...) \ + [&] { \ {{ dispatch_non_vec_blocking_kernel( - items_per_warp, - fixed_max_vecs_per_thread, + items_per_wave64, + fixed_max_vecs_per_thread_wave64, use_subwarp_shuffle=False) -}} - return; \ + return; \ }() - #endif +{%- endif %} +{%- if has_wave32 %} +#ifdef FBGEMM_USE_SUBWARP_SHUFFLE +#define {{ dispatch_macro_name }}_WAVE32(MAX_D, ...) \ + [&] { \ + {{ + dispatch_non_vec_blocking_kernel( + items_per_warp32, + fixed_max_vecs_per_thread_wave32, + use_subwarp_shuffle=True) + -}} + return; \ + }() +#else +#define {{ dispatch_macro_name }}_WAVE32(MAX_D, ...) \ + [&] { \ + {{ + dispatch_non_vec_blocking_kernel( + items_per_warp32, + fixed_max_vecs_per_thread_wave32, + use_subwarp_shuffle=False) + -}} + return; \ + }() +#endif +{%- endif %} +#define {{ dispatch_macro_name }}(MAX_D, ...) \ + [&] { \ +{%- if has_wave32 and has_wave64 %} + if (kWarpSizeHost() == 64) { \ + {{ dispatch_macro_name }}_WAVE64(MAX_D, __VA_ARGS__); \ + } else { \ + {{ dispatch_macro_name }}_WAVE32(MAX_D, __VA_ARGS__); \ + } \ +{%- elif has_wave64 %} + {{ dispatch_macro_name }}_WAVE64(MAX_D, __VA_ARGS__); \ +{%- else %} + {{ dispatch_macro_name }}_WAVE32(MAX_D, __VA_ARGS__); \ +{%- endif %} + return; \ + }() {% endmacro %} {#- @@ -632,8 +673,8 @@ batch_index_select_dim0_codegen_forward_cuda( output_t, index_t, kEmbeddingSize / 4>), - div_round_up(total_B, kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + div_round_up(total_B, kForwardMaxThreads / kWarpSizeHost()), + dim3(kWarpSizeHost(), kForwardMaxThreads / kWarpSizeHost()), 0, at::cuda::getCurrentCUDAStream(), PTA_B(dev_weights, emb_t, 1, 64), @@ -677,13 +718,13 @@ batch_index_select_dim0_codegen_forward_cuda( FBGEMM_LAUNCH_KERNEL( ({{ nobag_kernel }} {%- if dense or is_index_select %} - + {%- else %} - + {%- endif %} ), - div_round_up(total_B, kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + div_round_up(total_B, kForwardMaxThreads / kWarpSizeHost()), + dim3(kWarpSizeHost(), kForwardMaxThreads / kWarpSizeHost()), 0, at::cuda::getCurrentCUDAStream(), PTA_B(dev_weights, emb_t, 1, 64), @@ -752,7 +793,7 @@ batch_index_select_dim0_codegen_forward_cuda( {%- endif %} index_t, kMaxVecsPerThread, - kThreadGroupSize>), + kSubwarpDivisor>), grid, dim3(kThreadGroupSize, kForwardMaxThreads / kThreadGroupSize), 0, diff --git a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_kernel_template.cu b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_kernel_template.cu index 527abe6b92..7e6281c63b 100644 --- a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_kernel_template.cu @@ -9,6 +9,14 @@ // clang-format off #include "gen_embedding_optimizer_{{ optimizer }}_split_device_kernel.cuh" +// Template parameter kSubwarpDivisor: kThreadGroupSize is derived per-arch +// inside the kernel body as (kWarpSize / kSubwarpDivisor). This keeps the +// kernel's mangled name free of warpSize, so the host pass and every +// per-arch device pass agree on the symbol. The actual kThreadGroupSize +// used at runtime is set by the host launcher to (kWarpSizeHost() / +// kSubwarpDivisor), which matches the device-side value because kWarpSize +// (device) and kWarpSizeHost() (host) report the same warp size for the +// active arch. kSubwarpDivisor = 1 for the full-warp case. template < typename emb_t, typename cache_t, @@ -16,7 +24,7 @@ template < typename {{ ph_name + "_ph_t"}}, {%- endfor %} size_t kMaxVecsPerThread, - int32_t kThreadGroupSize = kWarpSize, + int32_t kSubwarpDivisor, int32_t VEC_WIDTH > __global__ __launch_bounds__(kMaxThreads) @@ -37,6 +45,7 @@ void split_{{ optimizer }}_update_kernel( at::PhiloxCudaState stochastic_rounding_philox_args, {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} ) { + constexpr int32_t kThreadGroupSize = kWarpSize / kSubwarpDivisor; const auto run_id = blockIdx.x * blockDim.y + threadIdx.y; if (run_id >= grad_dev_indices.size(0)) { return; @@ -109,16 +118,22 @@ void split_{{ optimizer }}_update_kernel( {%- for cache_type in ['float', 'at::Half'] %} {%- for ph_type_combo in args.placeholder_type_combos %} +{#- Emit instantiations for the union of (kMaxVecsPerThread, kSubwarpDivisor) + needed by every wave size in scope. Wave32 needs more brackets than wave64 + to cover the same max_D range (smaller kThreadGroupSize per arch), so the + wave32 set is a superset of the wave64 set; we iterate it whenever wave32 + is enabled, and just iterate the wave64 set otherwise. -#} +{%- set _items_per_warp_eff = items_per_warp32 if has_wave32 else items_per_wave64 %} {%- set tuples = [] %} -{%- for kMaxElemPerThread in range(1, legacy_max_embedding_dim // (items_per_warp // 4) + 1) %} +{%- for kMaxElemPerThread in range(1, legacy_max_embedding_dim // (_items_per_warp_eff // 4) + 1) %} {%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} {%- set t0 = [ (kMaxElemPerThread // 4), 1 ] | max if not nobag else "NULL" %} {%- set t1 = [ 4 // kMaxElemPerThread, 1] | max %} - {%- set temp = tuples.append((t0, "(kWarpSize / " ~ t1 ~ ")" if use_subwarp else "kWarpSize")) %} + {%- set temp = tuples.append((t0, t1 if use_subwarp else 1)) %} {%- endif %} {%- endfor %} -{%- for (kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %} +{%- for (kMaxVecsPerThread, kSubwarpDivisor) in tuples | unique %} template __global__ __launch_bounds__(kMaxThreads) void split_{{ optimizer }}_update_kernel < {{ emb_type }}, @@ -127,7 +142,7 @@ void split_{{ optimizer }}_update_kernel {{ ph_type_combo[ph_name] }}, {%- endfor %} {{ kMaxVecsPerThread }}, - {{ kThreadGroupSize }}, + {{ kSubwarpDivisor }}, 4 // VEC_WIDTH >( pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights, @@ -151,7 +166,7 @@ void split_{{ optimizer }}_update_kernel replace("cache_t", cache_type) }}); -{%- endfor %} // for (kMaxVecsPerThread, kThreadGroupSize) +{%- endfor %} // for (kMaxVecsPerThread, kSubwarpDivisor) {%- endfor %} // for ph_type_combo {%- endfor %} // for cache_type {%- endfor %} // for emb_type diff --git a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_template.cu b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_template.cu index 8a3a81b79e..b9734c25dd 100644 --- a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_template.cu +++ b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_template.cu @@ -21,7 +21,7 @@ template < typename {{ ph_name + "_ph_t" }}, {%- endfor %} size_t kMaxVecsPerThread, - int32_t kThreadGroupSize = kWarpSize, + int32_t kSubwarpDivisor, int32_t VEC_WIDTH > __global__ __launch_bounds__(kMaxThreads) void @@ -162,17 +162,21 @@ void split_embedding_{{ optimizer }}_update( at::check_generator(gen) ->philox_cuda_state(4); } - {%- for kMaxElemPerThread in range(1, legacy_max_embedding_dim // (items_per_warp // 4) + 1) %} + {%- macro emit_optimizer_bracket_chain(items_per_warp_local) %} + {%- for kMaxElemPerThread in range(1, legacy_max_embedding_dim // (items_per_warp_local // 4) + 1) %} {%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} - if (max_D <= {{ items_per_warp // 4 * kMaxElemPerThread }}) { + if (max_D <= {{ items_per_warp_local // 4 * kMaxElemPerThread }}) { // hipcc can't use max in constexpr constexpr int kMaxVecsPerThread = {{ kMaxElemPerThread }} / 4 >= 1 ? {{ kMaxElemPerThread }} / 4 : 1; - // If max_D is small, use fewer number of threads than kWarpSize. + // kSubwarpDivisor is the literal integer divider of warpSize; + // the kernel template's mangled name carries this, not + // warpSize itself, so host and per-arch device passes agree. #ifdef FBGEMM_USE_SUBWARP_SHUFFLE - constexpr int kThreadGroupSize = kWarpSize / std::max(4 / {{ kMaxElemPerThread }}, 1); + constexpr int kSubwarpDivisor = std::max(4 / {{ kMaxElemPerThread }}, 1); #else - constexpr int kThreadGroupSize = kWarpSize; + constexpr int kSubwarpDivisor = 1; #endif + const int kThreadGroupSize = kWarpSizeHost() / kSubwarpDivisor; DISPATCH_PLACEHOLDER_TYPES( {%- for ph_name in args.placeholder_tensor_names %} @@ -188,7 +192,7 @@ void split_embedding_{{ optimizer }}_update( {{ ph_name + "_ph_t" }}, {%- endfor %} kMaxVecsPerThread, - kThreadGroupSize, + kSubwarpDivisor, 4>), div_round_up(grad_dev_indices.numel(), kMaxThreads / kThreadGroupSize), dim3(kThreadGroupSize, kMaxThreads / kThreadGroupSize, 1), @@ -215,6 +219,18 @@ void split_embedding_{{ optimizer }}_update( } {%- endif %} {%- endfor %} + {%- endmacro %} + {%- if has_wave32 and has_wave64 %} + if (kWarpSizeHost() == 64) { + {{ emit_optimizer_bracket_chain(items_per_wave64) }} + } else { + {{ emit_optimizer_bracket_chain(items_per_warp32) }} + } + {%- elif has_wave64 %} + {{ emit_optimizer_bracket_chain(items_per_wave64) }} + {%- else %} + {{ emit_optimizer_bracket_chain(items_per_warp32) }} + {%- endif %} } // DISPATCH_EMB_CACHE_TYPES ); } diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v1.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v1.cu index 4bbc079937..f60610077b 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v1.cu +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v1.cu @@ -195,8 +195,8 @@ void _bounds_check_indices_cuda_v1( : bounds_check_indices_kernel_v1); FBGEMM_LAUNCH_DSA_KERNEL( bounds_check_kernel, - div_round_up(max_B_ * T, kNumThreads / fbgemm_gpu::kWarpSize), - dim3(fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize), + div_round_up(max_B_ * T, kNumThreads / fbgemm_gpu::kWarpSizeHost()), + dim3(fbgemm_gpu::kWarpSizeHost(), kNumThreads / fbgemm_gpu::kWarpSizeHost()), 0, at::cuda::getCurrentCUDAStream(), PTA_B(rows_per_table, int64_t, 1, 32), diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu index e2fd64ba8f..1857a7a04b 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu @@ -237,7 +237,7 @@ void _bounds_check_indices_cuda_v2( constexpr size_t kNumThreads = 1024; auto grid_dim = - min(div_round_up(total_B, kNumThreads / fbgemm_gpu::kWarpSize), + min(div_round_up(total_B, kNumThreads / fbgemm_gpu::kWarpSizeHost()), get_max_thread_blocks_()); if (prefetch_pipeline) { // Limit the grid size to PREFETCH_KERNEL_MAX_BLOCKS if running this kernel @@ -259,7 +259,7 @@ void _bounds_check_indices_cuda_v2( bounds_check_kernel, \ grid_dim, \ dim3( \ - fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize), \ + fbgemm_gpu::kWarpSizeHost(), kNumThreads / fbgemm_gpu::kWarpSizeHost()), \ 0, \ at::cuda::getCurrentCUDAStream(), \ PTA_B(rows_per_table, int64_t, 1, 32), \ diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index d5d9dd8c27..95b749c2f3 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -3750,6 +3750,20 @@ def _apply_cache_state( self.timestep = 1 self.timesteps_prefetched = [] + # Cache associativity (ways per set) must equal the device warp size: + # the cache kernels use one warp to cooperatively scan the ways of a + # single set. CDNA (gfx9xx) is warp 64, RDNA (gfx11xx) is warp 32, and + # NVIDIA is warp 32. Query the actual device rather than assuming, so a + # single multi-arch ROCm build is correct on both wave sizes. The + # module-level DEFAULT_ASSOC is only a fallback for the CPU/no-device + # path. + if self.use_cpu: + self.cache_assoc = DEFAULT_ASSOC + else: + self.cache_assoc = torch.cuda.get_device_properties( + self.current_device + ).warp_size + self.max_prefetch_depth = MAX_PREFETCH_DEPTH self.lxu_cache_locations_list = [] self.lxu_cache_locations_empty = torch.empty( @@ -3832,24 +3846,24 @@ def _apply_cache_state( assert free_memory > 0 cache_sets = ( int(cache_state.total_cache_hash_size * cache_load_factor) - + DEFAULT_ASSOC + + self.cache_assoc - 1 - ) // DEFAULT_ASSOC + ) // self.cache_assoc cache_sets = 1 if cache_sets == 0 else cache_sets - cache_size = cache_sets * DEFAULT_ASSOC * element_size * self.max_D_cache + cache_size = cache_sets * self.cache_assoc * element_size * self.max_D_cache if cache_size > free_memory: cache_sets = ( int(1.0 * free_memory / self.max_D_cache / element_size) - + DEFAULT_ASSOC + + self.cache_assoc - 1 - ) // DEFAULT_ASSOC + ) // self.cache_assoc cache_load_factor = ( - 1.0 * cache_sets * DEFAULT_ASSOC / int(cache_state.total_cache_hash_size) + 1.0 * cache_sets * self.cache_assoc / int(cache_state.total_cache_hash_size) ) assert cache_sets > 0 if cache_algorithm == CacheAlgorithm.LFU: assert cache_sets < 2**24 - 1 - cache_size = cache_sets * DEFAULT_ASSOC * element_size * self.max_D_cache + cache_size = cache_sets * self.cache_assoc * element_size * self.max_D_cache self.log( f"Using on-device cache with admission algorithm " f"{cache_algorithm}, {cache_sets} sets, " @@ -3882,14 +3896,14 @@ def _apply_cache_state( self.register_buffer( "lxu_cache_state", torch.zeros( - cache_sets, DEFAULT_ASSOC, device=self.current_device, dtype=torch.int64 + cache_sets, self.cache_assoc, device=self.current_device, dtype=torch.int64 ).fill_(-1), ) # Cache itself, not auxiliary size self.register_buffer( "lxu_cache_weights", torch.zeros( - cache_sets * DEFAULT_ASSOC, + cache_sets * self.cache_assoc, self.max_D_cache, device=self.current_device, dtype=dtype, @@ -3903,7 +3917,7 @@ def _apply_cache_state( size=( (self.total_cache_hash_size + 1,) if cache_algorithm == CacheAlgorithm.LFU - else (cache_sets, DEFAULT_ASSOC) + else (cache_sets, self.cache_assoc) ), device=self.current_device, dtype=torch.int64, @@ -4047,7 +4061,7 @@ def _init_uvm_cache_counter(self, cache_sets: int, persistent: bool) -> None: "lxu_cache_locking_counter", torch.zeros( cache_sets, - DEFAULT_ASSOC, + self.cache_assoc, device=self.current_device, dtype=torch.int32, ), diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh index f9a42cfc49..1591e7a11e 100755 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -60,16 +60,16 @@ namespace fbgemm_gpu { #define DIV_ROUND_UP(a, b) (a + b - 1) / b -// Warp size +// Warp size — device-pass constant // // Device code: per-arch constexpr. __GFX9__ is defined by HIP-Clang during // device compilation for gfx9xx targets (warpSize 64); it is undefined for // warpSize 32 targets. HIP-Clang runs the device-code backend once per // --offload-arch, so the same source produces correct per-arch device code. // -// Host code on ROCm: warpSize is only known at runtime. This 64 is a -// stop-gap constexpr. Host-side launches that must size thread blocks for -// the active device should call at::cuda::warp_size() directly. +// Host code on ROCm: warpSize is only known at runtime. The 64 below is a +// stop-gap so __global__ kernel bodies parse on the host pass; never rely +// on it in host-side computation. Use kWarpSizeHost (below) instead. #if !defined(USE_ROCM) static constexpr int32_t kWarpSize = 32; #elif defined(__GFX9__) @@ -80,6 +80,32 @@ static constexpr int32_t kWarpSize = 32; static constexpr int32_t kWarpSize = 64; #endif +// Host-side warp size +// +// Use this in host code anywhere kWarpSize would be wrong on a ROCm +// multi-arch build (block-dim computations, grid sizing, etc.). It is: +// * CUDA: a constexpr function returning 32. Usable as a template +// argument and in static_assert. +// * ROCm: an inline function returning at::cuda::warp_size() — a runtime +// query of the active device. NOT constexpr; do not use as a template +// argument. Cheap (a cached device-properties array lookup). +// +// Always invoke with parentheses (kWarpSizeHost()) so the same call shape +// works under both back-ends and inside namespace-qualified call sites. +// +// Do not use in __device__ code: on ROCm the at::cuda::warp_size() call +// is not callable from device. Use kWarpSize there, which is per-arch +// correct via the device pass. +#if defined(USE_ROCM) +inline int32_t kWarpSizeHost() { + return at::cuda::warp_size(); +} +#else +inline constexpr int32_t kWarpSizeHost() { + return 32; +} +#endif + // Max thread num in one thread block static constexpr int32_t kMaxThreads = 1024; @@ -162,11 +188,22 @@ DEVICE_INLINE uint32_t ballot_sync( } /// Sums a register value across all warp threads +// +// The rocm::wave_reduce DPP path is only correct on the CDNA archs whose +// hand-tuned assembly / row_bcast DPP controls it relies on (gfx942, gfx90a, +// gfx950); on other archs dpp_reduction compiles to a no-op and wave_reduce +// would return an un-reduced lane value. RDNA (wave 32) also lacks the +// row_bcast DPP controls. Everywhere except those CDNA archs, use the portable +// shfl_xor butterfly, which is correct for any ReduceWidth and arch (and is the +// CUDA path). __shfl_xor on ROCm adjusts internally for the device warp size. +// The arch macros are device-pass-only, so per-arch device code in a multi-arch +// fat binary selects the right path automatically. template DEVICE_INLINE T warpReduceAllSum( T val, unsigned shfl_sync_mask = static_cast(kFullWarpMask)) { -#ifdef USE_ROCM +#if defined(USE_ROCM) && \ + (defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__)) return rocm::wave_reduce< rocm::reduce_op::sum, // Sum reduction T, // Data type diff --git a/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu b/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu index 90ca71d03c..75c383f5eb 100644 --- a/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu +++ b/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu @@ -237,7 +237,7 @@ void embedding_inplace_update_cuda( } TORCH_CHECK_EQ(N, update_table_idx.numel()); - const int32_t warpsPerBlock = kMaxThreads / kWarpSize; + const int32_t warpsPerBlock = kMaxThreads / kWarpSizeHost(); auto lxu_cache_weights_value = lxu_cache_weights.value_or( at::empty({0, 0}, dev_weights.options().dtype(at::kByte))); @@ -250,7 +250,7 @@ void embedding_inplace_update_cuda( FBGEMM_LAUNCH_KERNEL( (embedding_inplace_update_kernel_1), nbit::div_round_up(N, warpsPerBlock), // number of blocks needed - dim3(kWarpSize, warpsPerBlock), // shape of each block + dim3(kWarpSizeHost(), warpsPerBlock), // shape of each block 0, at::cuda::getCurrentCUDAStream(), @@ -305,7 +305,7 @@ void embedding_inplace_update_single_placement_cuda( } TORCH_CHECK_EQ(N, update_table_idx.numel()); - const int32_t warpsPerBlock = kMaxThreads / kWarpSize; + const int32_t warpsPerBlock = kMaxThreads / kWarpSizeHost(); auto lxu_cache_weights_value = lxu_cache_weights.value_or( at::empty({0, 0}, dev_weights.options().dtype(at::kByte))); @@ -318,7 +318,7 @@ void embedding_inplace_update_single_placement_cuda( FBGEMM_LAUNCH_KERNEL( (embedding_inplace_update_kernel_2), nbit::div_round_up(N, warpsPerBlock), // number of blocks needed - dim3(kWarpSize, warpsPerBlock), // shape of each block + dim3(kWarpSizeHost(), warpsPerBlock), // shape of each block 0, at::cuda::getCurrentCUDAStream(), diff --git a/fbgemm_gpu/src/input_combine_ops/input_combine.cu b/fbgemm_gpu/src/input_combine_ops/input_combine.cu index 15505b02d0..439a4425d1 100644 --- a/fbgemm_gpu/src/input_combine_ops/input_combine.cu +++ b/fbgemm_gpu/src/input_combine_ops/input_combine.cu @@ -141,16 +141,16 @@ std::tuple tbe_input_combine_with_length_cuda( #else constexpr uint32_t VEC_WIDTH = 8; #endif - constexpr uint32_t NUM_WARPS_PER_BLOCK = kMaxThreads / kWarpSize; + const uint32_t NUM_WARPS_PER_BLOCK = kMaxThreads / kWarpSizeHost(); const auto num_warps_per_list = - div_round_up(max_list_size, kWarpSize * VEC_WIDTH); + div_round_up(max_list_size, kWarpSizeHost() * VEC_WIDTH); const auto num_blocks = div_round_up(num_warps_per_list * num_lists, NUM_WARPS_PER_BLOCK); FBGEMM_LAUNCH_KERNEL( (tbe_input_combine_with_length_kernel), num_blocks, - dim3(kWarpSize, NUM_WARPS_PER_BLOCK), + dim3(kWarpSizeHost(), NUM_WARPS_PER_BLOCK), 0, at::cuda::getCurrentCUDAStream(), combined_indices.data_ptr(), diff --git a/fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu b/fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu index 99081ee539..8378e75901 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu @@ -114,7 +114,7 @@ std::tuple batched_dense_vec_jagged_2d_mul_backward( "dense_vec_jagged_2d_bmm_backward_kernel_2", [&] { int block_dim_x = std::min( - div_round_up(max_L, kWarpSize) * kWarpSize, kMaxThreads); + div_round_up(max_L, kWarpSizeHost()) * kWarpSizeHost(), kMaxThreads); int block_dim_y = kMaxThreads / block_dim_x; FBGEMM_LAUNCH_KERNEL( @@ -129,7 +129,7 @@ std::tuple batched_dense_vec_jagged_2d_mul_backward( PTA_B(v_grad, scalar_t, 2, 32)); block_dim_x = std::min( - div_round_up(D, kWarpSize) * kWarpSize, kMaxThreads); + div_round_up(D, kWarpSizeHost()) * kWarpSizeHost(), kMaxThreads); block_dim_y = kMaxThreads / block_dim_x; FBGEMM_LAUNCH_KERNEL( diff --git a/fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu index 1f481e88d6..761f73d656 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu @@ -71,7 +71,7 @@ Tensor batched_dense_vec_jagged_2d_mul_forward( if (B > 0 && D > 0) { const int block_dim_x = - std::min(div_round_up(D, kWarpSize) * kWarpSize, kMaxThreads); + std::min(div_round_up(D, kWarpSizeHost()) * kWarpSizeHost(), kMaxThreads); const int block_dim_y = kMaxThreads / block_dim_x; AT_DISPATCH_INDEX_TYPES( diff --git a/fbgemm_gpu/src/jagged_tensor_ops/common.cuh b/fbgemm_gpu/src/jagged_tensor_ops/common.cuh index 33938bec62..16d8694da8 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/common.cuh +++ b/fbgemm_gpu/src/jagged_tensor_ops/common.cuh @@ -223,8 +223,8 @@ inline std::tuple> check_shape_and_partition_( dense_tensor.numel() / (outer_dense_size * inner_dense_size); const int threads_x = - inner_dense_size >= kWarpSize / 2 ? kWarpSize : inner_dense_size; - const int threads_y = kMaxThreads / kWarpSize; + inner_dense_size >= kWarpSizeHost() / 2 ? kWarpSizeHost() : inner_dense_size; + const int threads_y = kMaxThreads / kWarpSizeHost(); const dim3 blocks( div_round_up(outer_dense_size * jagged_folded_size, threads_y)); diff --git a/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu index e1b65068d4..c0b2d30daf 100644 --- a/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu +++ b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu @@ -138,11 +138,11 @@ Tensor recat_embedding_grad_output_mixed_D_batch_cuda( const auto dim_sum = grad_output.size(1); const dim3 threads( - fbgemm_gpu::kWarpSize, fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSize); + fbgemm_gpu::kWarpSizeHost(), fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSizeHost()); const dim3 blocks( fbgemm_gpu::div_round_up( (B_local * dim_num), - fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSize)); + fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSizeHost())); FBGEMM_DISPATCH_FLOAT_AND_HALF( grad_output.scalar_type(), "recat_embedding_gradients", [&] { diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu index c2a428f829..07d7af3f9c 100644 --- a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu @@ -272,9 +272,9 @@ std::vector permute_multi_embedding_function_gpu( // blocks. The grid z dimension is also used by batch_size in case it's // greater than 65535. const int32_t warp_per_block = - fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSize; + fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSizeHost(); const int32_t max_grid_dim = 32768; // The CUDA maximum is 65535, not 1<(batch_size), max_grid_dim), diff --git a/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu b/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu index 6373ca0f55..92ff6c6541 100644 --- a/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu +++ b/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu @@ -106,7 +106,7 @@ Tensor permute_pooled_embs_gpu_impl( // We are launching ( div_round_up(T, warp_per_block), B ) blocks. // The grid z dimension is also used by B in case it's greater than 65535. const int32_t warp_per_block = - fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSize; + fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSizeHost(); const int32_t max_grid_dim_y = 32768; // The CUDA maximum is 65535, not a power of 2. const dim3 threads(fbgemm_gpu::kMaxThreads); diff --git a/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu b/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu index 863c8234f0..14e0f32265 100644 --- a/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu +++ b/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu @@ -106,7 +106,7 @@ Tensor permute_pooled_embs_split_gpu_impl( // We are launching ( div_round_up(T, warp_per_block), B ) blocks. // The grid z dimension is also used by B in case it's greater than 65535. const int32_t warp_per_block = - fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSize; + fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSizeHost(); const int32_t max_grid_dim_y = 32768; // The CUDA maximum is 65535, not a power of 2. const dim3 threads(fbgemm_gpu::kMaxThreads); diff --git a/fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu index b1ad10b8e1..d9c82068a6 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu @@ -599,7 +599,7 @@ DLL_PUBLIC at::Tensor _fused8bitrowwise_to_float_mixed_dim_gpu( return output; } constexpr int threads_per_block = 256; - const dim3 blockDim(kWarpSize, threads_per_block / kWarpSize); + const dim3 blockDim(kWarpSizeHost(), threads_per_block / kWarpSizeHost()); const dim3 gridDim( cuda_calc_xblock_count(num_tables * batch_size, blockDim.y)); FBGEMM_DISPATCH_FLOAT_AND_HALF( diff --git a/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu b/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu index 10d3c7b7e0..067b935799 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu @@ -409,7 +409,7 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel( // Uses ballot_sync + popc to count preceding elements with same bucket, // which preserves the original ordering within each bucket. // All threads execute the same instructions (no branch divergence). -// Note: my_size is limited to kWarpSize (32) because we use warp-level ballot +// Note: my_size is limited to kWarpSizeHost() (32) because we use warp-level ballot // operations and store per-bucket counts in registers. template @@ -899,7 +899,7 @@ _block_bucketize_sparse_features_cuda( block_bucketize_pos_concat.device(), true); } static_assert(kMaxThreads % kWarpSize == 0); - dim3 block_dims(kWarpSize, kMaxThreads / kWarpSize); + dim3 block_dims(kWarpSizeHost(), kMaxThreads / kWarpSizeHost()); dim3 grid_dims(cuda_calc_xblock_count(lengths_size, block_dims.y)); const auto smem_adjust_threshold = at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock; @@ -1171,8 +1171,8 @@ DLL_PUBLIC Tensor populate_bucketized_permute_cuda( const static bool warp_kernel_enabled = config::is_feature_enabled( config::FeatureGateName::BUCKETIZED_PERMUTE_WARP_KERNEL); - if (my_size <= kWarpSize && warp_kernel_enabled) { - const auto warps_per_block = kMaxThreads / kWarpSize; + if (my_size <= kWarpSizeHost() && warp_kernel_enabled) { + const auto warps_per_block = kMaxThreads / kWarpSizeHost(); const auto num_blocks = cuda_calc_xblock_count(lengths_size, warps_per_block); FBGEMM_LAUNCH_KERNEL( diff --git a/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features_2d_weights.cu b/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features_2d_weights.cu index f0b65332aa..e1ab60e50e 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features_2d_weights.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features_2d_weights.cu @@ -668,7 +668,7 @@ _block_bucketize_sparse_features_2d_weights_cuda( block_bucketize_pos_concat.device(), true); } static_assert(kMaxThreads % kWarpSize == 0); - dim3 block_dims(kWarpSize, kMaxThreads / kWarpSize); + dim3 block_dims(kWarpSizeHost(), kMaxThreads / kWarpSizeHost()); dim3 grid_dims(cuda_calc_xblock_count(lengths_size, block_dims.y)); const auto smem_adjust_threshold = at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock; diff --git a/fbgemm_gpu/src/sparse_ops/sparse_compute_frequency_sequence.cu b/fbgemm_gpu/src/sparse_ops/sparse_compute_frequency_sequence.cu index 6b51fee0ee..34b5b0ba23 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_compute_frequency_sequence.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_compute_frequency_sequence.cu @@ -39,8 +39,8 @@ DLL_PUBLIC void compute_frequency_sequence( input.scalar_type(), "compute_frequency_sequence_kernel_1", [&] { FBGEMM_LAUNCH_KERNEL( (compute_frequency_sequence_kernel), - cuda_calc_xblock_count(input.numel(), kWarpSize), - kWarpSize, + cuda_calc_xblock_count(input.numel(), kWarpSizeHost()), + kWarpSizeHost(), 0, at::cuda::getCurrentCUDAStream(), input.data_ptr(), diff --git a/fbgemm_gpu/src/sparse_ops/sparse_expand_into_jagged_permute.cu b/fbgemm_gpu/src/sparse_ops/sparse_expand_into_jagged_permute.cu index d0e59eedd7..fce5231bd6 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_expand_into_jagged_permute.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_expand_into_jagged_permute.cu @@ -53,8 +53,8 @@ DLL_PUBLIC Tensor expand_into_jagged_permute_cuda( Tensor output_permute = at::empty({output_size}, permute.options()); // number of table per block - constexpr int32_t T_blocks = kMaxThreads / kWarpSize; - dim3 threads(kWarpSize, T_blocks); + const int32_t T_blocks = kMaxThreads / kWarpSizeHost(); + dim3 threads(kWarpSizeHost(), T_blocks); const auto blocks = cuda_calc_xblock_count(permute_size, T_blocks); AT_DISPATCH_INDEX_TYPES( permute.scalar_type(), "expand_into_jagged_permute_kernel", [&] { diff --git a/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu b/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu index f86ed1d614..8a6adefa8f 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu @@ -390,7 +390,7 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu( if (B == 1) { // for B = 1 broadcast case constexpr auto NUM_WARPS = 16; - const dim3 threads(NUM_WARPS * kWarpSize); // 16 x 32 + const dim3 threads(NUM_WARPS * kWarpSizeHost()); // 16 x 32 const dim3 blocks(cuda_calc_xblock_count( reordered_cat_ad_offsets.numel() - 1, NUM_WARPS)); // one warp per sample @@ -419,7 +419,7 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu( } else { // for B > 1 and B < 64 broadcast case constexpr auto NUM_WARPS = 16; - const dim3 threads(NUM_WARPS * kWarpSize); // 16 x 32 + const dim3 threads(NUM_WARPS * kWarpSizeHost()); // 16 x 32 const dim3 blocks(cuda_calc_xblock_count( T * num_ads_in_batch, NUM_WARPS)); // num_ads_in_batch warps for all Bs @@ -469,7 +469,7 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu( auto maxWarpSize = kMaxThreads / NUM_WARPS; const dim3 threads( NUM_WARPS, - maxWarpSize < kWarpSize ? maxWarpSize : kWarpSize); // 32 x 32 + maxWarpSize < kWarpSizeHost() ? maxWarpSize : kWarpSizeHost()); // 32 x 32 const dim3 blocks(cuda_calc_xblock_count(B * T, NUM_WARPS)); #endif FBGEMM_LAUNCH_KERNEL( diff --git a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_find.cu b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_find.cu index 40b387211e..5e000ce877 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_find.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_find.cu @@ -124,9 +124,9 @@ std::pair lfu_cache_find_uncached_cuda( FBGEMM_LAUNCH_KERNEL( (lfu_cache_find_uncached_kernel), std::min( - div_round_up(N, kMaxThreads / kWarpSize), + div_round_up(N, kMaxThreads / kWarpSizeHost()), get_max_thread_blocks_for_cache_kernels_()), - dim3(kWarpSize, kMaxThreads / kWarpSize), + dim3(kWarpSizeHost(), kMaxThreads / kWarpSizeHost()), 0, at::cuda::getCurrentCUDAStream(), PTA_B(unique_indices, index_t, 1, 32), diff --git a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu index 19fda6886a..10f17fd49c 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu @@ -200,9 +200,9 @@ void lfu_cache_insert_cuda( FBGEMM_LAUNCH_KERNEL( (lfu_cache_insert_kernel), std::min( - div_round_up(N, kCacheMaxThreads / kWarpSize), + div_round_up(N, kCacheMaxThreads / kWarpSizeHost()), get_max_thread_blocks_for_cache_kernels_()), - dim3(kWarpSize, kCacheMaxThreads / kWarpSize), + dim3(kWarpSizeHost(), kCacheMaxThreads / kWarpSizeHost()), 0, at::cuda::getCurrentCUDAStream(), PTA_B(weights, emb_t, 1, 64), @@ -248,15 +248,6 @@ DLL_PUBLIC void lfu_cache_populate_cuda( CUDA_DEVICE_GUARD(weights); -#ifdef USE_ROCM - TORCH_CHECK( - at::cuda::warp_size() == 64, - __func__, - ": TBE cache requires warpSize 64 on ROCm (got ", - at::cuda::warp_size(), - "); warpSize 32 devices are not yet supported"); -#endif - TORCH_CHECK( linear_cache_indices.numel() < std::numeric_limits::max()); if (linear_cache_indices.numel() == 0) { diff --git a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu index 286c26aa34..e984d12be1 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu @@ -176,9 +176,9 @@ void lfu_cache_insert_byte_cuda( FBGEMM_LAUNCH_KERNEL( (lfu_cache_insert_byte_kernel), std::min( - div_round_up(N, kCacheMaxThreads / kWarpSize), + div_round_up(N, kCacheMaxThreads / kWarpSizeHost()), get_max_thread_blocks_for_cache_kernels_()), - dim3(kWarpSize, kCacheMaxThreads / kWarpSize), + dim3(kWarpSizeHost(), kCacheMaxThreads / kWarpSizeHost()), 0, at::cuda::getCurrentCUDAStream(), PTA_B(weights, uint8_t, 1, 64), diff --git a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu index 56984160ce..f8dca9b074 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu @@ -214,7 +214,7 @@ lru_cache_find_uncached_cuda( constexpr int PREFETCH_KERNEL_MAX_BLOCKS = 8; auto grid_size = std::min( - div_round_up(N, kMaxThreads / kWarpSize), + div_round_up(N, kMaxThreads / kWarpSizeHost()), lock_cache_line ? PREFETCH_KERNEL_MAX_BLOCKS : get_max_thread_blocks_for_cache_kernels_()); @@ -222,7 +222,7 @@ lru_cache_find_uncached_cuda( FBGEMM_LAUNCH_KERNEL( (lru_cache_find_uncached_kernel), grid_size, - dim3(kWarpSize, kMaxThreads / kWarpSize), + dim3(kWarpSizeHost(), kMaxThreads / kWarpSizeHost()), 0, at::cuda::getCurrentCUDAStream(), PTA_B(unique_indices, index_t, 1, 32), diff --git a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu index f64ed21e02..f46c923537 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu @@ -225,12 +225,12 @@ void lru_cache_insert_cuda( auto grid_size = lock_cache_line ? div_round_up(get_device_sm_cnt_(), ALL_TO_PREFETCH_SM_RATIO) - : div_round_up(N, kMaxThreads / kWarpSize); + : div_round_up(N, kMaxThreads / kWarpSizeHost()); FBGEMM_LAUNCH_KERNEL( (lru_cache_insert_kernel), grid_size, - dim3(kWarpSize, kMaxThreads / kWarpSize), + dim3(kWarpSizeHost(), kMaxThreads / kWarpSizeHost()), 0, at::cuda::getCurrentCUDAStream(), PTA_B(weights, emb_t, 1, 64), @@ -301,15 +301,6 @@ DLL_PUBLIC void lru_cache_populate_cuda( CUDA_DEVICE_GUARD(weights); -#ifdef USE_ROCM - TORCH_CHECK( - at::cuda::warp_size() == 64, - __func__, - ": TBE cache requires warpSize 64 on ROCm (got ", - at::cuda::warp_size(), - "); warpSize 32 devices are not yet supported"); -#endif - TORCH_CHECK( linear_cache_indices.numel() < std::numeric_limits::max()); if (linear_cache_indices.numel() == 0) { diff --git a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu index 000c105d11..e0ab983b5f 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu @@ -397,9 +397,9 @@ void lru_cache_insert_byte_cuda( FBGEMM_LAUNCH_KERNEL( (lru_cache_insert_byte_kernel), std::min( - div_round_up(N, kMaxThreads / kWarpSize), + div_round_up(N, kMaxThreads / kWarpSizeHost()), get_max_thread_blocks_for_cache_kernels_()), - dim3(kWarpSize, kMaxThreads / kWarpSize), + dim3(kWarpSizeHost(), kMaxThreads / kWarpSizeHost()), 0, at::cuda::getCurrentCUDAStream(), PTA_B(weights, uint8_t, 1, 64), @@ -462,9 +462,9 @@ void direct_mapped_lru_cache_insert_byte_cuda( FBGEMM_LAUNCH_KERNEL( (direct_mapped_lru_cache_insert_byte_kernel), std::min( - div_round_up(N, kMaxThreads / kWarpSize), + div_round_up(N, kMaxThreads / kWarpSizeHost()), get_max_thread_blocks_for_cache_kernels_()), - dim3(kWarpSize, kMaxThreads / kWarpSize), + dim3(kWarpSizeHost(), kMaxThreads / kWarpSizeHost()), 0, at::cuda::getCurrentCUDAStream(), PTA_B(weights, uint8_t, 1, 64), diff --git a/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu b/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu index 7c3f6f5b7f..08225fd18b 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu @@ -223,8 +223,8 @@ void lxu_cache_locking_counter_decrement_cuda( auto count = at::zeros_like(lxu_cache_locking_counter); const int32_t C = lxu_cache_locking_counter.size(0); - TORCH_CHECK(lxu_cache_locking_counter.size(1) == kWarpSize); - auto fd = FixedDivisor(kWarpSize); + TORCH_CHECK(lxu_cache_locking_counter.size(1) == kWarpSizeHost()); + auto fd = FixedDivisor(kWarpSizeHost()); const dim3 blocks( std::min( @@ -244,9 +244,9 @@ void lxu_cache_locking_counter_decrement_cuda( FBGEMM_LAUNCH_KERNEL( lxu_cache_locking_counter_decrement_kernel, std::min( - div_round_up(C, kMaxThreads / kWarpSize), + div_round_up(C, kMaxThreads / kWarpSizeHost()), get_max_thread_blocks_for_cache_kernels_()), - dim3(kWarpSize, kMaxThreads / kWarpSize), + dim3(kWarpSizeHost(), kMaxThreads / kWarpSizeHost()), 0, at::cuda::getCurrentCUDAStream(), PTA_B(lxu_cache_locking_counter, int32_t, 2, 32), @@ -434,19 +434,6 @@ DLL_PUBLIC Tensor lxu_cache_lookup_cuda( CUDA_DEVICE_GUARD(linear_cache_indices); -#ifdef USE_ROCM - // Cache kernels use kWarpSize as the stride for indexing cache rows. - // On warpSize 32 ROCm devices (e.g. gfx1100) this produces wrong - // addresses because DEFAULT_ASSOC (kWarpSize on host) is 64. - // D102579845 introduces kCacheAssoc to fix this; until then, guard. - TORCH_CHECK( - at::cuda::warp_size() == 64, - __func__, - ": TBE cache requires warpSize 64 on ROCm (got ", - at::cuda::warp_size(), - "); warpSize 32 devices are not yet supported"); -#endif - const auto lxu_cache_locations = lxu_cache_locations_output.value_or(empty_like( linear_cache_indices, @@ -458,7 +445,7 @@ DLL_PUBLIC Tensor lxu_cache_lookup_cuda( return lxu_cache_locations; } - const dim3 threads(kWarpSize, kMaxThreads / kWarpSize); + const dim3 threads(kWarpSizeHost(), kMaxThreads / kWarpSizeHost()); const dim3 blocks(div_round_up(N, kMaxThreads)); AT_DISPATCH_INDEX_TYPES( diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu index 8e4819637b..5ccffe832c 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu @@ -328,7 +328,7 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel( assigned_cache_slots[n + l] = -1; } else { evicted_indices[n + l] = current_idx; // -1 if not set, >= 0 if valid. - assigned_cache_slots[n + l] = cache_set * kWarpSize + insert_slot; + assigned_cache_slots[n + l] = cache_set * kWarpSizeHost() + insert_slot; // TODO: Check if we can do contiguous writes here. // Update cache states @@ -343,7 +343,7 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel( } // Conflict misses - for (auto l = kWarpSize + threadIdx.x; l < SL; l += kWarpSize) { + for (auto l = kWarpSizeHost() + threadIdx.x; l < SL; l += kWarpSizeHost()) { evicted_indices[n + l] = -1; assigned_cache_slots[n + l] = -1; } @@ -447,8 +447,8 @@ ssd_cache_populate_actions_cuda( FBGEMM_LAUNCH_DSA_KERNEL( ssd_cache_actions_insert_kernel, - div_round_up(N, kMaxThreads / kWarpSize), - dim3(kWarpSize, kMaxThreads / kWarpSize), + div_round_up(N, kMaxThreads / kWarpSizeHost()), + dim3(kWarpSizeHost(), kMaxThreads / kWarpSizeHost()), 0, at::cuda::getCurrentCUDAStream(), PTA_B(lxu_cache_state, int64_t, 2, 32), @@ -558,7 +558,7 @@ std::tuple ssd_generate_row_addrs_cuda( lxu_cache_locations.options().dtype(at::kLong)); const auto post_bwd_evicted_indices = at::empty_like(ssd_row_addrs); - constexpr auto kNumWarps = kMaxThreads / kWarpSize; + const auto kNumWarps = kMaxThreads / kWarpSizeHost(); const auto cache_row_bytes = lxu_cache_weights.size(1) * lxu_cache_weights.element_size(); const auto lxu_cache_weights_addr = @@ -578,7 +578,7 @@ std::tuple ssd_generate_row_addrs_cuda( FBGEMM_LAUNCH_KERNEL( (ssd_generate_row_addrs_kernel), div_round_up(lxu_cache_locations.numel(), kNumWarps), - dim3(kWarpSize, kNumWarps), + dim3(kWarpSizeHost(), kNumWarps), 0, at::cuda::getCurrentCUDAStream(), PTA_B(ssd_row_addrs, int64_t, 1, 32), @@ -684,12 +684,12 @@ void ssd_update_row_addrs_cuda( reinterpret_cast(inserted_ssd_weights_next.data_ptr()); const auto cache_row_bytes = lxu_cache_weights.size(1) * lxu_cache_weights.element_size(); - constexpr auto kNumWarps = kMaxThreads / kWarpSize; + const auto kNumWarps = kMaxThreads / kWarpSizeHost(); FBGEMM_LAUNCH_KERNEL( (ssd_update_row_addrs_kernel), div_round_up(ssd_row_addrs_curr.numel(), kNumWarps), - dim3(kWarpSize, kNumWarps), + dim3(kWarpSizeHost(), kNumWarps), 0, at::cuda::getCurrentCUDAStream(), PTA_B(ssd_row_addrs_curr, int64_t, 1, 32), diff --git a/fbgemm_gpu/test/tbe/cache/cache_overflow_test.py b/fbgemm_gpu/test/tbe/cache/cache_overflow_test.py index 8da8b199f1..c7c24cd93a 100644 --- a/fbgemm_gpu/test/tbe/cache/cache_overflow_test.py +++ b/fbgemm_gpu/test/tbe/cache/cache_overflow_test.py @@ -22,6 +22,16 @@ from ..common import assert_torch_equal, MAX_EXAMPLES from .cache_common import assert_cache, generate_cache_tbes, gpu_unavailable, VERBOSITY +# TBE cache associativity equals the device warp size (64 on CDNA, 32 on RDNA +# and NVIDIA). DEFAULT_ASSOC is the static fallback; query the actual device +# so the cache-sizing math is correct on warpSize 32 ROCm devices. +if torch.cuda.is_available(): + WARP_SIZE: int = torch.cuda.get_device_properties( + torch.cuda.current_device() + ).warp_size +else: + WARP_SIZE = DEFAULT_ASSOC + class CacheOverflowTest(unittest.TestCase): @unittest.skipIf(*gpu_unavailable) @@ -45,7 +55,7 @@ def test_cache_int32_overflow(self, stochastic_rounding: bool) -> None: # Weight and cache precisions are fixed to FP16 element_size = 2 # Adjust cache_sets based on free memory - while cache_sets * DEFAULT_ASSOC * D * element_size > free_memory: + while cache_sets * WARP_SIZE * D * element_size > free_memory: cache_sets = cache_sets // 10 # Generate TBEs @@ -62,7 +72,7 @@ def test_cache_int32_overflow(self, stochastic_rounding: bool) -> None: # Accessing the last cache slot last_cache_set = cache_sets - 1 - cache_idx = last_cache_set * DEFAULT_ASSOC + (DEFAULT_ASSOC - 1) + cache_idx = last_cache_set * WARP_SIZE + (WARP_SIZE - 1) if cache_idx * D < (2**31) - 1: logging.warning("test_cache_int32_overflow does not test int32 overflowing") else: diff --git a/fbgemm_gpu/test/tbe/cache/lxu_cache_test.py b/fbgemm_gpu/test/tbe/cache/lxu_cache_test.py index bb9f377d13..bfbab8f043 100644 --- a/fbgemm_gpu/test/tbe/cache/lxu_cache_test.py +++ b/fbgemm_gpu/test/tbe/cache/lxu_cache_test.py @@ -27,20 +27,39 @@ VERBOSITY: Verbosity = Verbosity.verbose +# TBE cache associativity equals the device warp size (64 on CDNA, 32 on +# RDNA and NVIDIA). DEFAULT_ASSOC is the static fallback; query the actual +# device so these cache-layout tests are correct on warpSize 32 ROCm devices. +if torch.cuda.is_available(): + WARP_SIZE: int = torch.cuda.get_device_properties( + torch.cuda.current_device() + ).warp_size +else: + WARP_SIZE = DEFAULT_ASSOC + @optests.generate_opcheck_tests(fast=True) class LXUCacheTest(unittest.TestCase): @unittest.skipIf(*gpu_unavailable) @given( - associativity=st.sampled_from([1, DEFAULT_ASSOC]), + associativity=st.sampled_from([1, WARP_SIZE]), ) @settings(deadline=None) def test_lxu_cache_lookup(self, associativity: int) -> None: max_index: int = 8000 # Use single cache set to avoid dealing with cache set hash algorithm. - lxu_cache_state_gpu = ( - torch.arange(associativity, dtype=torch.int64).unsqueeze(0).cuda() + # The lookup kernel scans a full warp of ways per set, so the cache + # state row must be warpSize wide; `associativity` is how many ways are + # populated and the remainder are the empty sentinel (-1), matching how + # the real cache represents unpopulated ways. (A row narrower than the + # warp would make the kernel read past the row.) + lxu_cache_state_gpu = torch.full( + (1, WARP_SIZE), -1, dtype=torch.int64 + ) + lxu_cache_state_gpu[0, :associativity] = torch.arange( + associativity, dtype=torch.int64 ) + lxu_cache_state_gpu = lxu_cache_state_gpu.cuda() # Testing all miss. linear_cache_indices_0 = ( @@ -106,7 +125,7 @@ def test_lxu_cache_locking_counter_decrement( self, cache_sets: int, ) -> None: - warp_size = DEFAULT_ASSOC + warp_size = WARP_SIZE N = cache_sets * warp_size lxu_cache_locking_counter = torch.randint( low=1, @@ -305,7 +324,7 @@ def duplicate_lookup( cache_sets = int((E * T) * 0.2) lxu_cache_state = torch.zeros( cache_sets, - DEFAULT_ASSOC, + WARP_SIZE, device="cuda", dtype=torch.int64, ).fill_(-1) @@ -336,7 +355,7 @@ def duplicate_lookup( if c not in slots: slots[c] = 0 slot = slots[c] - if slot < DEFAULT_ASSOC: + if slot < WARP_SIZE: lxu_cache_state[c][slot] = idx slots[c] = slot + 1