Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,15 @@ endif()
set(CMAKE_CODEGEN_DIR ${CMAKE_CURRENT_SOURCE_DIR}/codegen)

macro(RUN_GEN_SCRIPT SCRIPT)
set(rocm_flag)
if(FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_ROCM)
set(rocm_flag --is_rocm)
list(APPEND rocm_flag --is_rocm)
if(FBGEMM_HAS_WAVE32)
list(APPEND rocm_flag --has_wave32)
endif()
if(FBGEMM_HAS_WAVE64)
list(APPEND rocm_flag --has_wave64)
endif()
endif()

BLOCK_PRINT(
Expand Down
16 changes: 16 additions & 0 deletions fbgemm_gpu/cmake/Hip.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ if(PYTORCH_ROCM_ARCH STREQUAL "")
endif()
message("Building FBGEMM for GPU arch: ${PYTORCH_ROCM_ARCH}")

# Derive the wave-size set in scope for this build from PYTORCH_ROCM_ARCH.
# gfx9* archs (CDNA: gfx90a, gfx940/941/942, gfx950) run wave64; gfx10/11/12
# archs (RDNA3/4: gfx1100, gfx1200, ...) run wave32. The codegen uses these
# to emit only the host-side dispatch branches needed by this wheel, so
# single-arch wheels stay free of the wrong wave's bracket table.
set(FBGEMM_HAS_WAVE32 OFF)
set(FBGEMM_HAS_WAVE64 OFF)
foreach(fbgemm_rocm_arch ${PYTORCH_ROCM_ARCH})
if(fbgemm_rocm_arch MATCHES "^gfx9")
set(FBGEMM_HAS_WAVE64 ON)
else()
set(FBGEMM_HAS_WAVE32 ON)
endif()
endforeach()
message("FBGEMM wave-size set: WAVE32=${FBGEMM_HAS_WAVE32} WAVE64=${FBGEMM_HAS_WAVE64}")

ADD_DEFINITIONS(-DNDEBUG)
# USE_ROCM flag is used inside FBGEMM_GPU C++ code
ADD_DEFINITIONS(-DUSE_ROCM)
Expand Down
121 changes: 116 additions & 5 deletions fbgemm_gpu/codegen/genscript/jinja_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,35 @@
# larger max embedding dimension.
env.globals["legacy_max_embedding_dim"] = 1024

# An optimization for ROCm
# An optimization for ROCm: wave64 archs (CDNA) want a larger items-per-warp.
env.globals["items_per_warp"] = 128 if args.is_rocm is False else 256

# Per-wave-size items_per_warp values; used by codegen helpers that emit
# kernel instantiations and host dispatch tables that adapt per-wave.
env.globals["items_per_warp32"] = 128
env.globals["items_per_wave64"] = 256

# The fixed max vectors per thread for different kernels. The numbers were
# derived from empirical studies
env.globals["fixed_max_vecs_per_thread"] = {"backward": 2, "backward_indice_weights": 6}

env.globals["dense"] = False
env.globals["is_rocm"] = args.is_rocm

# Wave-size set in scope for this build. CUDA is always wave32. On ROCm the
# values come from cmake/Hip.cmake parsing PYTORCH_ROCM_ARCH. The host
# dispatcher emits a runtime warp_size branch only when both are present.
# If a ROCm build has neither flag set (older CMake / direct codegen invoke),
# fall back to wave64 only to preserve pre-port behavior.
if args.is_rocm:
env.globals["has_wave32"] = args.has_wave32
env.globals["has_wave64"] = args.has_wave64 or (
not args.has_wave32 and not args.has_wave64
)
else:
env.globals["has_wave32"] = True
env.globals["has_wave64"] = False


################################################################################
# Helper functions in Jinja Environment
Expand Down Expand Up @@ -189,7 +208,16 @@ def dispatch_non_vec_blocking_kernel(
Generate code for kernel dispatching for kernels that do not use vector
blocking (i.e., an entire embedding row can fit in the allocated Vec4T
buffer)

Each branch emits a constexpr ``kSubwarpDivisor`` literal (the divisor
that, applied to the per-arch ``kWarpSize`` in device code or to
``kWarpSizeHost`` on host, yields the kernel's thread-group size) and a
matching ``kThreadGroupSize`` (constexpr on CUDA, runtime on ROCm). The
consumer uses ``kSubwarpDivisor`` as the kernel template argument so the
mangled name is warpSize-free; ``kThreadGroupSize`` is the value to set
block dims with.
"""
warp_size = items_per_warp // 4
blob = ""
for (
kFixedMaxVecsPerThread,
Expand All @@ -201,18 +229,21 @@ def dispatch_non_vec_blocking_kernel(
use_subwarp_shuffle,
use_vec_blocking=False,
):
kSubwarpDivisor = warp_size // kThreadGroupSize
formats = {
"max_D_val": kFixedMaxVecsPerThread * kThreadGroupSize * 4,
"kFixedMaxVecsPerThread": kFixedMaxVecsPerThread,
"kThreadGroupSize": kThreadGroupSize,
"kSubwarpDivisor": kSubwarpDivisor,
"kUseVecBlocking": kUseVecBlocking,
}
d_blob = """if (MAX_D <= {max_D_val}) { \\
[[ maybe_unused ]] const int max_vecs_per_thread = \\
{kFixedMaxVecsPerThread}; \\
constexpr int kFixedMaxVecsPerThread = {kFixedMaxVecsPerThread}; \\
[[ maybe_unused ]] constexpr int kThreadGroupSize = \\
{kThreadGroupSize}; \\
[[ maybe_unused ]] constexpr int kSubwarpDivisor = \\
{kSubwarpDivisor}; \\
[[ maybe_unused ]] const int kThreadGroupSize = \\
kWarpSizeHost() / kSubwarpDivisor; \\
[[ maybe_unused ]] constexpr bool kUseVecBlocking = \\
{kUseVecBlocking}; \\
return __VA_ARGS__(); \\
Expand All @@ -230,6 +261,8 @@ def dispatch_vec_blocking_kernel(
"""
Generate code for kernel dispatching for kernels that use vector blocking
(i.e., an entire embedding row cannot fit in the allocated Vec4T buffer)

Vec blocking always uses the full warp, so ``kSubwarpDivisor = 1``.
"""
formats = {
"max_D_val": fixed_max_vecs_per_thread * items_per_warp,
Expand All @@ -240,7 +273,8 @@ def dispatch_vec_blocking_kernel(
[[ maybe_unused ]] const int max_vecs_per_thread = \\
(MAX_D + {items_per_warp} - 1) / {items_per_warp}; \\
constexpr int kFixedMaxVecsPerThread = {fixed_max_vecs_per_thread}; \\
[[ maybe_unused ]] constexpr int kThreadGroupSize = kWarpSize; \\
[[ maybe_unused ]] constexpr int kSubwarpDivisor = 1; \\
[[ maybe_unused ]] const int kThreadGroupSize = kWarpSizeHost(); \\
[[ maybe_unused ]] constexpr bool kUseVecBlocking = true; \\
return __VA_ARGS__(); \\
} \\
Expand Down Expand Up @@ -270,6 +304,79 @@ def dispatch_optimal_kernel(
return blob


def _enabled_waves() -> list[tuple[int, int]]:
"""Return (items_per_warp, warp_size) pairs for each enabled wave size."""
waves: list[tuple[int, int]] = []
if env.globals["has_wave64"]:
waves.append((env.globals["items_per_wave64"], 64))
if env.globals["has_wave32"]:
waves.append((env.globals["items_per_warp32"], 32))
if not waves:
# Defensive fallback: codegen invoked without --has_wave* on ROCm.
waves.append((env.globals["items_per_warp"], env.globals["items_per_warp"] // 4))
return waves


def get_max_vecs_template_configs_union(
fixed_max_vecs_per_thread: int,
use_subwarp_shuffle: bool,
use_vec_blocking: bool,
) -> list[tuple[int, int, str]]:
"""
Returns the union of (kFixedMaxVecsPerThread, kSubwarpDivisor,
kUseVecBlocking) tuples needed by every wave size in scope for this build
(driven by ``has_wave32`` / ``has_wave64``). Templates use the result to
emit explicit instantiations: one kernel symbol per tuple, which serves
every enabled wave size because ``kSubwarpDivisor`` (not warpSize) is the
template parameter — the per-arch ``kThreadGroupSize`` falls out of
``kWarpSize / kSubwarpDivisor`` in the device pass.
"""
seen: set[tuple[int, int, str]] = set()
configs: list[tuple[int, int, str]] = []
for items_per_warp_local, warp_size in _enabled_waves():
for kFixedMaxVecs, kThreadGroupSize, kUseVecBlocking in get_max_vecs_template_configs(
items_per_warp_local,
fixed_max_vecs_per_thread,
use_subwarp_shuffle,
use_vec_blocking,
):
kSubwarpDivisor = warp_size // kThreadGroupSize
key = (kFixedMaxVecs, kSubwarpDivisor, kUseVecBlocking)
if key not in seen:
seen.add(key)
configs.append(key)
return configs


def get_max_vecs_template_configs_union_forward(
max_forward_embedding_dim: int,
use_subwarp_shuffle: bool,
use_vec_blocking: bool,
) -> list[tuple[int, int, str]]:
"""
Like :func:`get_max_vecs_template_configs_union`, but the
``fixed_max_vecs_per_thread`` value depends on wave size: forward kernels
use ``max_forward_embedding_dim // items_per_warp``, which differs between
wave32 and wave64 because they have different ``items_per_warp`` values.
"""
seen: set[tuple[int, int, str]] = set()
configs: list[tuple[int, int, str]] = []
for items_per_warp_local, warp_size in _enabled_waves():
fixed_max_vecs = max_forward_embedding_dim // items_per_warp_local
for kFixedMaxVecs, kThreadGroupSize, kUseVecBlocking in get_max_vecs_template_configs(
items_per_warp_local,
fixed_max_vecs,
use_subwarp_shuffle,
use_vec_blocking,
):
kSubwarpDivisor = warp_size // kThreadGroupSize
key = (kFixedMaxVecs, kSubwarpDivisor, kUseVecBlocking)
if key not in seen:
seen.add(key)
configs.append(key)
return configs


def is_valid_forward_config(
nobag: bool,
weighted: bool,
Expand Down Expand Up @@ -346,6 +453,10 @@ def compute_global_weight_decay(is_global_weight_decay_kernel: bool) -> str:
generate_optimized_grad_sum_loop_access
)
env.globals["get_max_vecs_template_configs"] = get_max_vecs_template_configs
env.globals["get_max_vecs_template_configs_union"] = get_max_vecs_template_configs_union
env.globals["get_max_vecs_template_configs_union_forward"] = (
get_max_vecs_template_configs_union_forward
)
env.globals["dispatch_optimal_kernel"] = dispatch_optimal_kernel
env.globals["dispatch_non_vec_blocking_kernel"] = dispatch_non_vec_blocking_kernel
env.globals["dispatch_vec_blocking_kernel"] = dispatch_vec_blocking_kernel
Expand Down
5 changes: 5 additions & 0 deletions fbgemm_gpu/codegen/genscript/scripts_argsparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
)
parser.add_argument("--opensource", action="store_false", dest="is_fbcode")
parser.add_argument("--is_rocm", action="store_true")
# CMake-derived wave-size set for the current ROCm build. CUDA builds ignore
# these (always wave32). Both unset on a ROCm build defaults to wave64-only
# to preserve pre-port behavior.
parser.add_argument("--has_wave32", action="store_true")
parser.add_argument("--has_wave64", action="store_true")

args: argparse.Namespace
_: list[str]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ Tensor pruned_hashmap_lookup_cuda(
(int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<
index_t,
hash_t>),
nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSizeHost()),
dim3(kWarpSizeHost(), kForwardMaxThreads / kWarpSizeHost()),
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(indices, index_t, 1, 32),
Expand Down Expand Up @@ -238,8 +238,8 @@ Tensor pruned_array_lookup_cuda(
index_t,
remap_t>),
nbit::div_round_up(
offsets.size(0), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
offsets.size(0), kForwardMaxThreads / kWarpSizeHost()),
dim3(kWarpSizeHost(), kForwardMaxThreads / kWarpSizeHost()),
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(indices, index_t, 1, 32),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
FBGEMM_LAUNCH_KERNEL( \
({{ func_name }}<index_t, output_t, OutputRowsPerThread, kWarpsPerBlock, InputRowsInFlight, MinNum128BRows, MaxNum128BRows, DeviceOnly, PackedMode>), \
nbit::div_round_up(T * nbit::div_round_up(B, num_packed_bags * OutputRowsPerThread), kWarpsPerBlock), \
dim3(kWarpSize, kWarpsPerBlock), \
dim3(kWarpSizeHost(), kWarpsPerBlock), \
0, \
at::cuda::getCurrentCUDAStream(), \
PTA_B(dev_weights, uint8_t, 1, 64), \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "fbgemm_gpu/utils/dispatch_macros.h"
#include "fbgemm_gpu/utils/ops_utils.h"
#include "fbgemm_gpu/split_embeddings_utils.cuh"
#include "fbgemm_gpu/utils/cuda_prelude.cuh"
#include "fbgemm_gpu/config/feature_gates.h"

using Tensor = at::Tensor;
Expand Down Expand Up @@ -959,7 +960,10 @@ class {{ autograd_func }} :
TORCH_CHECK_EQ(grad_outputs.size(), 1);

#ifdef USE_ROCM
constexpr int32_t BT_block_size = 64;
// BT_block_size matches the active device's warp size on ROCm (32 on
// wave32 archs like gfx1100, 64 on wave64 archs like gfx90a). Multi-arch
// wheels must read this at runtime, hence the const int (not constexpr).
const int32_t BT_block_size = kWarpSizeHost();
constexpr int32_t max_segment_length_per_warp = 16384;
#else
constexpr int32_t BT_block_size = 32;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,79 @@
using Tensor = at::Tensor;
using namespace fbgemm_gpu;

#define DISPATCH_NON_VEC_BLOCKING_KERNEL(MAX_D, ...) \
[&] { \
{%- if has_wave64 %}
#define DISPATCH_NON_VEC_BLOCKING_KERNEL_WAVE64(MAX_D, ...) \
[&] { \
{{
dispatch_non_vec_blocking_kernel(
items_per_warp,
items_per_wave64,
fixed_max_vecs_per_thread["backward_indice_weights"],
use_subwarp_shuffle=False,
)
-}}
}()
#define DISPATCH_VEC_BLOCKING_KERNEL_WAVE64(MAX_D, ...) \
[&] { \
{{
dispatch_vec_blocking_kernel(
items_per_wave64,
fixed_max_vecs_per_thread["backward_indice_weights"],
)
-}}
}()
{%- endif %}

#define DISPATCH_VEC_BLOCKING_KERNEL(MAX_D, ...) \
[&] { \
{%- if has_wave32 %}
#define DISPATCH_NON_VEC_BLOCKING_KERNEL_WAVE32(MAX_D, ...) \
[&] { \
{{
dispatch_non_vec_blocking_kernel(
items_per_warp32,
fixed_max_vecs_per_thread["backward_indice_weights"],
use_subwarp_shuffle=False,
)
-}}
}()
#define DISPATCH_VEC_BLOCKING_KERNEL_WAVE32(MAX_D, ...) \
[&] { \
{{
dispatch_vec_blocking_kernel(
items_per_warp,
items_per_warp32,
fixed_max_vecs_per_thread["backward_indice_weights"],
)
-}}
}()
{%- endif %}

#define DISPATCH_NON_VEC_BLOCKING_KERNEL(MAX_D, ...) \
[&] { \
{%- if has_wave32 and has_wave64 %}
if (kWarpSizeHost() == 64) { \
return DISPATCH_NON_VEC_BLOCKING_KERNEL_WAVE64(MAX_D, __VA_ARGS__); \
} else { \
return DISPATCH_NON_VEC_BLOCKING_KERNEL_WAVE32(MAX_D, __VA_ARGS__); \
} \
{%- elif has_wave64 %}
return DISPATCH_NON_VEC_BLOCKING_KERNEL_WAVE64(MAX_D, __VA_ARGS__); \
{%- else %}
return DISPATCH_NON_VEC_BLOCKING_KERNEL_WAVE32(MAX_D, __VA_ARGS__); \
{%- endif %}
}()

#define DISPATCH_VEC_BLOCKING_KERNEL(MAX_D, ...) \
[&] { \
{%- if has_wave32 and has_wave64 %}
if (kWarpSizeHost() == 64) { \
return DISPATCH_VEC_BLOCKING_KERNEL_WAVE64(MAX_D, __VA_ARGS__); \
} else { \
return DISPATCH_VEC_BLOCKING_KERNEL_WAVE32(MAX_D, __VA_ARGS__); \
} \
{%- elif has_wave64 %}
return DISPATCH_VEC_BLOCKING_KERNEL_WAVE64(MAX_D, __VA_ARGS__); \
{%- else %}
return DISPATCH_VEC_BLOCKING_KERNEL_WAVE32(MAX_D, __VA_ARGS__); \
{%- endif %}
}()

{%- for vbe in ([True, False]) %}
{%- set vdesc = "_vbe" if vbe else "" %}
Expand Down Expand Up @@ -548,8 +601,8 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
cache_t,
index_t,
kFixedMaxVecsPerThread>),
div_round_up(total_B, kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
div_round_up(total_B, kForwardMaxThreads / kWarpSizeHost()),
dim3(kWarpSizeHost(), kForwardMaxThreads / kWarpSizeHost()),
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(grad_output_reshaped, grad_t, 2, 64),
Expand Down
Loading
Loading