Skip to content

Commit 034180c

Browse files
committed
Attend to copilot review comments and fix some LIT tests
1 parent fbd25aa commit 034180c

6 files changed

Lines changed: 62 additions & 21 deletions

File tree

mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ class LoweringBlockwiseLoadTileOp final
176176
Value numPagesPerBatchVal =
177177
b.createOrFold<arith::ConstantIndexOp>(loc, numPagesPerBatch);
178178

179+
// Get number of batches from page table shape for bounds checking
180+
auto pageTableType = cast<MemRefType>(pageTable.getType());
181+
int64_t numBatches = pageTableType.getShape()[0];
182+
Value numBatchesVal =
183+
b.createOrFold<arith::ConstantIndexOp>(loc, numBatches);
184+
179185
// Only threads with tid < numPagesForTile participate in loading.
180186
// Each such thread either loads from page table or stores 0 to its LDS
181187
// slot.
@@ -200,10 +206,10 @@ class LoweringBlockwiseLoadTileOp final
200206
arith::RemUIOp::create(outerThenBuilder, outerThenLoc,
201207
globalPageIdx, numPagesPerBatchVal);
202208

203-
// Check that local page index is within bounds
209+
// Check that batch index is within bounds.
204210
Value withinTableBound = arith::CmpIOp::create(
205211
outerThenBuilder, outerThenLoc, arith::CmpIPredicate::ult,
206-
localPageIdx, numPagesPerBatchVal);
212+
batchIdx, numBatchesVal);
207213

208214
// Select the pointer value: load from page table if in bounds, else 0
209215
scf::IfOp ptrIfOp = scf::IfOp::create(

mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,13 @@ LogicalResult ThreadwiseReadIntoRewritePattern::matchAndRewrite(
885885
Value ldsPageIdx =
886886
arith::SubIOp::create(b, loc, globalPageIdx, firstPageIdx);
887887

888-
// Clamp to [0, numPagesForTile-1] to prevent LDS out-of-bounds
888+
// Clamp to [0, numPagesForTile-1] to prevent LDS out-of-bounds.
889+
// We use signed max/min operations intentionally: if globalPageIdx <
890+
// firstPageIdx, the subtraction underflows and produces a bit pattern
891+
// that represents a negative value in two's complement. Using signed
892+
// comparison correctly detects this underflow and clamps to 0. Unsigned
893+
// comparison would treat the underflowed value as a large positive
894+
// number, failing to clamp it.
889895
MemRefType ldsType = cast<MemRefType>(ldsPagePtrs.getType());
890896
int64_t numPagesForTile = ldsType.getShape()[0];
891897
Value maxValidIdx =

mlir/lib/Dialect/Rock/Transforms/TransformToMemref.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,15 @@ struct TransformRewritePattern : public OpRewritePattern<TransformOp> {
207207
}
208208
}
209209
}
210-
// Fall back to the last non-empty group
210+
211+
// Fall back to the last non-empty group. This is semantically correct
212+
// because:
213+
// 1. AddDim always creates dimensions of size 1
214+
// 2. Size-1 dimensions can be grouped with any source dimension without
215+
// changing reshape semantics (product of dimension sizes is
216+
// preserved)
217+
// 3. The subsequent sort ensures contiguity, which is required by
218+
// expand_shape
211219
if (!found) {
212220
for (int srcDim = merges.size() - 1; srcDim >= 0; srcDim--) {
213221
if (!merges[srcDim].empty()) {

mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,4 +1079,5 @@ module {
10791079
memref.copy %alloc, %arg5 : memref<1344000xf16> to memref<1344000xf16>
10801080
return
10811081
}
1082-
}
1082+
}
1083+

mlir/test/Dialect/Rock/lowering_global_load_store.mlir

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -553,58 +553,54 @@ func.func @load_4bit_vector_boundary_case(%mem: memref<4294967295xi4>) -> vector
553553
}
554554

