Skip to content

Commit 94e5d3b

Browse files
authored
Merge branch 'develop' into 498-attention-scheduling
2 parents 38bc812 + a8ae8ac commit 94e5d3b

56 files changed

Lines changed: 3440 additions & 1120 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ set(ROCMLIR_DRIVER_RANDOM_DATA_SEED "none" CACHE STRING "Enable E2E tests using
5757
set(ROCMLIR_GEN_FLAGS "" CACHE BOOL "Set feature flag for rocmlir-gen")
5858
set(ROCMLIR_DRIVER_TEST_GPU_VALIDATION 1 CACHE BOOL "Enable E2E tests with GPU validation")
5959
set(ROCK_E2E_TEST_ENABLED 0 CACHE BOOL "Enable build rock E2E tests")
60+
option(ROCMLIR_BUILD_TUNING_DRIVER "Build rocmlir-tuning-driver (default ON when BUILD_FAT_LIBROCKCOMPILER)" OFF)
6061
set(ROCMLIR_ENABLE_BENCHMARKS "" CACHE STRING "List of enabled benchmarks")
6162

6263
set(ROCMLIR_BIN_DIR "${CMAKE_CURRENT_BINARY_DIR}/bin" CACHE PATH "")
@@ -81,6 +82,8 @@ if( BUILD_FAT_LIBROCKCOMPILER )
8182
set(LLVM_BUILD_LLVM_DYLIB OFF CACHE BOOL "")
8283
# rocm-runner is not supported with static libraries
8384
set(MLIR_ENABLE_ROCM_RUNNER 0 CACHE BOOL "")
85+
# Enable tuning driver by default for fat-lib builds (it can link against static dialect libs + shared HIP)
86+
set(ROCMLIR_BUILD_TUNING_DRIVER ON CACHE BOOL "" FORCE)
8487
set(MLIR_INCLUDE_INTEGRATION_TESTS OFF CACHE BOOL "")
8588
set(ROCMLIR_DRIVER_PR_E2E_TEST_ENABLED 0 CACHE BOOL "Enable build PR-triggered E2E tests for Rock driver")
8689
set(MHAL_ENABLE_HOST_RUNNER OFF CACHE BOOL "Enable MHAL host runner")

external/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,6 @@ class SchedGroup {
139139
// Count of the number of created SchedGroups, used to initialize SGID.
140140
static unsigned NumSchedGroups;
141141

142-
// Try to add and edge from SU A to SU B.
143-
bool tryAddEdge(SUnit *A, SUnit *B);
144-
145142
// Use SGMask to determine whether we can classify MI as a member of this
146143
// SchedGroup object.
147144
bool canAddMI(const MachineInstr &MI) const;
@@ -153,6 +150,9 @@ class SchedGroup {
153150
ScheduleDAGInstrs *DAG;
154151
const SIInstrInfo *TII;
155152

153+
// Try to add and edge from SU A to SU B.
154+
bool tryAddEdge(SUnit *A, SUnit *B);
155+
156156
// Returns true if SU can be added to this SchedGroup.
157157
bool canAddSU(SUnit &SU) const;
158158

@@ -164,7 +164,7 @@ class SchedGroup {
164164
// Add DAG dependencies and track which edges are added, and the count of
165165
// missed edges
166166
int link(SUnit &SU, bool MakePred,
167-
std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
167+
std::list<std::pair<SUnit *, SUnit *>> &AddedEdges);
168168

169169
// Add DAG dependencies from all SUnits in this SchedGroup and this SU.
170170
// Use the predicate to determine whether SU should be a predecessor (P =
@@ -305,8 +305,7 @@ class PipelineSolver {
305305
// current information. One step in the greedy algorithm. Templated against
306306
// the SchedGroup iterator (either reverse or forward).
307307
template <typename T>
308-
void greedyFind(std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I,
309-
T E);
308+
void greedyFind(std::list<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E);
310309
// Whether or not the current solution is optimal
311310
bool checkOptimal();
312311
// Populate the ready list, prioiritizing fewest missed edges first
@@ -322,15 +321,15 @@ class PipelineSolver {
322321
// Add the edges from the SU to the other SchedGroups in pipeline, and
323322
// return the number of edges missed.
324323
int addEdges(SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
325-
std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
324+
std::list<std::pair<SUnit *, SUnit *>> &AddedEdges);
326325
/// Link the pipeline as if \p SU was in the SchedGroup with ID \p SGID. It
327326
/// returns the cost (in terms of missed pipeline edges), and tracks the edges
328327
/// added in \p AddedEdges
329328
template <typename T>
330329
int linkSUnit(SUnit *SU, int SGID,
331-
std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E);
330+
std::list<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E);
332331
/// Remove the edges passed via \p AddedEdges
333-
void removeEdges(const std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
332+
void removeEdges(const std::list<std::pair<SUnit *, SUnit *>> &AddedEdges);
334333
// Convert the passed in maps to arrays for bidirectional iterators
335334
void convertSyncMapsToArrays();
336335

@@ -454,7 +453,7 @@ void PipelineSolver::makePipeline() {
454453

455454
template <typename T>
456455
int PipelineSolver::linkSUnit(
457-
SUnit *SU, int SGID, std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges,
456+
SUnit *SU, int SGID, std::list<std::pair<SUnit *, SUnit *>> &AddedEdges,
458457
T I, T E) {
459458
bool MakePred = false;
460459
int AddedCost = 0;
@@ -472,7 +471,7 @@ int PipelineSolver::linkSUnit(
472471

473472
int PipelineSolver::addEdges(
474473
SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
475-
std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
474+
std::list<std::pair<SUnit *, SUnit *>> &AddedEdges) {
476475

477476
// For IsBottomUp, the first SchedGroup in SyncPipeline contains the
478477
// instructions that are the ultimate successors in the resultant mutation.
@@ -489,7 +488,7 @@ int PipelineSolver::addEdges(
489488
}
490489

491490
void PipelineSolver::removeEdges(
492-
const std::vector<std::pair<SUnit *, SUnit *>> &EdgesToRemove) {
491+
const std::list<std::pair<SUnit *, SUnit *>> &EdgesToRemove) {
493492
// Only remove the edges that we have added when testing
494493
// the fit.
495494
for (auto &PredSuccPair : EdgesToRemove) {
@@ -568,7 +567,7 @@ void PipelineSolver::populateReadyList(
568567
assert(CurrSU.second.size() >= 1);
569568

570569
for (; I != E; ++I) {
571-
std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
570+
std::list<std::pair<SUnit *, SUnit *>> AddedEdges;
572571
int CandSGID = *I;
573572
SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
574573
return SG.getSGID() == CandSGID;
@@ -627,7 +626,7 @@ bool PipelineSolver::solveExact() {
627626

628627
int CandSGID = I->first;
629628
int AddedCost = 0;
630-
std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
629+
std::list<std::pair<SUnit *, SUnit *>> AddedEdges;
631630
auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
632631
SchedGroup *Match;
633632
for (auto &SG : SyncPipeline) {
@@ -694,12 +693,13 @@ bool PipelineSolver::solveExact() {
694693

695694
template <typename T>
696695
void PipelineSolver::greedyFind(
697-
std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E) {
696+
std::list<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E) {
698697
SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
699698
int BestNodeCost = -1;
700699
int TempCost;
701700
SchedGroup *BestGroup = nullptr;
702701
int BestGroupID = -1;
702+
std::list<std::pair<SUnit *, SUnit *>> BestEdges;
703703
auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
704704
LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
705705
<< ") in Pipeline # " << CurrSyncGroupIdx << "\n");
@@ -709,7 +709,6 @@ void PipelineSolver::greedyFind(
709709
// first. If we fail to do this for the greedy algorithm, the solution will
710710
// likely not be good in more complex cases.
711711
for (; I != E; ++I) {
712-
std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
713712
int CandSGID = *I;
714713
SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
715714
return SG.getSGID() == CandSGID;
@@ -727,21 +726,36 @@ void PipelineSolver::greedyFind(
727726
LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " has conflicting rule\n");
728727
continue;
729728
}
730-
TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
729+
730+
std::list<std::pair<SUnit *, SUnit *>> TempEdges;
731+
TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, TempEdges);
731732
LLVM_DEBUG(dbgs() << "Cost of Group " << TempCost << "\n");
733+
732734
if (TempCost < BestNodeCost || BestNodeCost == -1) {
735+
BestEdges = TempEdges;
733736
BestGroup = Match;
734737
BestNodeCost = TempCost;
735738
BestGroupID = CandSGID;
739+
740+
if (BestNodeCost == 0)
741+
break;
736742
}
737-
removeEdges(AddedEdges);
738-
if (BestNodeCost == 0)
739-
break;
743+
744+
removeEdges(TempEdges);
740745
}
741746

742747
if (BestGroupID != -1) {
743748
BestGroup->add(*CurrSU.first);
744-
addEdges(SyncPipeline, CurrSU.first, BestGroupID, AddedEdges);
749+
if (AddedEdges.empty())
750+
AddedEdges = BestEdges;
751+
else
752+
AddedEdges.splice(std::prev(AddedEdges.cend()), BestEdges);
753+
754+
for (const std::pair<SUnit *, SUnit *> &E : BestEdges) {
755+
if (!BestGroup->tryAddEdge(E.first, E.second))
756+
llvm_unreachable("Edges known to be insertable.");
757+
}
758+
745759
LLVM_DEBUG(dbgs() << "Best Group has ID: " << BestGroupID << " and Mask"
746760
<< (int)BestGroup->getMask() << "\n");
747761
BestCost += TempCost;
@@ -753,7 +767,7 @@ void PipelineSolver::greedyFind(
753767

754768
bool PipelineSolver::solveGreedy() {
755769
BestCost = 0;
756-
std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
770+
std::list<std::pair<SUnit *, SUnit *>> AddedEdges;
757771

758772
while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size()) {
759773
SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
@@ -2379,11 +2393,7 @@ class IGroupLPDAGMutation : public ScheduleDAGMutation {
23792393
unsigned SchedGroup::NumSchedGroups = 0;
23802394

23812395
bool SchedGroup::tryAddEdge(SUnit *A, SUnit *B) {
2382-
if (A != B && DAG->canAddEdge(B, A)) {
2383-
DAG->addEdge(B, SDep(A, SDep::Artificial));
2384-
return true;
2385-
}
2386-
return false;
2396+
return A != B && DAG->addEdge(B, SDep(A, SDep::Artificial));
23872397
}
23882398

23892399
bool SchedGroup::canAddMI(const MachineInstr &MI) const {
@@ -2448,7 +2458,7 @@ bool SchedGroup::canAddMI(const MachineInstr &MI) const {
24482458
}
24492459

24502460
int SchedGroup::link(SUnit &SU, bool MakePred,
2451-
std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
2461+
std::list<std::pair<SUnit *, SUnit *>> &AddedEdges) {
24522462
int MissedEdges = 0;
24532463
for (auto *A : Collection) {
24542464
SUnit *B = &SU;

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ def Rock_AttentionOp
220220
Optional<TensorOrMemRefOf<[F32, F16, BF16]>>:$lse, I32Attr:$numHeadsQ,
221221
I32Attr:$numHeadsKV, UnitAttr:$qTransposed, UnitAttr:$kTransposed,
222222
UnitAttr:$vTransposed, UnitAttr:$oTransposed, UnitAttr:$causal,
223-
I32Attr:$splitKV, OptionalAttr<Rock_GemmFeaturesAttr>:$features,
223+
I32Attr:$splitKV, OptionalAttr<I32Attr>:$slidingWindowSize,
224+
OptionalAttr<Rock_GemmFeaturesAttr>:$features,
224225
StoreMethodAttr:$storeMethod, OptionalAttr<TypeAttr>:$softmaxType,
225226
OptionalAttr<RockTuningParamAttrInterface>:$params0,
226227
OptionalAttr<RockTuningParamAttrInterface>:$params1,
@@ -253,6 +254,11 @@ def Rock_AttentionOp
253254
- A tensor of shape [G]: per-group/batch offsets, allowing different prefix
254255
lengths for each sequence in the batch
255256

257+
If slidingWindowSize is set, we implement sliding window attention where
258+
only the last `slidingWindowSize` key positions (relative to currentSeqLen)
259+
are attended to. Positions before `max(0, currentSeqLen - slidingWindowSize)`
260+
are masked with -inf. This requires currentSeqLen to be set.
261+
256262
LSE (log-sum-exp) is an optional output typically used for flash decoding.
257263
For flash decoding, you can pass splitKV > 1, the default value is 1, which means flash decoding is disabled.
258264
Flash decoding multiplies the number of blocks by splitKV. Note that "lse" has to be non-null for splitKV > 1.
@@ -278,6 +284,7 @@ def Rock_AttentionOp
278284
` ` `qk` `=` (`tr` $qTransposed^)? $queries `*` (`tr` $kTransposed^)? $keys `:` type($queries) `,` type($keys) `\n`
279285
(`currentSeqLen` `=` `(` $currentSeqLen^ `:` type($currentSeqLen) `)` `\n`)?
280286
(`prefixOffset` `=` `(` $prefixOffset^ `:` type($prefixOffset) `)` `\n`)?
287+
(`slidingWindowSize` `=` $slidingWindowSize^ `\n`)?
281288
(`causal` `\n` $causal^)?
282289
(`lse` `=` $lse^ `:` type($lse) `\n`)?
283290
(`qk` `=` `elementwise` (`otherIns` `(` $preSoftmaxElemWiseInputs^ `:` type($preSoftmaxElemWiseInputs) `)`)? $preSoftmaxBody^ `\n`)?
@@ -583,7 +590,8 @@ def Rock_GridwiseAttentionAccelOp
583590
Optional<MemRefRankOf<[I32], [1]>>:$prefixOffset,
584591
MemRefRankOf<[F32, F16, BF16], [3]>:$out,
585592
Optional<MemRefRankOf<[F32, F16, BF16], [2]>>:$lse, UnitAttr:$causal,
586-
I32Attr:$splitKV, OptionalAttr<Rock_GemmFeaturesAttr>:$features,
593+
I32Attr:$splitKV, OptionalAttr<I32Attr>:$slidingWindowSize,
594+
OptionalAttr<Rock_GemmFeaturesAttr>:$features,
587595
StoreMethodAttr:$storeMethod, I32Attr:$blockSize, I32Attr:$gridSize,
588596
UnitAttr:$disableQBypassLDS, OptionalAttr<IndexAttr>:$prePadG0M,
589597
OptionalAttr<IndexAttr>:$prePadG0N,

0 commit comments

Comments
 (0)