Skip to content

Commit 332d560

Browse files
authored
[Codegen] Add v0 TileAndFuse constraints for mamtul (#24519)
After testing, this constraint set generates the same number of SMT solutions for matmul as the old tuner. See the SMT-LIB string [comparison](nod-ai/amd-shark-ai@5d51750). Issue: #23535
1 parent 2098bbe commit 332d560

5 files changed

Lines changed: 650 additions & 34 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/VerifyPipelineConstraints.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ struct ConstraintEvaluator {
156156
smt::IntDivOp, smt::IntModOp, smt::IntMulOp, smt::IntSubOp,
157157
smt::IteOp, smt::NotOp, smt::OrOp>(
158158
[&](auto op) { return eval(op); })
159+
.Case<smt::DeclareFunOp>([&](smt::DeclareFunOp declOp) {
160+
intValues[declOp.getResult()] = std::nullopt;
161+
return success();
162+
})
159163
.Default([](Operation *unhandled) {
160164
return unhandled->emitError(
161165
"unsupported op in constraint evaluator");

compiler/src/iree/compiler/Codegen/Common/test/insert_smt_constraints.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ hal.executable @matmul_f32_ex {
4848
// CHECK: linalg.fill
4949
// CHECK-NOT: iree_codegen.smt.constraints
5050
// CHECK: linalg.matmul
51+
// CHECK: iree_codegen.smt.constraints target = <set = 0>, pipeline = #iree_gpu.pipeline<TileAndFuse>
52+
// CHECK-NEXT{LITERAL}: knobs = {mma_kind = #iree_codegen.smt.one_of_knob<"mma_idx", [#iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>]>, reduction = [0, 0, #iree_codegen.smt.int_knob<"red_2">], subgroup = [#iree_codegen.smt.int_knob<"sg_0">, #iree_codegen.smt.int_knob<"sg_1">, 0], subgroup_size = #iree_codegen.smt.int_knob<"sg_size">, workgroup = [#iree_codegen.smt.int_knob<"wg_0">, #iree_codegen.smt.int_knob<"wg_1">, 0], workgroup_size = [#iree_codegen.smt.int_knob<"wg_size_x">, #iree_codegen.smt.int_knob<"wg_size_y">, #iree_codegen.smt.int_knob<"wg_size_z">]}
5153
//
5254
// CHECK: iree_codegen.smt.constraints target = <set = 0>, pipeline = #iree_gpu.pipeline<VectorDistribute>,
5355
// CHECK-NEXT{LITERAL}: knobs = {mma_kind = #iree_codegen.smt.one_of_knob<"mma_idx", [#iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>]>, reduction = [0, 0, #iree_codegen.smt.int_knob<"red_2">], subgroup_basis = [[#iree_codegen.smt.int_knob<"sg_m_cnt">, #iree_codegen.smt.int_knob<"sg_n_cnt">, 1], [0, 1, 2]], subgroup_size = #iree_codegen.smt.int_knob<"sg_size">, workgroup = [#iree_codegen.smt.int_knob<"wg_0">, #iree_codegen.smt.int_knob<"wg_1">, 0], workgroup_size = [#iree_codegen.smt.int_knob<"wg_size_x">, #iree_codegen.smt.int_knob<"wg_size_y">, #iree_codegen.smt.int_knob<"wg_size_z">]}

compiler/src/iree/compiler/Codegen/Common/test/verify_smt_constraints_e2e.mlir

Lines changed: 163 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
pipeline = #iree_gpu.pipeline<VectorDistribute>
2323
workgroup_size = [64, 1, 1] subgroup_size = 64>
2424

25-
func.func @matmul_e2e_generated_violation(
25+
func.func @matmul_e2e_generated_violation_vd(
2626
%lhs: tensor<128x64xf32>, %rhs: tensor<64x256xf32>)
2727
-> tensor<128x256xf32>
2828
attributes {hal.executable.target = #exec_target,
@@ -48,6 +48,47 @@ func.func @matmul_e2e_generated_violation(
4848

4949
// -----
5050

51+
#gpu_target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <
52+
compute = fp32, storage = b32, subgroup = shuffle,
53+
mma = [<MFMA_F32_16x16x4_F32>],
54+
subgroup_size_choices = [64],
55+
max_load_instruction_bits = 128,
56+
max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024,
57+
max_workgroup_memory_bytes = 65536,
58+
max_workgroup_counts = [2147483647, 2147483647, 2147483647]
59+
>>
60+
#exec_target = #hal.executable.target<"rocm", "rocm-hsaco-fb",
61+
{iree_codegen.target_info = #gpu_target}>
62+
#translation = #iree_codegen.translation_info<
63+
pipeline = #iree_gpu.pipeline<TileAndFuse>
64+
workgroup_size = [64, 1, 1] subgroup_size = 64>
65+
66+
func.func @matmul_e2e_generated_violation_tf(
67+
%lhs: tensor<128x64xf32>, %rhs: tensor<64x256xf32>)
68+
-> tensor<128x256xf32>
69+
attributes {hal.executable.target = #exec_target,
70+
translation_info = #translation} {
71+
%cst = arith.constant 0.0 : f32
72+
%init = tensor.empty() : tensor<128x256xf32>
73+
%fill = linalg.fill {root_op = #iree_codegen.root_op<set = 0>}
74+
ins(%cst : f32) outs(%init : tensor<128x256xf32>)
75+
-> tensor<128x256xf32>
76+
// expected-error @below {{pipeline constraints violated}}
77+
// expected-note @below {{dim_0 must be divisible by wg_0 (128 % 48 == 0)}}
78+
%result = linalg.matmul {
79+
lowering_config = #iree_gpu.lowering_config<{
80+
workgroup = [48, 64, 0],
81+
reduction = [0, 0, 16],
82+
mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>,
83+
subgroup = [1, 1, 0]}>,
84+
root_op = #iree_codegen.root_op<set = 0>}
85+
ins(%lhs, %rhs : tensor<128x64xf32>, tensor<64x256xf32>)
86+
outs(%fill : tensor<128x256xf32>) -> tensor<128x256xf32>
87+
return %result : tensor<128x256xf32>
88+
}
89+
90+
// -----
91+
5192
#gpu_target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <
5293
compute = fp32, storage = b32, subgroup = shuffle,
5394
mma = [<MFMA_F32_16x16x4_F32>],
@@ -63,7 +104,7 @@ func.func @matmul_e2e_generated_violation(
63104
pipeline = #iree_gpu.pipeline<VectorDistribute>
64105
workgroup_size = [64, 1, 1] subgroup_size = 64>
65106

66-
func.func @conv_e2e_generated_violation(
107+
func.func @conv_e2e_generated_violation_vd(
67108
%input: tensor<1x18x130x64xf32>, %filter: tensor<3x3x64x128xf32>)
68109
-> tensor<1x16x128x128xf32>
69110
attributes {hal.executable.target = #exec_target,
@@ -93,6 +134,50 @@ func.func @conv_e2e_generated_violation(
93134

94135
// -----
95136

137+
#gpu_target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <
138+
compute = fp32, storage = b32, subgroup = shuffle,
139+
mma = [<MFMA_F32_16x16x4_F32>],
140+
subgroup_size_choices = [64],
141+
max_load_instruction_bits = 128,
142+
max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024,
143+
max_workgroup_memory_bytes = 65536,
144+
max_workgroup_counts = [2147483647, 2147483647, 2147483647]
145+
>>
146+
#exec_target = #hal.executable.target<"rocm", "rocm-hsaco-fb",
147+
{iree_codegen.target_info = #gpu_target}>
148+
#translation = #iree_codegen.translation_info<
149+
pipeline = #iree_gpu.pipeline<TileAndFuse>
150+
workgroup_size = [64, 1, 1] subgroup_size = 64>
151+
152+
func.func @conv_e2e_generated_violation_tf(
153+
%input: tensor<1x18x130x64xf32>, %filter: tensor<3x3x64x128xf32>)
154+
-> tensor<1x16x128x128xf32>
155+
attributes {hal.executable.target = #exec_target,
156+
translation_info = #translation} {
157+
%cst = arith.constant 0.0 : f32
158+
%init = tensor.empty() : tensor<1x16x128x128xf32>
159+
%fill = linalg.fill {root_op = #iree_codegen.root_op<set = 1>}
160+
ins(%cst : f32) outs(%init : tensor<1x16x128x128xf32>)
161+
-> tensor<1x16x128x128xf32>
162+
// expected-error @below {{pipeline constraints violated}}
163+
// expected-note @below {{dim_2 must be divisible by wg_2 (128 % 48 == 0)}}
164+
%result = linalg.conv_2d_nhwc_hwcf {
165+
dilations = dense<1> : tensor<2xi64>,
166+
lowering_config = #iree_gpu.lowering_config<{
167+
workgroup = [1, 1, 48, 64, 0, 0, 0],
168+
reduction = [0, 0, 0, 0, 1, 1, 16],
169+
mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>,
170+
subgroup = [1, 1, 1, 1, 0, 0, 0]}>,
171+
root_op = #iree_codegen.root_op<set = 1>,
172+
strides = dense<1> : tensor<2xi64>}
173+
ins(%input, %filter : tensor<1x18x130x64xf32>,
174+
tensor<3x3x64x128xf32>)
175+
outs(%fill : tensor<1x16x128x128xf32>) -> tensor<1x16x128x128xf32>
176+
return %result : tensor<1x16x128x128xf32>
177+
}
178+
179+
// -----
180+
96181
// Test: End-to-end constraint insertion and verification.
97182
// Use the same shapes as above but with divisible workgroup sizes.
98183
// It should pass verification and have constraints erased.
@@ -111,7 +196,7 @@ func.func @conv_e2e_generated_violation(
111196
pipeline = #iree_gpu.pipeline<VectorDistribute>
112197
workgroup_size = [64, 1, 1] subgroup_size = 64>
113198

114-
func.func @matmul_e2e_constraints_erased(
199+
func.func @matmul_e2e_constraints_erased_vd(
115200
%lhs: tensor<128x64xf32>, %rhs: tensor<64x256xf32>)
116201
-> tensor<128x256xf32>
117202
attributes {hal.executable.target = #exec_target,
@@ -133,11 +218,11 @@ func.func @matmul_e2e_constraints_erased(
133218
return %result : tensor<128x256xf32>
134219
}
135220

136-
// CHECK-LABEL: func.func @matmul_e2e_constraints_erased
221+
// CHECK-LABEL: func.func @matmul_e2e_constraints_erased_vd
137222
// CHECK: linalg.matmul
138223
// CHECK-NOT: iree_codegen.smt.constraints
139224

140-
func.func @conv_e2e_constraints_erased(
225+
func.func @conv_e2e_constraints_erased_vd(
141226
%input: tensor<1x18x130x64xf32>, %filter: tensor<3x3x64x128xf32>)
142227
-> tensor<1x16x128x128xf32>
143228
attributes {hal.executable.target = #exec_target,
@@ -163,6 +248,78 @@ func.func @conv_e2e_constraints_erased(
163248
return %result : tensor<1x16x128x128xf32>
164249
}
165250

166-
// CHECK-LABEL: func.func @conv_e2e_constraints_erased
251+
// CHECK-LABEL: func.func @conv_e2e_constraints_erased_vd
252+
// CHECK: linalg.conv_2d_nhwc_hwcf
253+
// CHECK-NOT: iree_codegen.smt.constraints
254+
255+
// -----
256+
257+
#gpu_target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <
258+
compute = fp32, storage = b32, subgroup = shuffle,
259+
mma = [<MFMA_F32_16x16x4_F32>],
260+
subgroup_size_choices = [64],
261+
max_load_instruction_bits = 128,
262+
max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024,
263+
max_workgroup_memory_bytes = 65536,
264+
max_workgroup_counts = [2147483647, 2147483647, 2147483647]
265+
>>
266+
#exec_target = #hal.executable.target<"rocm", "rocm-hsaco-fb",
267+
{iree_codegen.target_info = #gpu_target}>
268+
#translation = #iree_codegen.translation_info<
269+
pipeline = #iree_gpu.pipeline<TileAndFuse>
270+
workgroup_size = [64, 1, 1] subgroup_size = 64>
271+
272+
func.func @matmul_e2e_constraints_erased_tf(
273+
%lhs: tensor<128x64xf32>, %rhs: tensor<64x256xf32>)
274+
-> tensor<128x256xf32>
275+
attributes {hal.executable.target = #exec_target,
276+
translation_info = #translation} {
277+
%cst = arith.constant 0.0 : f32
278+
%init = tensor.empty() : tensor<128x256xf32>
279+
%fill = linalg.fill {root_op = #iree_codegen.root_op<set = 0>}
280+
ins(%cst : f32) outs(%init : tensor<128x256xf32>)
281+
-> tensor<128x256xf32>
282+
%result = linalg.matmul {
283+
lowering_config = #iree_gpu.lowering_config<{
284+
workgroup = [16, 16, 0],
285+
reduction = [0, 0, 16],
286+
mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>,
287+
subgroup = [1, 1, 0]}>,
288+
root_op = #iree_codegen.root_op<set = 0>}
289+
ins(%lhs, %rhs : tensor<128x64xf32>, tensor<64x256xf32>)
290+
outs(%fill : tensor<128x256xf32>) -> tensor<128x256xf32>
291+
return %result : tensor<128x256xf32>
292+
}
293+
294+
// CHECK-LABEL: func.func @matmul_e2e_constraints_erased_tf
295+
// CHECK: linalg.matmul
296+
// CHECK-NOT: iree_codegen.smt.constraints
297+
298+
func.func @conv_e2e_constraints_erased_tf(
299+
%input: tensor<1x18x130x64xf32>, %filter: tensor<3x3x64x128xf32>)
300+
-> tensor<1x16x128x128xf32>
301+
attributes {hal.executable.target = #exec_target,
302+
translation_info = #translation} {
303+
%cst = arith.constant 0.0 : f32
304+
%init = tensor.empty() : tensor<1x16x128x128xf32>
305+
%fill = linalg.fill {root_op = #iree_codegen.root_op<set = 1>}
306+
ins(%cst : f32) outs(%init : tensor<1x16x128x128xf32>)
307+
-> tensor<1x16x128x128xf32>
308+
%result = linalg.conv_2d_nhwc_hwcf {
309+
dilations = dense<1> : tensor<2xi64>,
310+
lowering_config = #iree_gpu.lowering_config<{
311+
workgroup = [1, 1, 16, 64, 0, 0, 0],
312+
reduction = [0, 0, 0, 0, 1, 1, 16],
313+
mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>,
314+
subgroup = [0, 1, 0, 1, 1, 0, 0]}>,
315+
root_op = #iree_codegen.root_op<set = 1>,
316+
strides = dense<1> : tensor<2xi64>}
317+
ins(%input, %filter : tensor<1x18x130x64xf32>,
318+
tensor<3x3x64x128xf32>)
319+
outs(%fill : tensor<1x16x128x128xf32>) -> tensor<1x16x128x128xf32>
320+
return %result : tensor<1x16x128x128xf32>
321+
}
322+
323+
// CHECK-LABEL: func.func @conv_e2e_constraints_erased_tf
167324
// CHECK: linalg.conv_2d_nhwc_hwcf
168325
// CHECK-NOT: iree_codegen.smt.constraints

0 commit comments

Comments
 (0)