Skip to content

Implement asynchronous LDS loads for MI350 #5348

Open
avbokovoy wants to merge 5 commits into
pytorch:mainfrom
ROCm:abokovoi/async-lds-inference-opt
Open

Implement asynchronous LDS loads for MI350 #5348
avbokovoy wants to merge 5 commits into
pytorch:mainfrom
ROCm:abokovoi/async-lds-inference-opt

Conversation

@avbokovoy
Copy link
Copy Markdown
Contributor

@avbokovoy avbokovoy commented Jan 26, 2026

This PR implements direct HBM->LDS stores in tbe inference kernel. There are 2 major changes:

  1. Rows data isn't loaded in-place, instead we store pointers to global memory and store the actual data w.r.t. the predicate into LDS. In case predicate is false, we pre-allocate small chunk of static device memory of 16B once, fill it with zeros, and fallback to this chunk
  2. HBM->LDS 16B loads are implemented for ROCm >= 7.0 and MI350. We can expand the support range to MI30* through 4B loads, however it doesn't bring any performance benefits because we'll have to introduce an overhead of addresses transposition and 4x more load operations. You can find out the reference implementation here: fe52557.

Due to pre-7.2 ROCm features, we are forced to used assembly inline to get 16B loads to work, so manual synchronization was added. In case of ROCm >= 7.2, we use proper intrinsics to handle memory synchronization.

This change brings ~10% performance boost on average for weighted and unweighted cases. We may try to push it further by doing async loads for indices weights.

@meta-cla meta-cla Bot added the cla signed label Jan 26, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented Jan 26, 2026

@q10 has imported this pull request. If you are a Meta employee, you can view this in D91496421.


asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
#elif defined(USE_ROCM) && \
(ROCM_VERSION_MAJOR <= 7 && ROCM_VERSION_MINOR < 2) && defined(__gfx950__)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be supported for rocm version < 7.2?

If so, it should be
(ROCM_VERSION_MAJOR < 7 || (ROCM_VERSION_MAJOR == 7 && ROCM_VERSION_MINOR < 2))?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. Your comment + some further tweaks with adjustment to future lds intrinsic API are addressed in 2c739ab


asm volatile("cp.async.wait_all;\n" ::);
#elif defined(USE_ROCM) && \
(ROCM_VERSION_MAJOR <= 7 && ROCM_VERSION_MINOR < 2) && defined(__gfx950__)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 2c739ab

@avbokovoy avbokovoy requested a review from spcyppt February 19, 2026 11:26
const uint4* row_v[kRowUnroll];
int32_t idx_v[kRowUnroll];
int32_t cache_idx_v[kRowUnroll];
bool row_valid_v[kRowUnroll];
Copy link
Copy Markdown
Contributor

@spcyppt spcyppt Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you ensure all the changes only affect rocm?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole block (lines 162-228) is under if is_rocm jinja guard

}
{% if weighted %}
#pragma unroll
for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please guard the changes to be ROCM only? We see small regression in NVIDIA.

Copy link
Copy Markdown
Contributor Author

@avbokovoy avbokovoy Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole block (lines 162-228) is under if is_rocm jinja guard

aryaman-gupta and others added 2 commits May 18, 2026 11:49
cp_async_zfill_cg is async on Ampere+ and gfx950 but synchronous
elsewhere. Inlining the sync fallback into the per-iteration row-load
loop kills load pipelining (load->store dependency forces N waitcnts
instead of one) and adds wave divergence on mixed-validity warps.
Measured up to -19% BW on MI300 (gfx942) for weighted L=20/L=50.

Wrap the row-store section in a #if matching the helper's dispatch:
gfx950/Ampere keep the fused cp_async_zfill_cg loop; everything else
gets the original two-loop pattern (load all -> masked store).
Helper and gfx950 paths untouched.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants