Skip to content

Commit 4720932

Browse files
[mlir][acc] ACCComputeLowering needs to account for device_type par (#201267)
When assigning parallelism for compute constructs or loops, device_type parallelism must be first considered as a group for all available (gang, worker, vector) - if any of these have device_type setting, then those are the only ones that should be considered. Only if the loop has no device_type specific parallelism then default parallelism should be assigned.
1 parent e43adae commit 4720932

5 files changed

Lines changed: 124 additions & 9 deletions

File tree

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,6 +1829,10 @@ def OpenACC_ParallelOp
18291829
mlir::Operation::operand_range
18301830
getNumGangsValues(mlir::acc::DeviceType deviceType);
18311831

1832+
/// Return true if the op has any num_gangs, num_workers, or vector_length
1833+
/// clause for the given device_type.
1834+
bool hasAnyGangWorkerVector(mlir::acc::DeviceType deviceType);
1835+
18321836
/// Return true if the op has the wait attribute for the
18331837
/// mlir::acc::DeviceType::None device_type.
18341838
bool hasWaitOnly();
@@ -2156,6 +2160,10 @@ def OpenACC_KernelsOp
21562160
mlir::Operation::operand_range
21572161
getNumGangsValues(mlir::acc::DeviceType deviceType);
21582162

2163+
/// Return true if the op has any num_gangs, num_workers, or vector_length
2164+
/// clause for the given device_type.
2165+
bool hasAnyGangWorkerVector(mlir::acc::DeviceType deviceType);
2166+
21592167
/// Return true if the op has the wait attribute for the
21602168
/// mlir::acc::DeviceType::None device_type.
21612169
bool hasWaitOnly();
@@ -2808,6 +2816,10 @@ def OpenACC_LoopOp
28082816
// 'default'/None device-type.
28092817
bool hasDefaultGangWorkerVector();
28102818

2819+
// Return whether this LoopOp has a gang, worker, or vector for the given
2820+
// device-type.
2821+
bool hasAnyGangWorkerVector(DeviceType deviceType);
2822+
28112823
// Used to obtain the parallelism mode for the requested device type.
28122824
// This first checks if the mode is set for the device_type requested.
28132825
// And if not, it returns the non-device_type mode.

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2173,6 +2173,31 @@ ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
21732173
getNumGangsSegments(), deviceType);
21742174
}
21752175

2176+
static bool hasAnyGangWorkerVectorForDeviceType(
2177+
std::optional<mlir::ArrayAttr> numGangsDeviceType,
2178+
mlir::Operation::operand_range numGangs,
2179+
std::optional<llvm::ArrayRef<int32_t>> numGangsSegments,
2180+
std::optional<mlir::ArrayAttr> numWorkersDeviceType,
2181+
mlir::Operation::operand_range numWorkers,
2182+
std::optional<mlir::ArrayAttr> vectorLengthDeviceType,
2183+
mlir::Operation::operand_range vectorLength,
2184+
mlir::acc::DeviceType deviceType) {
2185+
return !getValuesFromSegments(numGangsDeviceType, numGangs, numGangsSegments,
2186+
deviceType)
2187+
.empty() ||
2188+
getValueInDeviceTypeSegment(numWorkersDeviceType, numWorkers,
2189+
deviceType) ||
2190+
getValueInDeviceTypeSegment(vectorLengthDeviceType, vectorLength,
2191+
deviceType);
2192+
}
2193+
2194+
bool acc::ParallelOp::hasAnyGangWorkerVector(mlir::acc::DeviceType deviceType) {
2195+
return hasAnyGangWorkerVectorForDeviceType(
2196+
getNumGangsDeviceType(), getNumGangs(), getNumGangsSegments(),
2197+
getNumWorkersDeviceType(), getNumWorkers(), getVectorLengthDeviceType(),
2198+
getVectorLength(), deviceType);
2199+
}
2200+
21762201
bool acc::ParallelOp::hasWaitOnly() {
21772202
return hasWaitOnly(mlir::acc::DeviceType::None);
21782203
}
@@ -3039,6 +3064,13 @@ KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
30393064
getNumGangsSegments(), deviceType);
30403065
}
30413066

