-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathsycl_tests.cc
More file actions
63 lines (51 loc) · 2.23 KB
/
sycl_tests.cc
File metadata and controls
63 lines (51 loc) · 2.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include "test_utils.h"
#include <catch2/catch_template_test_macros.hpp>
#include <catch2/generators/catch_generators_all.hpp>
using namespace celerity;
using namespace celerity::detail;
// If this test fails, celerity can't reliably support reductions on the user's combination of backend and hardware
TEST_CASE_METHOD(test_utils::sycl_queue_fixture, "SYCL has working simple scalar reductions", "[sycl][reductions]") {
const size_t N = GENERATE(64, 512, 1024, 4096);
CAPTURE(N);
const auto buf = sycl::malloc_host<int>(1, get_sycl_queue());
*buf = 99; // SYCL reduction must overwrite this, not include it in the reduction result
get_sycl_queue()
.submit([&](sycl::handler& cgh) {
cgh.parallel_for(sycl::nd_range<1>{N, 64}, // ND-range: DPC++ e330855 (May 7, 2024) on CUDA will run out of registers for the default WG size
sycl::reduction(buf, sycl::plus<int>{}, sycl::property::reduction::initialize_to_identity{}), [](auto, auto& r) { r.combine(1); });
})
.wait();
CHECK(static_cast<size_t>(*buf) == N);
sycl::free(buf, get_sycl_queue());
}
TEST_CASE("SYCL implements by-value equality-comparison of device information", "[sycl][device-selection][!mayfail]") {
constexpr static auto get_devices = [] {
auto devs = sycl::device::get_devices();
std::sort(devs.begin(), devs.end(), [](const sycl::device& lhs, const sycl::device& rhs) {
const auto lhs_vendor_id = lhs.get_info<sycl::info::device::vendor_id>(), rhs_vendor_id = rhs.get_info<sycl::info::device::vendor_id>();
const auto lhs_name = lhs.get_info<sycl::info::device::name>(), rhs_name = rhs.get_info<sycl::info::device::name>();
if(lhs_vendor_id < rhs_vendor_id) return true;
if(lhs_vendor_id > rhs_vendor_id) return false;
return lhs_name < rhs_name;
});
return devs;
};
constexpr static auto get_platforms = [] {
const auto devs = get_devices();
std::vector<sycl::platform> pfs;
for(const auto& d : devs) {
pfs.push_back(d.get_platform());
}
return pfs;
};
SECTION("for sycl::device") {
const auto first = get_devices();
const auto second = get_devices();
CHECK(first == second);
}
SECTION("for sycl::platforms") {
const auto first = get_platforms();
const auto second = get_platforms();
CHECK(first == second);
}
}