Skip to content

Commit 4ba3ffa

Browse files
committed
[SYCLomatic] Refine migration for thrust::max and thrust::min and add them to api-query-mapping
Signed-off-by: chenwei.sun <chenwei.sun@intel.com>
1 parent d6824ae commit 4ba3ffa

6 files changed

Lines changed: 146 additions & 3 deletions

File tree

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include <thrust/extrema.h>
2+
3+
void max_test() {
4+
// clang-format off
5+
// Start
6+
struct key_value {
7+
int key;
8+
int value;
9+
};
10+
struct compare_key_value {
11+
__host__ __device__ bool operator()(key_value lhs, key_value rhs) {
12+
return lhs.key < rhs.key;
13+
}
14+
};
15+
key_value a = {13, 0};
16+
key_value b = {7, 1};
17+
key_value smaller = thrust::max(a, b, compare_key_value());
18+
int value = thrust::max(1, 2);
19+
// End
20+
// clang-format on
21+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include <thrust/extrema.h>
2+
3+
void min_test() {
4+
// clang-format off
5+
// Start
6+
struct key_value {
7+
int key;
8+
int value;
9+
};
10+
struct compare_key_value {
11+
__host__ __device__ bool operator()(key_value lhs, key_value rhs) {
12+
return lhs.key < rhs.key;
13+
}
14+
};
15+
key_value a = {13, 0};
16+
key_value b = {7, 1};
17+
key_value smaller = thrust::min(a, b, compare_key_value());
18+
int value = thrust::min(1, 2);
19+
// End
20+
// clang-format on
21+
}

clang/lib/DPCT/RulesLangLib/APINamesThrust.inc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,10 +1172,18 @@ thrustFactory("thrust::detail::vector_equal",
11721172
{{3,PolicyState::NoPolicy,3,"oneapi::dpl::equal", HelperFeatureEnum::none}}),
11731173

11741174
// thrust::max
1175-
CALL_FACTORY_ENTRY("thrust::max", CALL("std::max", ARG(0), ARG(1)))
1175+
CONDITIONAL_FACTORY_ENTRY(
1176+
CheckArgCount(2),
1177+
CALL_FACTORY_ENTRY("thrust::max", CALL("std::max", ARG(0), ARG(1))),
1178+
CALL_FACTORY_ENTRY("thrust::max", CALL("std::max", ARG(0), ARG(1), ARG(2)))
1179+
)
11761180

11771181
// thrust::min
1178-
CALL_FACTORY_ENTRY("thrust::min", CALL("std::min",ARG(0), ARG(1)))
1182+
CONDITIONAL_FACTORY_ENTRY(
1183+
CheckArgCount(2),
1184+
CALL_FACTORY_ENTRY("thrust::min", CALL("std::min", ARG(0), ARG(1))),
1185+
CALL_FACTORY_ENTRY("thrust::min", CALL("std::min", ARG(0), ARG(1), ARG(2)))
1186+
)
11791187

11801188
// thrust::tie
11811189
CALL_FACTORY_ENTRY("thrust::tie", CALL("std::tie",ARG(0), ARG(1)))

clang/test/dpct/query_api_mapping/Thrust/thrust_api_test_p3.cu

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,63 @@
267267
// thrust_make_zip_iterator-NEXT: typedef dpct::device_vector<float>::iterator float_iterator;
268268
// thrust_make_zip_iterator-NEXT: typedef std::tuple<int_iterator, float_iterator> iterator_tuple;
269269
// thrust_make_zip_iterator-NEXT: dpct::zip_iterator<iterator_tuple> ret = oneapi::dpl::make_zip_iterator(std::make_tuple(int_in.begin(), float_in.begin()));
270+
271+
// RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=thrust::max --extra-arg="-std=c++14"| FileCheck %s -check-prefix=thrust_max
272+
// thrust_max: CUDA API:
273+
// thrust_max-NEXT: struct key_value {
274+
// thrust_max-NEXT: int key;
275+
// thrust_max-NEXT: int value;
276+
// thrust_max-NEXT: };
277+
// thrust_max-NEXT: struct compare_key_value {
278+
// thrust_max-NEXT: __host__ __device__ bool operator()(key_value lhs, key_value rhs) {
279+
// thrust_max-NEXT: return lhs.key < rhs.key;
280+
// thrust_max-NEXT: }
281+
// thrust_max-NEXT: };
282+
// thrust_max-NEXT: key_value a = {13, 0};
283+
// thrust_max-NEXT: key_value b = {7, 1};
284+
// thrust_max-NEXT: key_value smaller = thrust::max(a, b, compare_key_value());
285+
// thrust_max-NEXT: int value = thrust::max(1, 2);
286+
// thrust_max-NEXT: Is migrated to:
287+
// thrust_max-NEXT: struct key_value {
288+
// thrust_max-NEXT: int key;
289+
// thrust_max-NEXT: int value;
290+
// thrust_max-NEXT: };
291+
// thrust_max-NEXT: struct compare_key_value {
292+
// thrust_max-NEXT: bool operator()(key_value lhs, key_value rhs) {
293+
// thrust_max-NEXT: return lhs.key < rhs.key;
294+
// thrust_max-NEXT: }
295+
// thrust_max-NEXT: };
296+
// thrust_max-NEXT: key_value a = {13, 0};
297+
// thrust_max-NEXT: key_value b = {7, 1};
298+
// thrust_max-NEXT: key_value smaller = std::max(a, b, compare_key_value());
299+
// thrust_max-NEXT: int value = std::max(1, 2);
300+
301+
// RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=thrust::min --extra-arg="-std=c++14"| FileCheck %s -check-prefix=thrust_min
302+
// thrust_min: CUDA API:
303+
// thrust_min-NEXT: struct key_value {
304+
// thrust_min-NEXT: int key;
305+
// thrust_min-NEXT: int value;
306+
// thrust_min-NEXT: };
307+
// thrust_min-NEXT: struct compare_key_value {
308+
// thrust_min-NEXT: __host__ __device__ bool operator()(key_value lhs, key_value rhs) {
309+
// thrust_min-NEXT: return lhs.key < rhs.key;
310+
// thrust_min-NEXT: }
311+
// thrust_min-NEXT: };
312+
// thrust_min-NEXT: key_value a = {13, 0};
313+
// thrust_min-NEXT: key_value b = {7, 1};
314+
// thrust_min-NEXT: key_value smaller = thrust::min(a, b, compare_key_value());
315+
// thrust_min-NEXT: int value = thrust::min(1, 2);
316+
// thrust_min-NEXT: Is migrated to:
317+
// thrust_min-NEXT: struct key_value {
318+
// thrust_min-NEXT: int key;
319+
// thrust_min-NEXT: int value;
320+
// thrust_min-NEXT: };
321+
// thrust_min-NEXT: struct compare_key_value {
322+
// thrust_min-NEXT: bool operator()(key_value lhs, key_value rhs) {
323+
// thrust_min-NEXT: return lhs.key < rhs.key;
324+
// thrust_min-NEXT: }
325+
// thrust_min-NEXT: };
326+
// thrust_min-NEXT: key_value a = {13, 0};
327+
// thrust_min-NEXT: key_value b = {7, 1};
328+
// thrust_min-NEXT: key_value smaller = std::min(a, b, compare_key_value());
329+
// thrust_min-NEXT: int value = std::min(1, 2);

clang/test/dpct/query_api_mapping/test_all.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2469,9 +2469,11 @@
24692469
// CHECK-NEXT: thrust::make_tuple
24702470
// CHECK-NEXT: thrust::make_zip_iterator
24712471
// CHECK-NEXT: thrust::malloc
2472+
// CHECK-NEXT: thrust::max
24722473
// CHECK-NEXT: thrust::max_element
24732474
// CHECK-NEXT: thrust::merge
24742475
// CHECK-NEXT: thrust::merge_by_key
2476+
// CHECK-NEXT: thrust::min
24752477
// CHECK-NEXT: thrust::min_element
24762478
// CHECK-NEXT: thrust::minmax_element
24772479
// CHECK-NEXT: thrust::mismatch

clang/test/dpct/thrust-for-h2o4gpu.cu

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include <thrust/execution_policy.h>
1515
#include <thrust/random.h>
1616
#include <thrust/reduce.h>
17-
#include <algorithm>
17+
#include <thrust/extrema.h>
1818
#include <thrust/inner_product.h>
1919
#include <thrust/extrema.h>
2020
#include <thrust/host_vector.h>
@@ -651,3 +651,34 @@ template <class T> auto foo2(T t) {
651651
// CHECK: return dpct::make_constant_iterator(std::make_tuple(t, false));
652652
return thrust::make_constant_iterator(thrust::make_tuple<T, bool>(t, false));
653653
}
654+
655+
struct key_value {
656+
int key;
657+
int value;
658+
};
659+
660+
// CHECK: struct compare_key_value {
661+
// CHECK-NEXT: bool operator()(key_value lhs, key_value rhs) {
662+
// CHECK-NEXT: return lhs.key < rhs.key;
663+
// CHECK-NEXT: }
664+
// CHECK-NEXT: };
665+
struct compare_key_value {
666+
__host__ __device__ bool operator()(key_value lhs, key_value rhs) {
667+
return lhs.key < rhs.key;
668+
}
669+
};
670+
671+
void thrust_max_min() {
672+
key_value a = {13, 0};
673+
key_value b = {7, 1};
674+
675+
// CHECK: key_value smaller = std::min(a, b, compare_key_value());
676+
// CHECK-NEXT: key_value maxer = std::max(a, b, compare_key_value());
677+
key_value smaller = thrust::min(a, b, compare_key_value());
678+
key_value maxer = thrust::max(a, b, compare_key_value());
679+
680+
// CHECK: int min = std::min(1, 2);
681+
// CHECK-NEXT: int max = std::max(1, 2);
682+
int min = thrust::min(1, 2);
683+
int max = thrust::max(1, 2);
684+
}

0 commit comments

Comments
 (0)