3067+
bool acc::KernelsOp::hasAnyGangWorkerVector(mlir::acc::DeviceType deviceType) {
3068+
return hasAnyGangWorkerVectorForDeviceType(
3069+
getNumGangsDeviceType(), getNumGangs(), getNumGangsSegments(),
3070+
getNumWorkersDeviceType(), getNumWorkers(), getVectorLengthDeviceType(),
3071+
getVectorLength(), deviceType);
3072+
}
3073+
30423074
bool acc::KernelsOp::hasWaitOnly() {
30433075
return hasWaitOnly(mlir::acc::DeviceType::None);
30443076
}
@@ -3968,9 +4000,15 @@ bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
39684000
}
39694001

39704002
bool acc::LoopOp::hasDefaultGangWorkerVector() {
3971-
return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3972-
hasGang() || getGangValue(GangArgType::Num) ||
3973-
getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
4003+
return hasAnyGangWorkerVector(DeviceType::None);
4004+
}
4005+
4006+
bool acc::LoopOp::hasAnyGangWorkerVector(DeviceType deviceType) {
4007+
return hasVector(deviceType) || getVectorValue(deviceType) ||
4008+
hasWorker(deviceType) || getWorkerValue(deviceType) ||
4009+
hasGang(deviceType) || getGangValue(GangArgType::Num, deviceType) ||
4010+
getGangValue(GangArgType::Dim, deviceType) ||
4011+
getGangValue(GangArgType::Static, deviceType);
39744012
}
39754013

39764014
acc::LoopParMode

mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,32 @@ static void insertParDim(SmallVectorImpl<GPUParallelDimAttr> &parDims,
185185
parDims.insert(lb, parDim);
186186
}
187187

