-
Notifications
You must be signed in to change notification settings - Fork 55
[AIROCMLIR-498] Attention scheduling improvements #2267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
0ceb202
e42fe55
d963949
004dfb3
6eab4b9
b2d7ff7
e2ffd8f
dcceeec
1bb557c
b3fecc6
7063f8d
fedc9eb
c298f92
da06455
6018bd8
b4ac474
5cad4fb
2366068
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -240,6 +240,9 @@ void rock::buildKernelPipeline(OpPassManager &pm, | |
| funcPm.addPass(rock::createRockThreadwiseGemmLoweringPass()); | ||
| funcPm.addPass(rock::createRockAnalyzeMemoryUsePass()); | ||
| funcPm.addPass(rock::createRockSugarToLoopsPass()); | ||
| // Re-run the pipeline pass to remove back-to-back LDS barriers | ||
| // that may appear after SugarToLoops unrolls TransformingForOps. | ||
| funcPm.addPass(rock::createRockPipelinePass()); | ||
|
Comment on lines
+243
to
+245
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know we add LDSBarriers in gridewiseToBlockwise conservatively thinking rock-pipeline will take care of it. (1) this barrier, ay not be necessary if loop body starts with barrier and after exiting the loop there's a barrier (2) For loop carried deps possibly (3) Can be eliminated if there is a barrier at the exit of the loop (4) Exit barrier |
||
| funcPm.addPass(rock::createRockCleanMathPass()); | ||
| math::MathExtendToSupportedTypesOptions extendToLLVMTypesOptions; | ||
| extendToLLVMTypesOptions.extraTypeStrs = {"f16"}; | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -928,8 +928,13 @@ struct BlockwiseReduceRewritePattern | |||||||||||||||||||||
| } | ||||||||||||||||||||||
| } else { | ||||||||||||||||||||||
| if (rMethod == ReduceMethod::Sum) { | ||||||||||||||||||||||
| // Use -0.0 (negative zero) instead of +0.0. In IEEE 754, -0.0 is the | ||||||||||||||||||||||
| // true additive identity: fadd(-0.0, x) = x for ALL x (including -0.0 | ||||||||||||||||||||||
| // and NaN). LLVM can fold `fadd -0.0, x → x`, eliminating the | ||||||||||||||||||||||
| // redundant `v_add_f32 v, 0, v` that +0.0 generates via | ||||||||||||||||||||||
| // llvm.vector.reduce.fadd. | ||||||||||||||||||||||
|
Comment on lines
+931
to
+935
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this PR takes too long to merge, move these changes into a seperate PR and also create tests to make sure it doesn't generate
Comment on lines
+931
to
+935
|
||||||||||||||||||||||
| // Use -0.0 (negative zero) instead of +0.0. In IEEE 754, -0.0 is the | |
| // true additive identity: fadd(-0.0, x) = x for ALL x (including -0.0 | |
| // and NaN). LLVM can fold `fadd -0.0, x → x`, eliminating the | |
| // redundant `v_add_f32 v, 0, v` that +0.0 generates via | |
| // llvm.vector.reduce.fadd. | |
| // Use -0.0 (negative zero) instead of +0.0. LLVM can fold | |
| // `fadd -0.0, x → x`, eliminating the redundant | |
| // `v_add_f32 v, 0, v` that +0.0 generates via | |
| // llvm.vector.reduce.fadd. (Note: IEEE 754 still propagates NaNs, | |
| // i.e., x + NaN = NaN for any x.) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also seems a like an independent change compared to scheduling VTile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree this needs a heuristic. The branchless approach is O(N) LDS reads vs O(log N) in the tree, so for small rTidCount (2-4) it's clearly better, but for larger values (8, 16) the extra LDS reads may outweigh the branch elimination benefit.
Suggestion: benchmark both approaches for representative configs with rTidCount = 2, 4, 8, 16 on target architectures to find the empirical crossover point, then add a threshold and keep the old tree path as a fallback.
Copilot
AI
Mar 31, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
initVal is used to initialize accReg in the branchless reduction path, but it is declared inside the preceding if (threadViewShape[rIterDim] > 1) block. As written, this won’t compile (and even conceptually, the branchless reduction should be able to run when rIterDim <= 1). Move the initVal definition outside the conditional (or recompute it in the branchless block), or remove the FillOp entirely since the i==0 iteration overwrites the accumulator.
| FillOp::create(rewriter, loc, accReg, initVal); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BlockwiseLoadTileOp::getEffects doesn’t account for the new split-phase load types. For GemmLoadTileType::GlobalReadOnly the op does not write to destLDS, and for GemmLoadTileType::LDSWriteFromRegs the op should not read from
sourceat all (it should read fromdestRegistersand write todestLDS). As written, MemoryEffects will incorrectly report global/LDS accesses, which can mislead scheduling and optimization passes that rely on effects. Please add explicit cases for GlobalReadOnly/LDSWriteFromRegs (and ensure LDSReadOnly remains LDS-read + regs-write only).