Skip to content

Commit 72f9c60

Browse files
committed
Attend to copilot comments and fix LIT tests
1 parent 07583ae commit 72f9c60

14 files changed

Lines changed: 47 additions & 37 deletions

mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def MIGraphX_DerefOp : MIGraphX_Op<"deref">,
233233

234234
Example:
235235
```mlir
236-
%result = migraphx.deref %addrs {target_type = 1 : i64}
236+
%result = migraphx.deref %addrs
237237
: <1x64x8192xui64, 524288x8192x1> to <1x64x8192xf16, 524288x8192x1>
238238
```
239239
}];

mlir/lib/Conversion/TosaToRock/TosaToRock.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3428,7 +3428,19 @@ static FailureOr<Value> matchDerefInputPattern(Value derefInput) {
34283428
Value lhsSource = getPreBroadcastSource(lhs);
34293429
Value rhsSource = getPreBroadcastSource(rhs);
34303430

3431-
// Helper to trace back through view ops to find the original 3D tensor
3431+
auto lhsType = cast<ShapedType>(lhsSource.getType());
3432+
auto rhsType = cast<ShapedType>(rhsSource.getType());
3433+
3434+
// Check which one has last dimension = 1 (pointers)
3435+
// The pointers tensor should have shape [batch, blocks, 1]
3436+
if (lhsType.getRank() == 3 && lhsType.getShape()[2] == 1)
3437+
return lhsSource;
3438+
if (rhsType.getRank() == 3 && rhsType.getShape()[2] == 1)
3439+
return rhsSource;
3440+
3441+
// If the direct check didn't find the pointers, trace back through view ops
3442+
// to find the original 3D tensor. This handles cases where the pointer tensor
3443+
// goes through reshape/slice operations.
34323444
auto traceBackThroughViewOps = [](Value v) -> Value {
34333445
while (Operation *defOp = v.getDefiningOp()) {
34343446
if (!viewOps.contains(defOp->getName().getStringRef()))
@@ -3446,8 +3458,6 @@ static FailureOr<Value> matchDerefInputPattern(Value derefInput) {
34463458
auto lhsOriginalType = cast<ShapedType>(lhsOriginal.getType());
34473459
auto rhsOriginalType = cast<ShapedType>(rhsOriginal.getType());
34483460

3449-
// Check which one has last dimension = 1 (pointers)
3450-
// The pointers tensor should have shape [batch, blocks, 1]
34513461
if (lhsOriginalType.getRank() == 3 && lhsOriginalType.getShape()[2] == 1)
34523462
return lhsOriginal;
34533463
if (rhsOriginalType.getRank() == 3 && rhsOriginalType.getShape()[2] == 1)

mlir/lib/Dialect/Rock/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,9 @@ struct DerefOpInterface
333333
auto derefOp = mlir::cast<DerefOp>(op);
334334

335335
// Get buffer for pointers operand
336-
FailureOr<Value> PointersBuffer =
336+
FailureOr<Value> pointersBuffer =
337337
getBuffer(rewriter, derefOp.getPointers(), options, state);
338-
if (failed(PointersBuffer))
338+
if (failed(pointersBuffer))
339339
return failure();
340340

341341
// Determine the result memref type from the tensor type
@@ -353,7 +353,7 @@ struct DerefOpInterface
353353

354354
// Create new op with memref types
355355
replaceOpWithNewBufferizedOp<DerefOp>(rewriter, op, resultMemRefType,
356-
*PointersBuffer);
356+
*pointersBuffer);
357357
return success();
358358
}
359359
};

mlir/test/Conversion/TosaToRock/tosa-to-rock-paged-attention.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@ func.func @test_paged_attention(
3939
%7 = tosa.mul %4, %4, %2 : (tensor<1x64x8192xi64>, tensor<1x64x8192xi64>, tensor<1xi8>) -> tensor<1x64x8192xi64>
4040
%8 = tosa.add %6, %7 : (tensor<1x64x8192xi64>, tensor<1x64x8192xi64>) -> tensor<1x64x8192xi64>
4141

42-
// CHECK: %[[KEY_DEREF:.*]] = rock.deref %{{.*}} : tensor<1x64x1xi64> -> tensor<1x64x8192xf16>
42+
// CHECK: %[[VAL_DEREF:.*]] = rock.deref %{{.*}} : tensor<1x64x1xi64> -> tensor<1x64x8192xf16>
4343
%9 = tosa.custom %8 {domain_name = "rocmlir", implementation_attrs = "", operator_name = "deref"} : (tensor<1x64x8192xi64>) -> tensor<1x64x8192xf16>
4444
%extracted_slice_3 = tensor.extract_slice %expanded_0[0, 0, 0] [1, 1, 64] [1, 1, 1] : tensor<1x16x64xi64> to tensor<1x1x64xi64>
4545
%collapsed_4 = tensor.collapse_shape %extracted_slice_3 [[0, 1], [2]] : tensor<1x1x64xi64> into tensor<1x64xi64>
4646
%expanded_5 = tensor.expand_shape %collapsed_4 [[0], [1, 2]] output_shape [1, 64, 1] : tensor<1x64xi64> into tensor<1x64x1xi64>
4747
%10 = tosa.mul %expanded_5, %4, %2 : (tensor<1x64x1xi64>, tensor<1x64x8192xi64>, tensor<1xi8>) -> tensor<1x64x8192xi64>
4848
%11 = tosa.add %10, %7 : (tensor<1x64x8192xi64>, tensor<1x64x8192xi64>) -> tensor<1x64x8192xi64>
4949

50-
// CHECK: %[[VAL_DEREF:.*]] = rock.deref %{{.*}} : tensor<1x64x1xi64> -> tensor<1x64x8192xf16>
50+
// CHECK: %[[KEY_DEREF:.*]] = rock.deref %{{.*}} : tensor<1x64x1xi64> -> tensor<1x64x8192xf16>
5151
%12 = tosa.custom %11 {domain_name = "rocmlir", implementation_attrs = "", operator_name = "deref"} : (tensor<1x64x8192xi64>) -> tensor<1x64x8192xf16>
5252
%extracted_slice_6 = tensor.extract_slice %5[0, 0, 0, 0] [1, 14, 1500, 64] [1, 1, 1, 1] : tensor<1x18x1500x64xf16> to tensor<1x14x1500x64xf16>
5353
%collapsed_7 = tensor.collapse_shape %9 [[0], [1, 2]] : tensor<1x64x8192xf16> into tensor<1x524288xf16>
@@ -71,8 +71,8 @@ func.func @test_paged_attention(
7171
%22 = tosa.mul %20, %21, %2 : (tensor<1x1x1x1xi32>, tensor<1x14x1500x4096xi32>, tensor<1xi8>) -> tensor<1x14x1500x4096xi32>
7272

7373
// CHECK: rock.attention
74-
// CHECK: keyAddresses = (%[[VAL_DEREF]] : tensor<1x64x8192xf16>)
75-
// CHECK: valueAddresses = (%[[KEY_DEREF]] : tensor<1x64x8192xf16>)
74+
// CHECK: keyAddresses = (%[[KEY_DEREF]] : tensor<1x64x8192xf16>)
75+
// CHECK: valueAddresses = (%[[VAL_DEREF]] : tensor<1x64x8192xf16>)
7676

7777
%23 = "tosa.const"() <{values = dense<0> : tensor<1x14x1500x4096xi32>}> : () -> tensor<1x14x1500x4096xi32>
7878
%24 = tosa.greater %23, %22 : (tensor<1x14x1500x4096xi32>, tensor<1x14x1500x4096xi32>) -> tensor<1x14x1500x4096xi1>

mlir/test/Dialect/Rock/effects.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ func.func @rock_gridwise_attn(%arg0: memref<1x384x64xf32>,
464464
params0 = #rock.accel_gemm_params<kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, kpack = 1, mPerWave = 32, nPerWave = 32, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true>,
465465
params1 = #rock.accel_gemm_params<kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, kpack = 1, mPerWave = 32, nPerWave = 32, mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 1, outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true>,
466466
firstGemmIndices = array<i64: 0>,
467-
operand_segment_sizes = array<i32: 1, 1, 1, 0, 0, 0, 1, 0>,
467+
operand_segment_sizes = array<i32: 1, 1, 1, 0, 0, 0, 0, 0, 1, 0>,
468468
splitKV = 1 : i32,
469469
storeMethod = #rock<StoreMethod set>
470470
} : memref<1x64x384xf32>, memref<1x64x384xf32>, memref<1x384x64xf32>, memref<1x384x64xf32>

mlir/test/Dialect/Rock/gridwise-attention-prefix-causal.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ module {
110110
}
111111
memref.copy %alloc_0, %arg6 : memref<1x14x4x16xf16> to memref<1x14x4x16xf16>
112112
rock.yield
113-
} {blockSize = 64 : i32, causal, firstGemmIndices = array<i64: 0>, gridSize = 14 : i32, operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 1, 1, 0>, params0 = #accel_gemm_params, params1 = #accel_gemm_params, prePadG0M = 16 : index, prePadG0N = 4 : index, softmaxType = f32, splitKV = 1 : i32, storeMethod = #rock<StoreMethod set>} : memref<14x64x32xf16>, memref<14x64x32xf16>, memref<14x32x64xf16>, memref<14xi32>, memref<14x32x64xf16>
113+
} {blockSize = 64 : i32, causal, firstGemmIndices = array<i64: 0>, gridSize = 14 : i32, operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 1, 0, 0, 1, 0>, params0 = #accel_gemm_params, params1 = #accel_gemm_params, prePadG0M = 16 : index, prePadG0N = 4 : index, softmaxType = f32, splitKV = 1 : i32, storeMethod = #rock<StoreMethod set>} : memref<14x64x32xf16>, memref<14x64x32xf16>, memref<14x32x64xf16>, memref<14xi32>, memref<14x32x64xf16>
114114
memref.copy %alloc, %arg4 : memref<3584xf16> to memref<3584xf16>
115115
return
116116
}

mlir/test/Dialect/Rock/gridwise-gemm-input-fusion-type-change.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ module {
102102
}
103103
memref.copy %alloc_1, %arg7 : memref<1x16x1500x1500xf32> to memref<1x16x1500x1500xf32>
104104
rock.yield
105-
} {blockSize = 64 : i32, firstGemmIndices = array<i64: 0>, gridSize = 752 : i32, operandSegmentSizes = array<i32: 1, 1, 1, 1, 0, 0, 1, 0>, params0 = #accel_gemm_params, params1 = #accel_gemm_params, prePadG0M = 1500 : index, prePadG0N = 1500 : index, softmaxType = f32, splitKV = 1 : i32, storeMethod = #rock<StoreMethod set>} : memref<16x64x1504xf16>, memref<16x64x1504xf16>, memref<16x1504x64xf32>, memref<2250000xf32>, memref<16x1504x64xf32>
105+
} {blockSize = 64 : i32, firstGemmIndices = array<i64: 0>, gridSize = 752 : i32, operandSegmentSizes = array<i32: 1, 1, 1, 1, 0, 0, 0, 0, 1, 0>, params0 = #accel_gemm_params, params1 = #accel_gemm_params, prePadG0M = 1500 : index, prePadG0N = 1500 : index, softmaxType = f32, splitKV = 1 : i32, storeMethod = #rock<StoreMethod set>} : memref<16x64x1504xf16>, memref<16x64x1504xf16>, memref<16x1504x64xf32>, memref<2250000xf32>, memref<16x1504x64xf32>
106106
memref.copy %alloc_0, %arg4 : memref<1536000xf32> to memref<1536000xf32>
107107
return
108108
}

0 commit comments

Comments
 (0)