Skip to content

Commit d376a7e

Browse files
authored
[mlir] TileUsingInterface bugfix for dominance error (llvm#178190)
In this PR i move the insertion point in the `yieldReplacementForFusedProducer` because i ran into some issue where a `tensor.extract_slices` tried to use a result of `affine.apply` that was inserted at the end of the block instead of the start of it. This is the full error of the test i added before this change: ```mlir third-party/llvm-project/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir:83:11: error: operand #1 does not dominate this use %pack = linalg.pack %gen#1 ^ third-party/llvm-project/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir:83:11: note: see current operation: %24 = "tensor.extract_slice"(%23, %36, %8) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xf32>, index, index) -> tensor<?x1024xf32> third-party/llvm-project/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir:71:12: note: operand defined here (op in the same block) %gen:2 = linalg.generic { ^ // -----// IR Dump After InterpreterPass Failed (transform-interpreter) //----- // #map = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0) -> (d0 * 16)> #map2 = affine_map<(d0) -> (d0 * -16 + 32)> #map3 = affine_map<(d0) -> (16, d0 * -16 + 32)> #map4 = affine_map<(d0) -> (d0 - 1)> "builtin.module"() ({ "func.func"() <{function_type = (tensor<32x1024xf32>) -> (tensor<32x1024xf32>, tensor<2x512x16x2xi8>), sym_name = "fuse_pack_consumer_into_multi_output_generic"}> ({ ^bb0(%arg1: tensor<32x1024xf32>): %2 = "arith.constant"() <{value = 0 : i8}> : () -> i8 %3 = "tensor.empty"() : () -> tensor<32x1024xf32> %4 = "tensor.empty"() : () -> tensor<32x1024xi8> %5 = "tensor.empty"() : () -> tensor<2x512x16x2xi8> %6:2 = "linalg.generic"(%arg1, %3, %4) <{indexing_maps = [#map, #map, #map], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 2>}> ({ ^bb0(%arg9: f32, %arg10: f32, %arg11: i8): %41 = "arith.fptoui"(%arg9) : (f32) -> i8 "linalg.yield"(%arg9, %41) : (f32, i8) -> () }) : (tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xi8>) -> (tensor<32x1024xf32>, tensor<32x1024xi8>) %7:3 = "scf.forall"(%5, %3, %4) <{operandSegmentSizes = array<i32: 0, 0, 0, 3>, staticLowerBound = array<i64: 0>, staticStep = array<i64: 1>, staticUpperBound = array<i64: 2>}> ({ ^bb0(%arg2: index, %arg3: tensor<2x512x16x2xi8>, %arg4: tensor<32x1024xf32>, %arg5: tensor<32x1024xi8>): %8 = "affine.apply"(%arg2) <{map = #map1}> : (index) -> index %9 = "affine.apply"(%arg2) <{map = #map2}> : (index) -> index %10 = "affine.min"(%arg2) <{map = #map3}> : (index) -> index %11 = "affine.apply"(%10) <{map = #map4}> : (index) -> index %12 = "affine.apply"(%arg2) <{map = #map1}> : (index) -> index %13 = "affine.apply"(%10) <{map = #map4}> : (index) -> index %14 = "affine.apply"(%arg2) <{map = #map1}> : (index) -> index %15 = "affine.apply"(%10) <{map = #map4}> : (index) -> index %16 = "affine.apply"(%arg2) <{map = #map1}> : (index) -> index %17 = "affine.apply"(%10) <{map = #map4}> : (index) -> index %18 = "tensor.extract_slice"(%arg1, %12, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xf32>, index, index) -> tensor<?x1024xf32> %19 = "tensor.empty"() : () -> tensor<32x1024xf32> %20 = "tensor.extract_slice"(%19, %14, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xf32>, index, index) -> tensor<?x1024xf32> %21 = "tensor.extract_slice"(%3, %14, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xf32>, index, index) -> tensor<?x1024xf32> %22 = "tensor.empty"() : () -> tensor<32x1024xi8> %23 = "tensor.extract_slice"(%22, %16, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xi8>, index, index) -> tensor<?x1024xi8> %24 = "tensor.extract_slice"(%4, %16, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xi8>, index, index) -> tensor<?x1024xi8> %25 = "tensor.empty"() : () -> tensor<32x1024xf32> %26 = "tensor.extract_slice"(%25, %38, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xf32>, index, index) -> tensor<?x1024xf32> %27 = "tensor.extract_slice"(%arg4, %38, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xf32>, index, index) -> tensor<?x1024xf32> %28 = "tensor.empty"() : () -> tensor<32x1024xi8> %29 = "tensor.extract_slice"(%28, %8, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xi8>, index, index) -> tensor<?x1024xi8> %30 = "tensor.extract_slice"(%arg5, %8, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xi8>, index, index) -> tensor<?x1024xi8> %31:2 = "linalg.generic"(%18, %27, %30) <{indexing_maps = [#map, #map, #map], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 2>}> ({ ^bb0(%arg6: f32, %arg7: f32, %arg8: i8): %40 = "arith.fptoui"(%arg6) : (f32) -> i8 "linalg.yield"(%arg6, %40) : (f32, i8) -> () }) : (tensor<?x1024xf32>, tensor<?x1024xf32>, tensor<?x1024xi8>) -> (tensor<?x1024xf32>, tensor<?x1024xi8>) %32 = "tensor.extract_slice"(%6#1, %8, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xi8>, index, index) -> tensor<?x1024xi8> %33 = "tensor.empty"() : () -> tensor<2x512x16x2xi8> %34 = "tensor.extract_slice"(%33, %arg2) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0, 0, 0>, static_sizes = array<i64: 1, 512, 16, 2>, static_strides = array<i64: 1, 1, 1, 1>}> : (tensor<2x512x16x2xi8>, index) -> tensor<1x512x16x2xi8> %35 = "tensor.extract_slice"(%arg3, %arg2) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0, 0, 0>, static_sizes = array<i64: 1, 512, 16, 2>, static_strides = array<i64: 1, 1, 1, 1>}> : (tensor<2x512x16x2xi8>, index) -> tensor<1x512x16x2xi8> %36 = "linalg.pack"(%31#1, %35, %2) <{inner_dims_pos = array<i64: 0, 1>, operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_inner_tiles = array<i64: 16, 2>}> : (tensor<?x1024xi8>, tensor<1x512x16x2xi8>, i8) -> tensor<1x512x16x2xi8> %37 = "affine.apply"(%10) <{map = #map4}> : (index) -> index %38 = "affine.apply"(%arg2) <{map = #map1}> : (index) -> index %39 = "affine.apply"(%10) <{map = #map4}> : (index) -> index "scf.forall.in_parallel"() ({ "tensor.parallel_insert_slice"(%36, %arg3, %arg2) <{operandSegmentSizes = array<i32: 1, 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0, 0, 0>, static_sizes = array<i64: 1, 512, 16, 2>, static_strides = array<i64: 1, 1, 1, 1>}> : (tensor<1x512x16x2xi8>, tensor<2x512x16x2xi8>, index) -> () "tensor.parallel_insert_slice"(%31#0, %arg4, %38, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<?x1024xf32>, tensor<32x1024xf32>, index, index) -> () "tensor.parallel_insert_slice"(%31#1, %arg5, %8, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<?x1024xi8>, tensor<32x1024xi8>, index, index) -> () }) : () -> () }) : (tensor<2x512x16x2xi8>, tensor<32x1024xf32>, tensor<32x1024xi8>) -> (tensor<2x512x16x2xi8>, tensor<32x1024xf32>, tensor<32x1024xi8>) "func.return"(%7#1, %7#0) : (tensor<32x1024xf32>, tensor<2x512x16x2xi8>) -> () }) : () -> () "builtin.module"() ({ "transform.named_sequence"() <{arg_attrs = [{transform.readonly}], function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({ ^bb0(%arg0: !transform.any_op): %0 = "transform.structured.match"(%arg0) <{ops = ["linalg.pack"]}> : (!transform.any_op) -> !transform.any_op %1:2 = "transform.test.fuse_and_yield"(%0) <{tile_interchange = [], tile_sizes = [1], use_forall = true}> : (!transform.any_op) -> (!transform.any_op, !transform.any_op) "transform.yield"() : () -> () }) : () -> () }) {transform.with_named_sequence} : () -> () }) : () -> () ``` I also noticed that Interface tests are missing from the bazel overlay so i also added this.
1 parent 44aebb6 commit d376a7e

File tree

3 files changed

+97
-2
lines changed

3 files changed

+97
-2
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,14 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
15091509
auto tilableOp = cast<TilingInterface>(originalOwner);
15101510
// b. get iterDomain Offset and Sizes based on sliceOp tile
15111511
SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1512+
// Set insertion point before any operations that might create new SSA
1513+
// values used in offset/size computations. This ensures all values created
1514+
// by getIterationDomainTileFromResultTile and getResultTilePosition
1515+
// dominate the extract_slice operations created later.
1516+
if (auto tiledDestStyleOp =
1517+
dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1518+
rewriter.setInsertionPoint(tiledDestStyleOp);
1519+
}
15121520
// skip tensor.pack/unpack/pad, which expects single opResult
15131521
if (tilableOp->getNumResults() > 1 &&
15141522
failed(tilableOp.getIterationDomainTileFromResultTile(
@@ -1550,7 +1558,6 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
15501558
// necessary
15511559
if (auto tiledDestStyleOp =
15521560
dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1553-
rewriter.setInsertionPoint(tiledDestStyleOp);
15541561
for (const auto &&[index, newRegionArg] :
15551562
llvm::enumerate(newRegionIterArgs)) {
15561563
auto destSlice = tensor::ExtractSliceOp::create(

mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -transform-interpreter -cse -split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt -transform-interpreter -cse -canonicalize -split-input-file %s | FileCheck %s
22

33
func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %rhs1 : tensor<?x?xf32>,
44
%init0 : tensor<?x?xf32>, %init1 : tensor<?x?xf32>)
@@ -58,3 +58,71 @@ module attributes {transform.with_named_sequence} {
5858
// CHECK: tensor.parallel_insert_slice %[[GEMM1_TILE]] into %[[ITERARG0]][%[[IV]], 0]
5959
// CHECK: tensor.parallel_insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
6060
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0
61+
62+
// -----
63+
64+
func.func @fuse_pack_consumer_into_multi_output_generic(
65+
%input: tensor<32x1024xf32>) -> (tensor<32x1024xf32>, tensor<2x512x16x2xi8>) {
66+
%c0_i8 = arith.constant 0 : i8
67+
%output_f32 = tensor.empty() : tensor<32x1024xf32>
68+
%output_i8 = tensor.empty() : tensor<32x1024xi8>
69+
%pack_dest = tensor.empty() : tensor<2x512x16x2xi8>
70+
71+
%gen:2 = linalg.generic {
72+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
73+
affine_map<(d0, d1) -> (d0, d1)>,
74+
affine_map<(d0, d1) -> (d0, d1)>],
75+
iterator_types = ["parallel", "parallel"]
76+
} ins(%input : tensor<32x1024xf32>)
77+
outs(%output_f32, %output_i8 : tensor<32x1024xf32>, tensor<32x1024xi8>) {
78+
^bb0(%in: f32, %out_f: f32, %out_i: i8):
79+
%q = arith.fptoui %in : f32 to i8
80+
linalg.yield %in, %q : f32, i8
81+
} -> (tensor<32x1024xf32>, tensor<32x1024xi8>)
82+
83+
%pack = linalg.pack %gen#1
84+
padding_value(%c0_i8 : i8)
85+
inner_dims_pos = [0, 1]
86+
inner_tiles = [16, 2]
87+
into %pack_dest : tensor<32x1024xi8> -> tensor<2x512x16x2xi8>
88+
89+
return %gen#0, %pack : tensor<32x1024xf32>, tensor<2x512x16x2xi8>
90+
}
91+
92+
module attributes {transform.with_named_sequence} {
93+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
94+
%pack = transform.structured.match ops{["linalg.pack"]} in %arg0
95+
: (!transform.any_op) -> !transform.any_op
96+
%a, %b = transform.test.fuse_and_yield %pack [1] use_forall true
97+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
98+
transform.yield
99+
}
100+
}
101+
// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 16)>
102+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * -16 + 32, 16)>
103+
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
104+
// CHECK: func.func @fuse_pack_consumer_into_multi_output_generic(
105+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<32x1024xf32>)
106+
// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
107+
// CHECK-DAG: %[[OUTPUT_F32:.+]] = tensor.empty() : tensor<32x1024xf32>
108+
// CHECK-DAG: %[[OUTPUT_I8:.+]] = tensor.empty() : tensor<32x1024xi8>
109+
// CHECK-DAG: %[[PACK_DEST:.+]] = tensor.empty() : tensor<2x512x16x2xi8>
110+
// CHECK: %[[RESULT:.+]]:2 = scf.forall (%[[IV:.+]]) in (2)
111+
// CHECK-SAME: shared_outs(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[PACK_DEST]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[OUTPUT_F32]])
112+
// CHECK: %[[OFFSET:.+]] = affine.apply #[[$MAP0]](%[[IV]])
113+
// CHECK: %[[SIZE:.+]] = affine.min #[[$MAP1]](%[[IV]])
114+
// CHECK-DAG: %[[INPUT_TILE:.+]] = tensor.extract_slice %[[INPUT]][%[[OFFSET]], 0] [%[[SIZE]], 1024]
115+
// CHECK-DAG: %[[F32_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[OFFSET]], 0] [%[[SIZE]], 1024]
116+
// CHECK-DAG: %[[I8_TILE:.+]] = tensor.extract_slice %[[OUTPUT_I8]][%[[OFFSET]], 0] [%[[SIZE]], 1024]
117+
// CHECK: %[[GENERIC_TILE:.+]]:2 = linalg.generic
118+
// CHECK-SAME: ins(%[[INPUT_TILE]] :
119+
// CHECK-SAME: outs(%[[F32_TILE]], %[[I8_TILE]] :
120+
// CHECK-DAG: %[[PACK_DEST_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0, 0, 0] [1, 512, 16, 2]
121+
// CHECK: %[[PACK_TILE:.+]] = linalg.pack %[[GENERIC_TILE]]#1
122+
// CHECK-SAME: padding_value(%[[C0_I8]] : i8)
123+
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [16, 2]
124+
// CHECK-SAME: into %[[PACK_DEST_TILE]]
125+
// CHECK: scf.forall.in_parallel {
126+
// CHECK: tensor.parallel_insert_slice %[[PACK_TILE]] into %[[ITERARG0]][%[[IV]], 0, 0, 0] [1, 512, 16, 2]
127+
// CHECK: tensor.parallel_insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[OFFSET]], 0] [%[[SIZE]], 1024]
128+
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
load("//llvm:lit_test.bzl", "lit_test")
2+
3+
licenses(["notice"])
4+
5+
package(default_visibility = ["//visibility:public"])
6+
7+
[
8+
lit_test(
9+
name = "%s.test" % src,
10+
srcs = [src],
11+
data = [
12+
"//llvm:llvm-symbolizer",
13+
"//mlir:mlir-opt",
14+
"//mlir/test:lit_data",
15+
],
16+
)
17+
for src in glob(
18+
include = ["**/*.mlir"],
19+
)
20+
]

0 commit comments

Comments
 (0)