188+
/// Return the device type from which gang/worker/vector clauses should be read.
189+
/// If the requested device type has any such clauses, use that exclusively;
190+
/// otherwise fall back to the default (DeviceType::None).
191+
static DeviceType getGangWorkerVectorDeviceType(LoopOp loopOp,
192+
DeviceType deviceType) {
193+
if (deviceType != DeviceType::None &&
194+
loopOp.hasAnyGangWorkerVector(deviceType))
195+
return deviceType;
196+
return DeviceType::None;
197+
}
198+
199+
template <typename ComputeConstructT>
200+
static DeviceType getParDimsDeviceType(ComputeConstructT computeOp,
201+
DeviceType deviceType) {
202+
if (deviceType != DeviceType::None &&
203+
computeOp.hasAnyGangWorkerVector(deviceType))
204+
return deviceType;
205+
return DeviceType::None;
206+
}
207+
188208
/// Map loop parallelism clauses (gang/worker/vector) to GPU parallel
189209
/// dimensions using the given mapping policy.
190210
static SmallVector<GPUParallelDimAttr>
191211
getParallelDimensions(LoopOp loopOp, const ACCToGPUMappingPolicy &policy,
192212
DeviceType deviceType) {
213+
deviceType = getGangWorkerVectorDeviceType(loopOp, deviceType);
193214
SmallVector<GPUParallelDimAttr> parDims;
194215
auto *ctx = loopOp->getContext();
195216

@@ -229,12 +250,12 @@ assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType,
229250
if (isEffectivelySerial(computeOp))
230251
return {ParWidthOp::create(rewriter, loc, Value(), policy.seqDim(ctx))};
231252

253+
deviceType = getParDimsDeviceType(computeOp, deviceType);
254+
232255
SmallVector<Value> values;
233256
auto indexTy = rewriter.getIndexType();
234257

235258
auto numGangs = computeOp.getNumGangsValues(deviceType);
236-
if (numGangs.empty())
237-
numGangs = computeOp.getNumGangsValues();
238259
for (auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
239260
auto gangLevel = getGangParLevel(gangDimIdx + 1);
240261
values.push_back(ParWidthOp::create(
@@ -245,8 +266,6 @@ assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType,
245266
}
246267

247268
Value numWorkers = computeOp.getNumWorkersValue(deviceType);
248-
if (!numWorkers)
249-
numWorkers = computeOp.getNumWorkersValue();
250269
if (numWorkers) {
251270
values.push_back(ParWidthOp::create(
252271
rewriter, loc,
@@ -256,8 +275,6 @@ assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType,
256275
}
257276

258277
Value vectorLength = computeOp.getVectorLengthValue(deviceType);
259-
if (!vectorLength)
260-
vectorLength = computeOp.getVectorLengthValue();
261278
if (vectorLength) {
262279
values.push_back(ParWidthOp::create(
263280
rewriter, loc,
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: mlir-opt %s -acc-compute-lowering=device-type=nvidia | FileCheck %s
2+
3+
// Default num_gangs, nvidia vector_length: with device-type=nvidia only vector applies.
4+
// CHECK-LABEL: func.func @parallel_default_gangs_nvidia_vector_length
5+
func.func @parallel_default_gangs_nvidia_vector_length(%buf: memref<4xi32>) {
6+
%c0 = arith.constant 0 : index
7+
%c1 = arith.constant 1 : index
8+
%c4 = arith.constant 4 : index
9+
%c4_i32 = arith.constant 4 : i32
10+
%c32_i32 = arith.constant 32 : i32
11+
12+
%dev = acc.copyin varPtr(%buf : memref<4xi32>) -> memref<4xi32>
13+
// CHECK-NOT: acc.par_width {{.*}} {par_dim = #acc.par_dim<block_x>}
14+
// CHECK: acc.par_width {{.*}} {par_dim = #acc.par_dim<thread_x>}
15+
acc.parallel num_gangs({%c4_i32 : i32}) vector_length(%c32_i32 : i32 [#acc.device_type<nvidia>]) dataOperands(%dev : memref<4xi32>) {
16+
acc.loop control(%i : index) = (%c0 : index) to (%c4 : index) step (%c1 : index) {
17+
%vi = arith.index_cast %i : index to i32
18+
memref.store %vi, %dev[%i] : memref<4xi32>
19+
acc.yield
20+
} attributes {independent = [#acc.device_type<none>]}
21+
acc.yield
22+
}
23+
acc.copyout accPtr(%dev : memref<4xi32>) to varPtr(%buf : memref<4xi32>)
24+
return
25+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt %s -acc-compute-lowering=device-type=nvidia | FileCheck %s
2+
3+
// Gang on default, vector on nvidia: with device-type=nvidia only vector applies.
4+
// CHECK-LABEL: func.func @parallel_loop_gang_default_vector_nvidia
5+
func.func @parallel_loop_gang_default_vector_nvidia(%buf: memref<1xi32>) {
6+
%c0 = arith.constant 0 : index
7+
%c1_i32 = arith.constant 1 : i32
8+
%c10_i32 = arith.constant 10 : i32
9+
%c100_i32 = arith.constant 100 : i32
10+
11+
%dev = acc.copyin varPtr(%buf : memref<1xi32>) -> memref<1xi32>
12+
// CHECK-NOT: acc.par_dims = #acc<par_dims[block_x]>
13+
// CHECK: acc.par_dims = #acc<par_dims[thread_x]>
14+
acc.parallel num_gangs({%c10_i32 : i32}) dataOperands(%dev : memref<1xi32>) {
15+
acc.loop gang control(%arg0 : i32) = (%c1_i32 : i32) to (%c100_i32 : i32) step (%c1_i32 : i32) {
16+
memref.store %arg0, %dev[%c0] : memref<1xi32>
17+
acc.yield
18+
} attributes {auto_ = [#acc.device_type<none>], gang = [#acc.device_type<none>], vector = [#acc.device_type<nvidia>]}
19+
acc.yield
20+
}
21+
acc.copyout accPtr(%dev : memref<1xi32>) to varPtr(%buf : memref<1xi32>)
22+
return
23+
}

0 commit comments

Comments
 (0)