Skip to content

Commit 29d6c02

Browse files
Add device_support_queries to enable default device types
1 parent 7949c17 commit 29d6c02

4 files changed

Lines changed: 271 additions & 29 deletions

File tree

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ set(_tensor_impl_sources
5656
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/zeros_ctor.cpp
5757
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp
5858
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
59-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
59+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
6060
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp
6161
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
6262
)
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===--------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
33+
//===--------------------------------------------------------------------===//
34+
35+
#include <string>
36+
37+
#include "dpnp4pybind11.hpp"
38+
#include <pybind11/pybind11.h>
39+
#include <pybind11/stl.h>
40+
#include <sycl/sycl.hpp>
41+
42+
namespace dpctl
43+
{
44+
namespace tensor
45+
{
46+
namespace py_internal
47+
{
48+
49+
namespace
50+
{
51+
52+
std::string _default_device_fp_type(const sycl::device &d)
53+
{
54+
if (d.has(sycl::aspect::fp64)) {
55+
return "f8";
56+
}
57+
else {
58+
return "f4";
59+
}
60+
}
61+
62+
int get_numpy_major_version()
63+
{
64+
namespace py = pybind11;
65+
66+
py::module_ numpy = py::module_::import("numpy");
67+
py::str version_string = numpy.attr("__version__");
68+
py::module_ numpy_lib = py::module_::import("numpy.lib");
69+
70+
py::object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
71+
int major_version = numpy_version.attr("major").cast<int>();
72+
73+
return major_version;
74+
}
75+
76+
std::string _default_device_int_type(const sycl::device &)
77+
{
78+
const int np_ver = get_numpy_major_version();
79+
80+
if (np_ver >= 2) {
81+
return "i8";
82+
}
83+
else {
84+
// code for numpy.dtype('long') to be consistent
85+
// with NumPy's default integer type across
86+
// platforms.
87+
return "l";
88+
}
89+
}
90+
91+
std::string _default_device_uint_type(const sycl::device &)
92+
{
93+
const int np_ver = get_numpy_major_version();
94+
95+
if (np_ver >= 2) {
96+
return "u8";
97+
}
98+
else {
99+
// code for numpy.dtype('long') to be consistent
100+
// with NumPy's default integer type across
101+
// platforms.
102+
return "L";
103+
}
104+
}
105+
106+
std::string _default_device_complex_type(const sycl::device &d)
107+
{
108+
if (d.has(sycl::aspect::fp64)) {
109+
return "c16";
110+
}
111+
else {
112+
return "c8";
113+
}
114+
}
115+
116+
std::string _default_device_bool_type(const sycl::device &)
117+
{
118+
return "b1";
119+
}
120+
121+
std::string _default_device_index_type(const sycl::device &)
122+
{
123+
return "i8";
124+
}
125+
126+
sycl::device _extract_device(const py::object &arg)
127+
{
128+
auto const &api = dpctl::detail::dpctl_capi::get();
129+
130+
PyObject *source = arg.ptr();
131+
if (api.PySyclQueue_Check_(source)) {
132+
const sycl::queue &q = py::cast<sycl::queue>(arg);
133+
return q.get_device();
134+
}
135+
else if (api.PySyclDevice_Check_(source)) {
136+
return py::cast<sycl::device>(arg);
137+
}
138+
else {
139+
throw py::type_error(
140+
"Expected type `dpctl.SyclQueue` or `dpctl.SyclDevice`.");
141+
}
142+
}
143+
144+
} // namespace
145+
146+
std::string default_device_fp_type(const py::object &arg)
147+
{
148+
const sycl::device &d = _extract_device(arg);
149+
return _default_device_fp_type(d);
150+
}
151+
152+
std::string default_device_int_type(const py::object &arg)
153+
{
154+
const sycl::device &d = _extract_device(arg);
155+
return _default_device_int_type(d);
156+
}
157+
158+
std::string default_device_uint_type(const py::object &arg)
159+
{
160+
const sycl::device &d = _extract_device(arg);
161+
return _default_device_uint_type(d);
162+
}
163+
164+
std::string default_device_bool_type(const py::object &arg)
165+
{
166+
const sycl::device &d = _extract_device(arg);
167+
return _default_device_bool_type(d);
168+
}
169+
170+
std::string default_device_complex_type(const py::object &arg)
171+
{
172+
const sycl::device &d = _extract_device(arg);
173+
return _default_device_complex_type(d);
174+
}
175+
176+
std::string default_device_index_type(const py::object &arg)
177+
{
178+
const sycl::device &d = _extract_device(arg);
179+
return _default_device_index_type(d);
180+
}
181+
182+
} // namespace py_internal
183+
} // namespace tensor
184+
} // namespace dpctl
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//===--------------------------------------------------------------------===//
29+
///
30+
/// \file
31+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
32+
//===--------------------------------------------------------------------===//
33+
34+
#pragma once
35+
#include <string>
36+
37+
#include "dpnp4pybind11.hpp"
38+
#include <pybind11/pybind11.h>
39+
#include <pybind11/stl.h>
40+
#include <sycl/sycl.hpp>
41+
42+
namespace dpctl
43+
{
44+
namespace tensor
45+
{
46+
namespace py_internal
47+
{
48+
49+
extern std::string default_device_fp_type(const py::object &);
50+
extern std::string default_device_int_type(const py::object &);
51+
extern std::string default_device_uint_type(const py::object &);
52+
extern std::string default_device_bool_type(const py::object &);
53+
extern std::string default_device_complex_type(const py::object &);
54+
extern std::string default_device_index_type(const py::object &);
55+
56+
} // namespace py_internal
57+
} // namespace tensor
58+
} // namespace dpctl

dpctl_ext/tensor/libtensor/source/tensor_ctors.cpp

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
// #include "copy_for_reshape.hpp"
5353
// #include "copy_for_roll.hpp"
5454
// #include "copy_numpy_ndarray_into_usm_ndarray.hpp"
55-
// #include "device_support_queries.hpp"
55+
#include "device_support_queries.hpp"
5656
// #include "eye_ctor.hpp"
5757
// #include "full_ctor.hpp"
5858
// #include "integer_advanced_indexing.hpp"
@@ -360,33 +360,33 @@ PYBIND11_MODULE(_tensor_impl, m)
360360
// py::arg("k"), py::arg("dst"), py::arg("sycl_queue"),
361361
// py::arg("depends") = py::list());
362362

363-
// m.def("default_device_fp_type",
364-
// dpctl::tensor::py_internal::default_device_fp_type,
365-
// "Gives default floating point type supported by device.",
366-
// py::arg("dev"));
367-
368-
// m.def("default_device_int_type",
369-
// dpctl::tensor::py_internal::default_device_int_type,
370-
// "Gives default signed integer type supported by device.",
371-
// py::arg("dev"));
372-
373-
// m.def("default_device_uint_type",
374-
// dpctl::tensor::py_internal::default_device_uint_type,
375-
// "Gives default unsigned integer type supported by device.",
376-
// py::arg("dev"));
377-
378-
// m.def("default_device_bool_type",
379-
// dpctl::tensor::py_internal::default_device_bool_type,
380-
// "Gives default boolean type supported by device.", py::arg("dev"));
381-
382-
// m.def("default_device_complex_type",
383-
// dpctl::tensor::py_internal::default_device_complex_type,
384-
// "Gives default complex floating point type supported by device.",
385-
// py::arg("dev"));
386-
387-
// m.def("default_device_index_type",
388-
// dpctl::tensor::py_internal::default_device_index_type,
389-
// "Gives default index type supported by device.", py::arg("dev"));
363+
m.def("default_device_fp_type",
364+
dpctl::tensor::py_internal::default_device_fp_type,
365+
"Gives default floating point type supported by device.",
366+
py::arg("dev"));
367+
368+
m.def("default_device_int_type",
369+
dpctl::tensor::py_internal::default_device_int_type,
370+
"Gives default signed integer type supported by device.",
371+
py::arg("dev"));
372+
373+
m.def("default_device_uint_type",
374+
dpctl::tensor::py_internal::default_device_uint_type,
375+
"Gives default unsigned integer type supported by device.",
376+
py::arg("dev"));
377+
378+
m.def("default_device_bool_type",
379+
dpctl::tensor::py_internal::default_device_bool_type,
380+
"Gives default boolean type supported by device.", py::arg("dev"));
381+
382+
m.def("default_device_complex_type",
383+
dpctl::tensor::py_internal::default_device_complex_type,
384+
"Gives default complex floating point type supported by device.",
385+
py::arg("dev"));
386+
387+
m.def("default_device_index_type",
388+
dpctl::tensor::py_internal::default_device_index_type,
389+
"Gives default index type supported by device.", py::arg("dev"));
390390

391391
// auto tril_fn = [](const dpctl::tensor::usm_ndarray &src,
392392
// const dpctl::tensor::usm_ndarray &dst, py::ssize_t k,

0 commit comments

Comments
 (0)