@@ -1141,11 +1141,13 @@ def Rock_GlobalPrefetchOp
11411141def Rock_GlobalLoadOp
11421142 : Rock_Op<
11431143 "global_load", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1144- AllElementTypesMatch<["source", "result"]>]>,
1144+ AllElementTypesMatch<["source", "result"]>,
1145+ AttrSizedOperandSegments]>,
11451146 Arguments<(
11461147 ins Arg<MemRefOf<SupportedMemoryElems>, "source memory">:$source,
11471148 I1:$valid, Variadic<Index>:$sourceCoord, UnitAttr:$needs64BitIdx,
1148- UnitAttr:$canReadOffEnd)>,
1149+ UnitAttr:$canReadOffEnd, Optional<I64>:$pagePtr,
1150+ OptionalAttr<I64Attr>:$pageSize)>,
11491151 Results<(outs AnyType:$result)> {
11501152 let summary = "Load from global memory, applying bounds checks";
11511153 let description = [{
@@ -1162,9 +1164,17 @@ def Rock_GlobalLoadOp
11621164 If `canReadOffEnd` is present, then the `valid` input is ignored and the
11631165 bounds-checked implementation of this operation is always used. This
11641166 is used to simplify the implementation of certain utility kernels.
1167+
1168+ When `pagePtr` and `pageSize` are provided, this
1169+ operation loads from paged memory. The `source` memref is still used for
1170+ type information (element type, etc.), but the actual load uses `pagePtr`
1171+ as the base address. `sourceCoord[0]` is treated as the offset-in-page in
1172+ elements (a single flat index, not multi-dimensional coordinates).
11651173 }];
11661174 let assemblyFormat = [{
1167- $source `[` $sourceCoord `]` `if` $valid attr-dict
1175+ $source `[` $sourceCoord `]` `if` $valid
1176+ (`paged` $pagePtr^)?
1177+ attr-dict
11681178 `:` type($source) `->` type($result)
11691179 }];
11701180 let hasVerifier = 1;
@@ -1200,7 +1210,8 @@ def Rock_GlobalLoadToLDSOp
12001210 }];
12011211 let assemblyFormat = [{
12021212 $source `[` $sourceCoord `]`
1203- `->` $dest `[` $destCoord `]` `if` $valid attr-dict
1213+ `->` $dest `[` $destCoord `]` `if` $valid
1214+ attr-dict
12041215 `:` type($source) `->` type($dest)
12051216 }];
12061217 let hasVerifier = 1;
@@ -1302,14 +1313,21 @@ def Rock_ThreadwiseReadIntoOp
13021313 [](::mlir::Type lhs, ::mlir::TypeRange rhs) -> bool {
13031314 return llvm::all_of(rhs, [lhs](::mlir::Type t) { return t == lhs; });
13041315 }
1305- }]>]>,
1306- Arguments<(ins Arg<MemRefOf<NativeMemoryOpTypes>, "source view">:$source,
1316+ }]>]> {
1317+
1318+ let arguments =
1319+ (ins Arg<MemRefOf<NativeMemoryOpTypes>, "source view">:$source,
13071320 Arg<MemRefOf<NativeMemoryOpTypes>, "destination registers">:$dest,
13081321 Variadic<VectorOfNonZeroRankOf<[I1]>>:$dynamicValidities,
13091322 TransformMapArrayAttr:$extraViews, Variadic<Index>:$extraIndices,
13101323 UnitAttr:$forceUnroll, UnitAttr:$useIndexDiffs,
1311- OptionalAttr<Rock_LDSTransposeConfigAttr>:$ldsTransposeConfig)>,
1312- Results<(outs Optional<VectorOfNonZeroRankOf<[I1]>>:$validityRecord)> {
1324+ OptionalAttr<Rock_LDSTransposeConfigAttr>:$ldsTransposeConfig,
1325+ Optional<MemRefOf<[I64]>>:$ldsPagePtrs,
1326+ Optional<Index>:$firstPageIndex, OptionalAttr<IndexAttr>:$pageSize,
1327+ OptionalAttr<IndexAttr>:$numPagesPerBatch);
1328+
1329+ let results = (outs Optional<VectorOfNonZeroRankOf<[I1]>>:$validityRecord);
1330+
13131331 let summary = "Read values from transformed source into destination";
13141332
13151333 let description = [{
@@ -1341,8 +1359,8 @@ def Rock_ThreadwiseReadIntoOp
13411359 L is the length of `%dest`, V is the maximum vectorization computed
13421360 for MAPS.
13431361
1344- The input to extraViews: (the transforms on %source) must have the form
1345- (extraIdx0, ... , extraIdxN, iteration_number)
1362+ The input to extraViews (the transforms on %source) must have the form
1363+ (extraIdx0, ... , extraIdxN, iteration_number).
13461364
13471365 Primarily, extraIndices would be used to pass in tid and bid. This would need
13481366 to have a matching view in [extraViews]source. The extraIndices could be used
@@ -1368,6 +1386,14 @@ def Rock_ThreadwiseReadIntoOp
13681386
13691387 This operation is also used during fusion to represent loads from additional
13701388 arguments like bias tensors.
1389+
1390+ When `ldsPagePtrs`, `firstPageIndex`, and `pageSize` are provided, this
1391+ operation reads from paged memory where data is scattered across
1392+ non-contiguous pages accessed via a page table. The page pointers are
1393+ pre-loaded into LDS by the calling code to avoid redundant global memory
1394+ accesses across threads. During lowering, the transform maps compute the
1395+ flat position for each element, determining the page index and offset
1396+ within page.
13711397 }];
13721398
13731399 let builders = [
@@ -1377,28 +1403,38 @@ def Rock_ThreadwiseReadIntoOp
13771403 "ValueRange":$extraIndices, "bool":$forceUnroll,
13781404 "bool":$useIndexDiffs),
13791405 [{
1380- build($_builder, $_state, TypeRange{}, source, dest, ValueRange{}, extraViews, extraIndices, forceUnroll, useIndexDiffs, /*ldsTransposeConfig=*/nullptr);
1406+ build($_builder, $_state, TypeRange{}, source, dest, ValueRange{}, extraViews, extraIndices, forceUnroll, useIndexDiffs, /*ldsTransposeConfig=*/nullptr, /*ldsPagePtrs=*/Value{}, /*firstPageIndex=*/Value{}, /*pageSize=*/nullptr, /*numPagesPerBatch=*/nullptr );
13811407 }]>,
13821408 // Builder with explicit ldsTransposeConfig support
13831409 OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayAttr":$extraViews,
13841410 "ValueRange":$extraIndices, "bool":$forceUnroll,
13851411 "bool":$useIndexDiffs,
13861412 "LDSTransposeConfigAttr":$ldsTransposeConfig),
13871413 [{
1388- build($_builder, $_state, TypeRange{}, source, dest, ValueRange{}, extraViews, extraIndices, forceUnroll, useIndexDiffs, ldsTransposeConfig);
1414+ build($_builder, $_state, TypeRange{}, source, dest, ValueRange{}, extraViews, extraIndices, forceUnroll, useIndexDiffs, ldsTransposeConfig, /*ldsPagePtrs=*/Value{}, /*firstPageIndex=*/Value{}, /*pageSize=*/nullptr, /*numPagesPerBatch=*/nullptr );
13891415 }]>,
13901416 // Builder with validityRecord but without ldsTransposeConfig
13911417 OpBuilder<(ins "TypeRange":$validityRecord, "Value":$source,
13921418 "Value":$dest, "ValueRange":$dynamicValidities,
13931419 "ArrayAttr":$extraViews, "ValueRange":$extraIndices,
13941420 "bool":$forceUnroll, "bool":$useIndexDiffs),
13951421 [{
1396- build($_builder, $_state, validityRecord, source, dest, dynamicValidities, extraViews, extraIndices, forceUnroll, useIndexDiffs, /*ldsTransposeConfig=*/nullptr);
1422+ build($_builder, $_state, validityRecord, source, dest, dynamicValidities, extraViews, extraIndices, forceUnroll, useIndexDiffs, /*ldsTransposeConfig=*/nullptr, /*ldsPagePtrs=*/Value{}, /*firstPageIndex=*/Value{}, /*pageSize=*/nullptr, /*numPagesPerBatch=*/nullptr);
1423+ }]>,
1424+ // Builder with validityRecord and ldsTransposeConfig
1425+ OpBuilder<(ins "TypeRange":$validityRecord, "Value":$source,
1426+ "Value":$dest, "ValueRange":$dynamicValidities,
1427+ "ArrayAttr":$extraViews, "ValueRange":$extraIndices,
1428+ "bool":$forceUnroll, "bool":$useIndexDiffs,
1429+ "LDSTransposeConfigAttr":$ldsTransposeConfig),
1430+ [{
1431+ build($_builder, $_state, validityRecord, source, dest, dynamicValidities, extraViews, extraIndices, forceUnroll, useIndexDiffs, ldsTransposeConfig, /*ldsPagePtrs=*/Value{}, /*firstPageIndex=*/Value{}, /*pageSize=*/nullptr, /*numPagesPerBatch=*/nullptr);
13971432 }]>];
13981433
13991434 let assemblyFormat = [{
14001435 attr-dict $extraViews `(` $source `)` (`[` $extraIndices^ `]`)? `->` $dest
14011436 (`if` ` ` `[` $dynamicValidities^ `]`)?
1437+ (`paged` $ldsPagePtrs^ `:` type($ldsPagePtrs) `[` $firstPageIndex `]`)?
14021438 `:` type($source) `->` type($dest) (`,` type($validityRecord)^)?
14031439 }];
14041440 let hasVerifier = 1;
@@ -1634,7 +1670,9 @@ def Rock_BlockwiseLoadTileOp
16341670 Rock_BlockwiseMatrixParamsAttr:$matrixParamsB, UnitAttr:$isA,
16351671 Variadic<Index>:$sourceIndices,
16361672 OptionalAttr<Rock_GemmFeaturesAttr>:$features, I32Attr:$blockSize,
1637- RockAccelTuningParamAttrInterface:$params)> {
1673+ RockAccelTuningParamAttrInterface:$params,
1674+ Optional<MemRefOf<[I64]>>:$pageTable,
1675+ OptionalAttr<IndexAttr>:$pageSize)> {
16381676 let summary =
16391677 "Blockwise load tile from global memory to LDS and/or registers";
16401678 let description = [{
@@ -1651,10 +1689,20 @@ def Rock_BlockwiseLoadTileOp
16511689 `isA` determines if we are loading an A matrix or B matrix. `G`, `M` and `N` are the GEMM sizes.
16521690 `elementTypeA` and `elementTypeB` are used to construct AccelEmitter. They are data types for the Matrix A & B of the GEMMs.
16531691 `elementLoadType` is the element type of the global buffer this BlockwiseLoadTileOp is trying to load.
1654- `elementType` is the elementType in the registers. Note that it may differ from `elementLoadType` because of input fusions.
1692+ `elementType` is the elementType in the registers. Note that it may differ from `elementLoadType` because of input fusions.
1693+
1694+ For paged attention (when K/V tensors are stored in non-contiguous pages),
1695+ the following optional parameters enable indirect memory access:
1696+ - `pageTable`: The i64 page table memref containing base pointers for each page.
1697+ Shape is [batch, numPages, 1] where each entry is a pointer to a page.
1698+ - `pageSize`: Number of elements per page, used to compute page boundaries.
1699+ When these are provided, the lowering will compute page indices from logical
1700+ positions and load page pointers to resolve physical addresses.
16551701 }];
16561702 let assemblyFormat = [{
1657- $source (`[` $sourceIndices^ `]`)? (`LDS` `->` $destLDS^)? (`->` $destRegisters^)? attr-dict
1703+ $source (`[` $sourceIndices^ `]`)? (`LDS` `->` $destLDS^)? (`->` $destRegisters^)?
1704+ (`paged` $pageTable^ `:` type($pageTable))?
1705+ attr-dict
16581706 `:` type($source) (`LDS` `->` type($destLDS)^)? (`->` type($destRegisters)^)?
16591707 }];
16601708
0 commit comments