Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -165,46 +165,80 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no

#pragma unroll
for (uint32_t outer_i = 0; outer_i < OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; outer_i += kRowUnroll) {
uint4 row_data_v[kRowUnroll];
const uint4* row_v[kRowUnroll];
int32_t idx_v[kRowUnroll];
int32_t cache_idx_v[kRowUnroll];
bool row_valid_v[kRowUnroll];
Copy link
Copy Markdown
Contributor

@spcyppt spcyppt Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you ensure all the changes only affect rocm?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole block (lines 162-228) is under if is_rocm jinja guard

#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) {
uint32_t i = outer_i + inner_i;
bool valid = load_idx_valid && L_start + input_row_idx < Ls[i];
bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid);
idx_v[inner_i] = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1;
cache_idx_v[inner_i] = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1;
row_valid_v[inner_i] = valid;
}


#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) {
uint32_t i = outer_i + inner_i;
bool valid = load_idx_valid && L_start + input_row_idx < Ls[i];
bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid);
valid = valid && (idx_v[inner_i] != -1);
bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && row_valid_v[inner_i]);
bool final_valid = row_valid_v[inner_i] && (idx_v[inner_i] != -1);
if (!DeviceOnly && cache_valid && cache_idx_v[inner_i] != kCacheLocationMissing) {
row_v[inner_i] = reinterpret_cast<const uint4*>(&lxu_cache_weights[static_cast<int64_t>(cache_idx_v[inner_i])][0]);
} else
if (valid) {
} else if (final_valid) {
row_v[inner_i] = reinterpret_cast<const uint4*>(&weights[static_cast<int64_t>(idx_v[inner_i]) * D_bytes]);
} else {
row_v[inner_i] = reinterpret_cast<const uint4*>(&weights[0]);
}
row_valid_v[inner_i] = final_valid;
}
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) || (defined(USE_ROCM) && defined(__gfx950__))
#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) {
uint32_t i = outer_i + inner_i;
bool final_valid = row_valid_v[inner_i];
if constexpr (PackedMode) {
// Store row data with uint4_loads_per_row offset
cp_async_zfill_cg<sizeof(uint4)>(
&buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx],
&row_v[inner_i][row_load_idx],
final_valid);
} else {
cp_async_zfill_cg<sizeof(uint4)>(
&buffers[warp_idx][i][input_row_idx][row_load_idx],
&row_v[inner_i][row_load_idx],
final_valid);
}
}
{% if weighted %}
#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please guard the changes to be ROCM only? We see small regression in NVIDIA.

Copy link
Copy Markdown
Contributor Author

@avbokovoy avbokovoy Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole block (lines 162-228) is under if is_rocm jinja guard

uint32_t i = outer_i + inner_i;
bool final_valid = row_valid_v[inner_i] && (idx_v[inner_i] != -1);
if (row_load_idx == 0) {
// Use only one thread to load the index weight to prevent a race
// condition when writing to the shared memory
buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] =
final_valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
}
}
{% endif %}
#else
// Maintain pipelining of load and store operations when async shared memory/lds
// loads are not available.
uint4 row_data_v[kRowUnroll];
#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) {
row_data_v[inner_i] = row_v[inner_i][row_load_idx];
}
uint4 zeros = {0, 0, 0, 0};
#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) {
uint32_t i = outer_i + inner_i;
bool valid = load_idx_valid && (L_start + input_row_idx < Ls[i]) && (idx_v[inner_i] != -1);
uint4 data = valid ? row_data_v[inner_i] : zeros;
bool final_valid = row_valid_v[inner_i];
uint4 data = final_valid ? row_data_v[inner_i] : zeros;
if constexpr (PackedMode) {
// Store row data with uint4_loads_per_row offset
buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx] = data;
Expand All @@ -216,10 +250,11 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
// Use only one thread to load the index weight to prevent a race
// condition when writing to the shared memory
buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] =
valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
final_valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
}
{% endif %}
}
#endif
}
{%- endif %}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ __device__ __forceinline__ void cp_async_wait() {
#if __CUDA_ARCH__ >= 800

asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
#elif defined(USE_ROCM) && (ROCM_VERSION_MAJOR < 7 || \
(ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR < 2)) && defined(__gfx950__)

__builtin_amdgcn_s_waitcnt(0);
#endif
}

Expand All @@ -179,13 +183,27 @@ __device__ __forceinline__ void cp_async_wait<0>() {
#if __CUDA_ARCH__ >= 800

asm volatile("cp.async.wait_all;\n" ::);
#elif defined(USE_ROCM) && (ROCM_VERSION_MAJOR < 7 || \
(ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR < 2)) && defined(__gfx950__)

__builtin_amdgcn_s_waitcnt(0);
#endif
}

template<typename T>
__device__ __forceinline__ uint32_t hip_cvta_to_shared_address(const T* ptr) {
// First get the address as a size_t to handle all pointer sizes
size_t addr = reinterpret_cast<size_t>(ptr);

// Extract the lower 32 bits which represent the shared memory offset
// This is safe because shared memory addresses are always within 32-bit range
return static_cast<uint32_t>(addr & 0xFFFFFFFF);
}

/// Partial specialization
template <int SizeInBytes>
__device__ __forceinline__ void
cp_async_zfill_cg(void* smem_ptr, void const* global_ptr, bool pred_guard) {
cp_async_zfill_cg(__shared__ void* smem_ptr, void const* global_ptr, bool pred_guard) {
#if __CUDA_ARCH__ >= 800
static_assert(
SizeInBytes == 16,
Expand All @@ -199,6 +217,36 @@ cp_async_zfill_cg(void* smem_ptr, void const* global_ptr, bool pred_guard) {
"n"(SizeInBytes),
"r"(src_in_bytes));

// if ROCm version >= 7.2 and MI350
#elif defined(USE_ROCM) && (ROCM_VERSION_MAJOR > 7 || \
(ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR >= 2)) && defined(__gfx950__)
static __device__ __constant__ uint4 zero_tile = {0, 0, 0, 0};
static_assert(
SizeInBytes == 16,
"cp_async_zfill_cg() function is implemented for 16B inputs only");
// Due to LLVM bug, we can't use SizeInBytes directly
// in __builtin_amdgcn_global_load_lds intrinsic until
// ROCm 7.11:
// https://github.com/llvm/llvm-project/pull/175767
//
// Make sure you modify this #if branch if SizeInBytes
// support range is extended
const void *src_ptr = (pred_guard) ? global_ptr : &zero_tile;
__builtin_amdgcn_global_load_lds(const_cast<void*>(src_ptr), smem_ptr, 16, 0, 0);
// if MI350
#elif defined(USE_ROCM) && defined(__gfx950__)
static __device__ __constant__ uint4 zero_tile = {0, 0, 0, 0};
static_assert(
SizeInBytes == 16,
"cp_async_zfill_cg() function is implemented for 16B inputs only");

uint32_t smem =
__builtin_amdgcn_readfirstlane(hip_cvta_to_shared_address(smem_ptr));
const void *src_ptr = (pred_guard) ? global_ptr : &zero_tile;
asm volatile("s_mov_b32 m0, %0\n"
"global_load_lds_dwordx4 %1, off\n" ::"s"(smem),
"v"(static_cast<const uint32_t *>(src_ptr))
:);
#else
static_assert(SizeInBytes == 16, "");
using AccessType = uint4;
Expand Down
Loading