555555
// CHECK-LABEL: func.func @load_paged_scalar
556-
// CHECK-SAME: (%[[mem:.*]]: memref<1x64x8192xf16>, %[[pagePtr:.*]]: i64, %[[offset:.*]]: index)
557-
func.func @load_paged_scalar(%mem: memref<1x64x8192xf16>, %pagePtr: i64, %offset: index) -> f16 attributes {arch = "amdgcn-amd-amdhsa:gfx942"} {
556+
// CHECK-SAME: (%[[mem:.*]]: memref<8192xf16>, %[[pagePtr:.*]]: i64, %[[offset:.*]]: index)
557+
func.func @load_paged_scalar(%mem: memref<8192xf16>, %pagePtr: i64, %offset: index) -> f16 attributes {arch = "amdgcn-amd-amdhsa:gfx942"} {
558558
%true = arith.constant true
559-
// Paged load converts page ptr to llvm.ptr, creates buffer resource, and loads
560559
// CHECK-DAG: %[[pageSizeBytes:.*]] = llvm.mlir.constant(16384 : i64) : i64
561560
// CHECK: %[[ptr:.*]] = llvm.inttoptr %[[pagePtr]] : i64 to !llvm.ptr<1>
562561
// CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %{{.*}}, %[[pageSizeBytes]], %{{.*}} : <1> to <8>
563562
// CHECK: rocdl.raw.ptr.buffer.load %[[rsrc]]
564563
%ret = rock.global_load %mem[%offset] if %true paged %pagePtr {pageSize = 8192 : i64}
565-
: memref<1x64x8192xf16> -> f16
564+
: memref<8192xf16> -> f16
566565
return %ret : f16
567566
}
568567

569568
// CHECK-LABEL: func.func @load_paged_vector
570-
// CHECK-SAME: (%[[mem:.*]]: memref<1x64x8192xf16>, %[[pagePtr:.*]]: i64, %[[offset:.*]]: index)
571-
func.func @load_paged_vector(%mem: memref<1x64x8192xf16>, %pagePtr: i64, %offset: index) -> vector<2xf16> attributes {arch = "amdgcn-amd-amdhsa:gfx942"} {
569+
// CHECK-SAME: (%[[mem:.*]]: memref<8192xf16>, %[[pagePtr:.*]]: i64, %[[offset:.*]]: index)
570+
func.func @load_paged_vector(%mem: memref<8192xf16>, %pagePtr: i64, %offset: index) -> vector<2xf16> attributes {arch = "amdgcn-amd-amdhsa:gfx942"} {
572571
%true = arith.constant true
573-
// Paged vector load converts page ptr to llvm.ptr, creates buffer resource, and loads
574572
// CHECK-DAG: %[[pageSizeBytes:.*]] = llvm.mlir.constant(16384 : i64) : i64
575573
// CHECK: %[[ptr:.*]] = llvm.inttoptr %[[pagePtr]] : i64 to !llvm.ptr<1>
576574
// CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %{{.*}}, %[[pageSizeBytes]], %{{.*}} : <1> to <8>
577575
// CHECK: rocdl.raw.ptr.buffer.load %[[rsrc]]
578576
%ret = rock.global_load %mem[%offset] if %true paged %pagePtr {pageSize = 8192 : i64}
579-
: memref<1x64x8192xf16> -> vector<2xf16>
577+
: memref<8192xf16> -> vector<2xf16>
580578
return %ret : vector<2xf16>
581579
}
582580

583581
// CHECK-LABEL: func.func @load_paged_vector_maybe_oob
584-
// CHECK-SAME: (%[[mem:.*]]: memref<1x64x8192xf16>, %[[pagePtr:.*]]: i64, %[[offset:.*]]: index, %[[valid:.*]]: i1)
585-
func.func @load_paged_vector_maybe_oob(%mem: memref<1x64x8192xf16>, %pagePtr: i64, %offset: index, %valid: i1) -> vector<2xf16> attributes {arch = "amdgcn-amd-amdhsa:gfx942"} {
586-
// Paged load with validity check - scf.if guards the buffer load
582+
// CHECK-SAME: (%[[mem:.*]]: memref<8192xf16>, %[[pagePtr:.*]]: i64, %[[offset:.*]]: index, %[[valid:.*]]: i1)
583+
func.func @load_paged_vector_maybe_oob(%mem: memref<8192xf16>, %pagePtr: i64, %offset: index, %valid: i1) -> vector<2xf16> attributes {arch = "amdgcn-amd-amdhsa:gfx942"} {
587584
// CHECK-DAG: %[[pageSizeBytes:.*]] = llvm.mlir.constant(16384 : i64) : i64
588585
// CHECK: %[[ptr:.*]] = llvm.inttoptr %[[pagePtr]] : i64 to !llvm.ptr<1>
589586
// CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %{{.*}}, %[[pageSizeBytes]], %{{.*}} : <1> to <8>
590587
// CHECK: scf.if %[[valid]]
591588
// CHECK: rocdl.raw.ptr.buffer.load %[[rsrc]]
592589
%ret = rock.global_load %mem[%offset] if %valid paged %pagePtr {pageSize = 8192 : i64}
593-
: memref<1x64x8192xf16> -> vector<2xf16>
590+
: memref<8192xf16> -> vector<2xf16>
594591
return %ret : vector<2xf16>
595592
}
596593

597594
// CHECK-LABEL: func.func @load_paged_vector_large_page
598-
// CHECK-SAME: (%[[mem:.*]]: memref<1x64x16384xf32>, %[[pagePtr:.*]]: i64, %[[offset:.*]]: index)
599-
func.func @load_paged_vector_large_page(%mem: memref<1x64x16384xf32>, %pagePtr: i64, %offset: index) -> vector<4xf32> attributes {arch = "amdgcn-amd-amdhsa:gfx942"} {
595+
// CHECK-SAME: (%[[mem:.*]]: memref<16384xf32>, %[[pagePtr:.*]]: i64, %[[offset:.*]]: index)
596+
func.func @load_paged_vector_large_page(%mem: memref<16384xf32>, %pagePtr: i64, %offset: index) -> vector<4xf32> attributes {arch = "amdgcn-amd-amdhsa:gfx942"} {
600597
%true = arith.constant true
601-
// Larger page size (16384 elements * 4 bytes = 65536 bytes)
602598
// CHECK-DAG: %[[pageSizeBytes:.*]] = llvm.mlir.constant(65536 : i64) : i64
603599
// CHECK: %[[ptr:.*]] = llvm.inttoptr %[[pagePtr]] : i64 to !llvm.ptr<1>
604600
// CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %{{.*}}, %[[pageSizeBytes]], %{{.*}} : <1> to <8>
605601
// CHECK: rocdl.raw.ptr.buffer.load %[[rsrc]]
606602
%ret = rock.global_load %mem[%offset] if %true paged %pagePtr {pageSize = 16384 : i64}
607-
: memref<1x64x16384xf32> -> vector<4xf32>
603+
: memref<16384xf32> -> vector<4xf32>
608604
return %ret : vector<4xf32>
609605
}
610606
}

mlir/test/Dialect/Rock/toblockwise_attention_accel_lowering.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,27 @@ func.func @gridwise_attn_schedulev2(%arg0: memref<1x384x64xf32>, %arg1: memref<1
208208
} : memref<1x64x384xf32>, memref<1x64x384xf32>, memref<1x384x64xf32>, memref<1x384x64xf32>
209209
return
210210
}
211+
212+
// -----
213+
214+
// CHECK-LABEL: func.func @paged_attention_disables_direct_to_lds
215+
// CHECK-NOT: DirectToLDSDefault
216+
// CHECK: rock.blockwise_load_tile
217+
// CHECK-SAME: loadType = #rock<GemmLoadTileType Default>
218+
func.func @paged_attention_disables_direct_to_lds(%arg0: memref<1x64x384xf16>, %arg1: memref<1x64x384xf16>, %arg2: memref<1x384x64xf16>, %arg3: memref<1x384x64xf16>, %pageTable: memref<1x64x1xi64>) attributes {block_size = 64 : i32, features = #rock<GemmFeatures mfma|dot|atomic_add|direct_to_lds_128b>, grid_size = 24 : i32, kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx950:sramecc+:xnack-"} {
219+
%0 = rock.transform %arg0 by <affine_map<(d0, d1, d2) -> (d0, d2, d1)> by [<PassThrough ["gemmG"] at [0] -> ["gemmG"] at [0]>, <PassThrough ["gemm0K", "gemm0M"] at [1, 2] -> ["gemm0K", "gemm0M"] at [2, 1]>] bounds = [1, 64, 384] -> [1, 384, 64]> : memref<1x64x384xf16> to memref<1x64x384xf16>
220+
%keyAddrs = rock.deref %pageTable : memref<1x64x1xi64> -> memref<1x64x8192xf16>
221+
%valueAddrs = rock.deref %pageTable : memref<1x64x1xi64> -> memref<1x64x8192xf16>
222+
rock.gridwise_attention_accel(%0, %arg1, %arg2, %keyAddrs, %valueAddrs, %arg3) preSoftmaxOps = {} {
223+
blockSize = 64 : i32,
224+
gridSize = 24 : i32,
225+
params0 = #rock.accel_gemm_params<kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, kpack = 1, mPerWave = 32, nPerWave = 32, mnPerXdl = 16, splitKFactor = 1, scheduleVersion = 3, outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true>,
226+
params1 = #rock.accel_gemm_params<kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, kpack = 1, mPerWave = 32, nPerWave = 32, mnPerXdl = 16, splitKFactor = 1, scheduleVersion = 3, outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true>,
227+
firstGemmIndices = array<i64: 0>,
228+
splitKV = 1 : i32,
229+
storeMethod = #rock<StoreMethod set>,
230+
operand_segment_sizes = array<i32: 1, 1, 1, 0, 0, 0, 1, 1, 1, 0>
231+
} : memref<1x64x384xf16>, memref<1x64x384xf16>, memref<1x384x64xf16>, memref<1x64x8192xf16>, memref<1x64x8192xf16>, memref<1x384x64xf16>
232+
return
233+
}
234+

0 commit comments

Comments
 (0)