Implement asynchronous LDS loads for MI350 #5348
Conversation
|
|
||
| asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); | ||
| #elif defined(USE_ROCM) && \ | ||
| (ROCM_VERSION_MAJOR <= 7 && ROCM_VERSION_MINOR < 2) && defined(__gfx950__) |
There was a problem hiding this comment.
Is this supposed to be supported for rocm version < 7.2?
If so, it should be
(ROCM_VERSION_MAJOR < 7 || (ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR < 2))?
There was a problem hiding this comment.
Indeed. Your comment + some further tweaks with adjustment to future lds intrinsic API are addressed in 2c739ab
|
|
||
| asm volatile("cp.async.wait_all;\n" ::); | ||
| #elif defined(USE_ROCM) && \ | ||
| (ROCM_VERSION_MAJOR <= 7 && ROCM_VERSION_MINOR < 2) && defined(__gfx950__) |
| const uint4* row_v[kRowUnroll]; | ||
| int32_t idx_v[kRowUnroll]; | ||
| int32_t cache_idx_v[kRowUnroll]; | ||
| bool row_valid_v[kRowUnroll]; |
There was a problem hiding this comment.
could you ensure all the changes only affect rocm?
There was a problem hiding this comment.
The whole block (lines 162-228) is under if is_rocm jinja guard
| } | ||
| {% if weighted %} | ||
| #pragma unroll | ||
| for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { |
There was a problem hiding this comment.
Please guard the changes to be ROCM only? We see small regression in NVIDIA.
There was a problem hiding this comment.
The whole block (lines 162-228) is under if is_rocm jinja guard
cp_async_zfill_cg is async on Ampere+ and gfx950 but synchronous elsewhere. Inlining the sync fallback into the per-iteration row-load loop kills load pipelining (load->store dependency forces N waitcnts instead of one) and adds wave divergence on mixed-validity warps. Measured up to -19% BW on MI300 (gfx942) for weighted L=20/L=50. Wrap the row-store section in a #if matching the helper's dispatch: gfx950/Ampere keep the fused cp_async_zfill_cg loop; everything else gets the original two-loop pattern (load all -> masked store). Helper and gfx950 paths untouched. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…ents about pipelining of memory ops
This PR implements direct HBM->LDS stores in tbe inference kernel. There are 2 major changes:
Due to pre-7.2 ROCm features, we are forced to used assembly inline to get 16B loads to work, so manual synchronization was added. In case of ROCm >= 7.2, we use proper intrinsics to handle memory synchronization.
This change brings ~10% performance boost on average for weighted and unweighted cases. We may try to push it further by doing async loads for indices weights.