Skip to content

Commit fbd25aa

Browse files
committed
Backend changes
1 parent 72f9c60 commit fbd25aa

14 files changed

Lines changed: 1196 additions & 49 deletions

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,11 +1141,13 @@ def Rock_GlobalPrefetchOp
11411141
def Rock_GlobalLoadOp
11421142
: Rock_Op<
11431143
"global_load", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1144-
AllElementTypesMatch<["source", "result"]>]>,
1144+
AllElementTypesMatch<["source", "result"]>,
1145+
AttrSizedOperandSegments]>,
11451146
Arguments<(
11461147
ins Arg<MemRefOf<SupportedMemoryElems>, "source memory">:$source,
11471148
I1:$valid, Variadic<Index>:$sourceCoord, UnitAttr:$needs64BitIdx,
1148-
UnitAttr:$canReadOffEnd)>,
1149+
UnitAttr:$canReadOffEnd, Optional<I64>:$pagePtr,
1150+
OptionalAttr<I64Attr>:$pageSize)>,
11491151
Results<(outs AnyType:$result)> {
11501152
let summary = "Load from global memory, applying bounds checks";
11511153
let description = [{
@@ -1162,9 +1164,17 @@ def Rock_GlobalLoadOp
11621164
If `canReadOffEnd` is present, then the `valid` input is ignored and the
11631165
bounds-checked implementation of this operation is always used. This
11641166
is used to simplify the implementation of certain utility kernels.
1167+
1168+
When `pagePtr` and `pageSize` are provided, this
1169+
operation loads from paged memory. The `source` memref is still used for
1170+
type information (element type, etc.), but the actual load uses `pagePtr`
1171+
as the base address. `sourceCoord[0]` is treated as the offset-in-page in
1172+
elements (a single flat index, not multi-dimensional coordinates).
11651173
}];
11661174
let assemblyFormat = [{
1167-
$source `[` $sourceCoord `]` `if` $valid attr-dict
1175+
$source `[` $sourceCoord `]` `if` $valid
1176+
(`paged` $pagePtr^)?
1177+
attr-dict
11681178
`:` type($source) `->` type($result)
11691179
}];
11701180
let hasVerifier = 1;
@@ -1200,7 +1210,8 @@ def Rock_GlobalLoadToLDSOp
12001210
}];
12011211
let assemblyFormat = [{
12021212
$source `[` $sourceCoord `]`
1203-
`->` $dest `[` $destCoord `]` `if` $valid attr-dict
1213+
`->` $dest `[` $destCoord `]` `if` $valid
1214+
attr-dict
12041215
`:` type($source) `->` type($dest)
12051216
}];
12061217
let hasVerifier = 1;
@@ -1302,14 +1313,21 @@ def Rock_ThreadwiseReadIntoOp
13021313
[](::mlir::Type lhs, ::mlir::TypeRange rhs) -> bool {
13031314
return llvm::all_of(rhs, [lhs](::mlir::Type t) { return t == lhs; });
13041315
}
1305-
}]>]>,
1306-
Arguments<(ins Arg<MemRefOf<NativeMemoryOpTypes>, "source view">:$source,
1316+
}]>]> {
1317+
1318+
let arguments =
1319+
(ins Arg<MemRefOf<NativeMemoryOpTypes>, "source view">:$source,
13071320
Arg<MemRefOf<NativeMemoryOpTypes>, "destination registers">:$dest,
13081321
Variadic<VectorOfNonZeroRankOf<[I1]>>:$dynamicValidities,
13091322
TransformMapArrayAttr:$extraViews, Variadic<Index>:$extraIndices,
13101323
UnitAttr:$forceUnroll, UnitAttr:$useIndexDiffs,
1311-
OptionalAttr<Rock_LDSTransposeConfigAttr>:$ldsTransposeConfig)>,
1312-
Results<(outs Optional<VectorOfNonZeroRankOf<[I1]>>:$validityRecord)> {
1324+
OptionalAttr<Rock_LDSTransposeConfigAttr>:$ldsTransposeConfig,
1325+
Optional<MemRefOf<[I64]>>:$ldsPagePtrs,
1326+
Optional<Index>:$firstPageIndex, OptionalAttr<IndexAttr>:$pageSize,
1327+
OptionalAttr<IndexAttr>:$numPagesPerBatch);
1328+
1329+
let results = (outs Optional<VectorOfNonZeroRankOf<[I1]>>:$validityRecord);
1330+
13131331
let summary = "Read values from transformed source into destination";
13141332

13151333
let description = [{
@@ -1341,8 +1359,8 @@ def Rock_ThreadwiseReadIntoOp
13411359
L is the length of `%dest`, V is the maximum vectorization computed
13421360
for MAPS.
13431361

1344-
The input to extraViews: (the transforms on %source) must have the form
1345-
(extraIdx0, ... , extraIdxN, iteration_number)
1362+
The input to extraViews (the transforms on %source) must have the form
1363+
(extraIdx0, ... , extraIdxN, iteration_number).
13461364

13471365
Primarily, extraIndices would be used to pass in tid and bid. This would need
13481366
to have a matching view in [extraViews]source. The extraIndices could be used
@@ -1368,6 +1386,14 @@ def Rock_ThreadwiseReadIntoOp
13681386

13691387
This operation is also used during fusion to represent loads from additional
13701388
arguments like bias tensors.
1389+
1390+
When `ldsPagePtrs`, `firstPageIndex`, and `pageSize` are provided, this
1391+
operation reads from paged memory where data is scattered across
1392+
non-contiguous pages accessed via a page table. The page pointers are
1393+
pre-loaded into LDS by the calling code to avoid redundant global memory
1394+
accesses across threads. During lowering, the transform maps compute the
1395+
flat position for each element, determining the page index and offset
1396+
within page.
13711397
}];
13721398

13731399
let builders = [
@@ -1377,28 +1403,38 @@ def Rock_ThreadwiseReadIntoOp
13771403
"ValueRange":$extraIndices, "bool":$forceUnroll,
13781404
"bool":$useIndexDiffs),
13791405
[{
1380-
build($_builder, $_state, TypeRange{}, source, dest, ValueRange{}, extraViews, extraIndices, forceUnroll, useIndexDiffs, /*ldsTransposeConfig=*/nullptr);
1406+
build($_builder, $_state, TypeRange{}, source, dest, ValueRange{}, extraViews, extraIndices, forceUnroll, useIndexDiffs, /*ldsTransposeConfig=*/nullptr, /*ldsPagePtrs=*/Value{}, /*firstPageIndex=*/Value{}, /*pageSize=*/nullptr, /*numPagesPerBatch=*/nullptr);
13811407
}]>,
13821408
// Builder with explicit ldsTransposeConfig support
13831409
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayAttr":$extraViews,
13841410
"ValueRange":$extraIndices, "bool":$forceUnroll,
13851411
"bool":$useIndexDiffs,
13861412
"LDSTransposeConfigAttr":$ldsTransposeConfig),
13871413
[{
1388-
build($_builder, $_state, TypeRange{}, source, dest, ValueRange{}, extraViews, extraIndices, forceUnroll, useIndexDiffs, ldsTransposeConfig);
1414+
build($_builder, $_state, TypeRange{}, source, dest, ValueRange{}, extraViews, extraIndices, forceUnroll, useIndexDiffs, ldsTransposeConfig, /*ldsPagePtrs=*/Value{}, /*firstPageIndex=*/Value{}, /*pageSize=*/nullptr, /*numPagesPerBatch=*/nullptr);
13891415
}]>,
13901416
// Builder with validityRecord but without ldsTransposeConfig
13911417
OpBuilder<(ins "TypeRange":$validityRecord, "Value":$source,
13921418
"Value":$dest, "ValueRange":$dynamicValidities,
13931419
"ArrayAttr":$extraViews, "ValueRange":$extraIndices,
13941420
"bool":$forceUnroll, "bool":$useIndexDiffs),
13951421
[{
1396-
build($_builder, $_state, validityRecord, source, dest, dynamicValidities, extraViews, extraIndices, forceUnroll, useIndexDiffs, /*ldsTransposeConfig=*/nullptr);
1422+
build($_builder, $_state, validityRecord, source, dest, dynamicValidities, extraViews, extraIndices, forceUnroll, useIndexDiffs, /*ldsTransposeConfig=*/nullptr, /*ldsPagePtrs=*/Value{}, /*firstPageIndex=*/Value{}, /*pageSize=*/nullptr, /*numPagesPerBatch=*/nullptr);
1423+
}]>,
1424+
// Builder with validityRecord and ldsTransposeConfig
1425+
OpBuilder<(ins "TypeRange":$validityRecord, "Value":$source,
1426+
"Value":$dest, "ValueRange":$dynamicValidities,
1427+
"ArrayAttr":$extraViews, "ValueRange":$extraIndices,
1428+
"bool":$forceUnroll, "bool":$useIndexDiffs,
1429+
"LDSTransposeConfigAttr":$ldsTransposeConfig),
1430+
[{
1431+
build($_builder, $_state, validityRecord, source, dest, dynamicValidities, extraViews, extraIndices, forceUnroll, useIndexDiffs, ldsTransposeConfig, /*ldsPagePtrs=*/Value{}, /*firstPageIndex=*/Value{}, /*pageSize=*/nullptr, /*numPagesPerBatch=*/nullptr);
13971432
}]>];
13981433

13991434
let assemblyFormat = [{
14001435
attr-dict $extraViews `(` $source `)` (`[` $extraIndices^ `]`)? `->` $dest
14011436
(`if` ` ` `[` $dynamicValidities^ `]`)?
1437+
(`paged` $ldsPagePtrs^ `:` type($ldsPagePtrs) `[` $firstPageIndex `]`)?
14021438
`:` type($source) `->` type($dest) (`,` type($validityRecord)^)?
14031439
}];
14041440
let hasVerifier = 1;
@@ -1634,7 +1670,9 @@ def Rock_BlockwiseLoadTileOp
16341670
Rock_BlockwiseMatrixParamsAttr:$matrixParamsB, UnitAttr:$isA,
16351671
Variadic<Index>:$sourceIndices,
16361672
OptionalAttr<Rock_GemmFeaturesAttr>:$features, I32Attr:$blockSize,
1637-
RockAccelTuningParamAttrInterface:$params)> {
1673+
RockAccelTuningParamAttrInterface:$params,
1674+
Optional<MemRefOf<[I64]>>:$pageTable,
1675+
OptionalAttr<IndexAttr>:$pageSize)> {
16381676
let summary =
16391677
"Blockwise load tile from global memory to LDS and/or registers";
16401678
let description = [{
@@ -1651,10 +1689,20 @@ def Rock_BlockwiseLoadTileOp
16511689
`isA` determines if we are loading an A matrix or B matrix. `G`, `M` and `N` are the GEMM sizes.
16521690
`elementTypeA` and `elementTypeB` are used to construct AccelEmitter. They are data types for the Matrix A & B of the GEMMs.
16531691
`elementLoadType` is the element type of the global buffer this BlockwiseLoadTileOp is trying to load.
1654-
`elementType` is the elementType in the registers. Note that it may differ from `elementLoadType` because of input fusions.
1692+
`elementType` is the elementType in the registers. Note that it may differ from `elementLoadType` because of input fusions.
1693+
1694+
For paged attention (when K/V tensors are stored in non-contiguous pages),
1695+
the following optional parameters enable indirect memory access:
1696+
- `pageTable`: The i64 page table memref containing base pointers for each page.
1697+
Shape is [batch, numPages, 1] where each entry is a pointer to a page.
1698+
- `pageSize`: Number of elements per page, used to compute page boundaries.
1699+
When these are provided, the lowering will compute page indices from logical
1700+
positions and load page pointers to resolve physical addresses.
16551701
}];
16561702
let assemblyFormat = [{
1657-
$source (`[` $sourceIndices^ `]`)? (`LDS` `->` $destLDS^)? (`->` $destRegisters^)? attr-dict
1703+
$source (`[` $sourceIndices^ `]`)? (`LDS` `->` $destLDS^)? (`->` $destRegisters^)?
1704+
(`paged` $pageTable^ `:` type($pageTable))?
1705+
attr-dict
16581706
`:` type($source) (`LDS` `->` type($destLDS)^)? (`->` type($destRegisters)^)?
16591707
}];
16601708

mlir/include/mlir/Dialect/Rock/utility/transformMapUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,11 @@ SmallVector<StringRef> getStringRefsFor(ArrayRef<SmallString<8>> strings);
304304
// type of the block argument, otherwise it returns failure.
305305
FailureOr<Type> getInputFusionElementType(Value transformed);
306306

307+
// Compute the flat position in the virtual paged space by evaluating
308+
// the source transforms on the given coordinates.
309+
FailureOr<Value> computeFlatPosition(OpBuilder &b, Location loc, Value source,
310+
ValueRange indices);
311+
307312
} // end namespace rock
308313
} // end namespace mlir
309314
#endif

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2009,8 +2009,31 @@ static LogicalResult verifyGlobalLoadAndPrefetch(LoadOrPrefetch op) {
20092009
MemRefType sourceType = op.getSource().getType();
20102010
size_t nDims = sourceType.getRank();
20112011

2012-
if (op.getSourceCoord().size() != nDims)
2012+
// Check if this op has paging attributes
2013+
bool isPaged = false;
2014+
if constexpr (std::is_same_v<LoadOrPrefetch, GlobalLoadOp>) {
2015+
isPaged = op.getPagePtr() != nullptr;
2016+
2017+
// Verify paging attributes consistency
2018+
bool hasPagePtr = op.getPagePtr() != nullptr;
2019+
bool hasPageSize = op.getPageSize().has_value();
2020+
if (hasPagePtr != hasPageSize) {
2021+
return op.emitOpError(
2022+
"pagePtr and pageSize must both be set or both be unset");
2023+
}
2024+
2025+
if (hasPageSize && op.getPageSize().value() <= 0) {
2026+
return op.emitOpError("pageSize must be positive");
2027+
}
2028+
}
2029+
2030+
// For paged loads, we expect exactly 1 coordinate (flat offset-in-page)
2031+
if (isPaged && op.getSourceCoord().size() != 1) {
2032+
return op.emitOpError("Expected 1 coordinate for paged load");
2033+
} else if (op.getSourceCoord().size() != nDims) {
20132034
return op.emitOpError("Expected " + Twine(nDims) + " coordinates");
2035+
}
2036+
20142037
Attribute memSpaceAttr = sourceType.getMemorySpace();
20152038
auto gpuMemSpaceAttr = dyn_cast_or_null<gpu::AddressSpaceAttr>(memSpaceAttr);
20162039
if (memSpaceAttr && (!gpuMemSpaceAttr ||
@@ -2328,6 +2351,11 @@ ThreadwiseReadIntoOp::cloneWithExtraIndices(OpBuilder &builder,
23282351
void ThreadwiseReadIntoOp::getEffects(
23292352
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
23302353
getCommonEffects(*this, effects);
2354+
// If this is a paged load, we also read from the LDS page pointer buffer
2355+
if (getLdsPagePtrs()) {
2356+
effects.emplace_back(MemoryEffects::Read::get(),
2357+
&getLdsPagePtrsMutable()[0]);
2358+
}
23312359
}
23322360

23332361
LogicalResult ThreadwiseReadIntoOp::verify() {
@@ -2393,6 +2421,37 @@ LogicalResult ThreadwiseReadIntoOp::verify() {
23932421
"in register-to-register reads produced by input fusion");
23942422
}
23952423
}
2424+
2425+
// Verify paged attention attributes consistency
2426+
bool hasLdsPagePtrs = getLdsPagePtrs() != nullptr;
2427+
bool hasFirstPageIndex = getFirstPageIndex() != nullptr;
2428+
bool hasPageSize = getPageSize().has_value();
2429+
2430+
if (hasLdsPagePtrs != hasFirstPageIndex || hasLdsPagePtrs != hasPageSize) {
2431+
return emitOpError(
2432+
"ldsPagePtrs, firstPageIndex, and pageSize must all be "
2433+
"set together for paged attention, or none should be set");
2434+
}
2435+
2436+
if (hasPageSize && getPageSize().value().getSExtValue() <= 0) {
2437+
return emitOpError("pageSize must be positive");
2438+
}
2439+
2440+
if (getNumPagesPerBatch().has_value() &&
2441+
getNumPagesPerBatch().value().getSExtValue() <= 0) {
2442+
return emitOpError("numPagesPerBatch must be positive");
2443+
}
2444+
2445+
if (hasLdsPagePtrs) {
2446+
MemRefType ldsPagePtrsType = cast<MemRefType>(getLdsPagePtrs().getType());
2447+
if (ldsPagePtrsType.getRank() != 1) {
2448+
return emitOpError("ldsPagePtrs must be a 1D memref");
2449+
}
2450+
if (!ldsPagePtrsType.getElementType().isInteger(64)) {
2451+
return emitOpError("ldsPagePtrs must have i64 element type");
2452+
}
2453+
}
2454+
23962455
return success();
23972456
}
23982457

@@ -2568,6 +2627,34 @@ LogicalResult BlockwiseLoadTileOp::verify() {
25682627
return emitOpError("destRegisters must be set unless loadType is "
25692628
"Default/DirectToLDSDefault");
25702629

2630+
// Verify paged attention attributes consistency
2631+
bool hasPageTable = getPageTable() != nullptr;
2632+
bool hasPageSize = getPageSize().has_value();
2633+
2634+
if (hasPageTable != hasPageSize) {
2635+
return emitOpError(
2636+
"pageTable and pageSize must both be set or both be unset");
2637+
}
2638+
2639+
if (hasPageSize && getPageSize().value().getSExtValue() <= 0) {
2640+
return emitOpError("pageSize must be positive");
2641+
}
2642+
2643+
if (hasPageTable) {
2644+
MemRefType pageTableType = cast<MemRefType>(getPageTable().getType());
2645+
if (pageTableType.getRank() != 3) {
2646+
return emitOpError(
2647+
"pageTable must be a 3D memref with shape [batch, numPages, 1]");
2648+
}
2649+
if (pageTableType.getShape()[2] != 1) {
2650+
return emitOpError(
2651+
"pageTable last dimension must be 1 (shape [batch, numPages, 1])");
2652+
}
2653+
if (!pageTableType.getElementType().isInteger(64)) {
2654+
return emitOpError("pageTable must have i64 element type");
2655+
}
2656+
}
2657+
25712658
return success();
25722659
}
25732660

0 commit comments

Comments
 (0)