Skip to content

Commit c36166f

Browse files
committed
add C API back to _usmarray.pyx
1 parent 206d85d commit c36166f

2 files changed

Lines changed: 334 additions & 62 deletions

File tree

dpnp/backend/include/dpnp4pybind11.hpp

Lines changed: 102 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,37 @@ class dpnp_capi
7777
public:
7878
PyTypeObject *PyUSMArrayType_;
7979

80+
char *(*UsmNDArray_GetData_)(PyUSMArrayObject *);
81+
int (*UsmNDArray_GetNDim_)(PyUSMArrayObject *);
82+
py::ssize_t *(*UsmNDArray_GetShape_)(PyUSMArrayObject *);
83+
py::ssize_t *(*UsmNDArray_GetStrides_)(PyUSMArrayObject *);
84+
int (*UsmNDArray_GetTypenum_)(PyUSMArrayObject *);
85+
int (*UsmNDArray_GetElementSize_)(PyUSMArrayObject *);
86+
int (*UsmNDArray_GetFlags_)(PyUSMArrayObject *);
87+
DPCTLSyclQueueRef (*UsmNDArray_GetQueueRef_)(PyUSMArrayObject *);
88+
py::ssize_t (*UsmNDArray_GetOffset_)(PyUSMArrayObject *);
89+
PyObject *(*UsmNDArray_GetUSMData_)(PyUSMArrayObject *);
90+
void (*UsmNDArray_SetWritableFlag_)(PyUSMArrayObject *, int);
91+
PyObject *(*UsmNDArray_MakeSimpleFromMemory_)(int,
92+
const py::ssize_t *,
93+
int,
94+
Py_MemoryObject *,
95+
py::ssize_t,
96+
char);
97+
PyObject *(*UsmNDArray_MakeSimpleFromPtr_)(size_t,
98+
int,
99+
DPCTLSyclUSMRef,
100+
DPCTLSyclQueueRef,
101+
PyObject *);
102+
PyObject *(*UsmNDArray_MakeFromPtr_)(int,
103+
const py::ssize_t *,
104+
int,
105+
const py::ssize_t *,
106+
DPCTLSyclUSMRef,
107+
DPCTLSyclQueueRef,
108+
py::ssize_t,
109+
PyObject *);
110+
80111
int USM_ARRAY_C_CONTIGUOUS_;
81112
int USM_ARRAY_F_CONTIGUOUS_;
82113
int USM_ARRAY_WRITABLE_;
@@ -119,7 +150,15 @@ class dpnp_capi
119150
std::shared_ptr<py::object> default_usm_ndarray_;
120151

121152
dpnp_capi()
122-
: PyUSMArrayType_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
153+
: PyUSMArrayType_(nullptr), UsmNDArray_GetData_(nullptr),
154+
UsmNDArray_GetNDim_(nullptr), UsmNDArray_GetShape_(nullptr),
155+
UsmNDArray_GetStrides_(nullptr), UsmNDArray_GetTypenum_(nullptr),
156+
UsmNDArray_GetElementSize_(nullptr), UsmNDArray_GetFlags_(nullptr),
157+
UsmNDArray_GetQueueRef_(nullptr), UsmNDArray_GetOffset_(nullptr),
158+
UsmNDArray_GetUSMData_(nullptr), UsmNDArray_SetWritableFlag_(nullptr),
159+
UsmNDArray_MakeSimpleFromMemory_(nullptr),
160+
UsmNDArray_MakeSimpleFromPtr_(nullptr),
161+
UsmNDArray_MakeFromPtr_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
123162
USM_ARRAY_F_CONTIGUOUS_(0), USM_ARRAY_WRITABLE_(0), UAR_BOOL_(-1),
124163
UAR_BYTE_(-1), UAR_UBYTE_(-1), UAR_SHORT_(-1), UAR_USHORT_(-1),
125164
UAR_INT_(-1), UAR_UINT_(-1), UAR_LONG_(-1), UAR_ULONG_(-1),
@@ -135,6 +174,23 @@ class dpnp_capi
135174

136175
this->PyUSMArrayType_ = &PyUSMArrayType;
137176

177+
// dpnp.tensor.usm_ndarray API
178+
this->UsmNDArray_GetData_ = UsmNDArray_GetData;
179+
this->UsmNDArray_GetNDim_ = UsmNDArray_GetNDim;
180+
this->UsmNDArray_GetShape_ = UsmNDArray_GetShape;
181+
this->UsmNDArray_GetStrides_ = UsmNDArray_GetStrides;
182+
this->UsmNDArray_GetTypenum_ = UsmNDArray_GetTypenum;
183+
this->UsmNDArray_GetElementSize_ = UsmNDArray_GetElementSize;
184+
this->UsmNDArray_GetFlags_ = UsmNDArray_GetFlags;
185+
this->UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef;
186+
this->UsmNDArray_GetOffset_ = UsmNDArray_GetOffset;
187+
this->UsmNDArray_GetUSMData_ = UsmNDArray_GetUSMData;
188+
this->UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag;
189+
this->UsmNDArray_MakeSimpleFromMemory_ =
190+
UsmNDArray_MakeSimpleFromMemory;
191+
this->UsmNDArray_MakeSimpleFromPtr_ = UsmNDArray_MakeSimpleFromPtr;
192+
this->UsmNDArray_MakeFromPtr_ = UsmNDArray_MakeFromPtr;
193+
138194
// constants
139195
this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
140196
this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS;
@@ -269,7 +325,9 @@ class usm_ndarray : public py::object
269325
char *get_data() const
270326
{
271327
PyUSMArrayObject *raw_ar = usm_array_ptr();
272-
return raw_ar->data_;
328+
329+
auto const &api = detail::dpnp_capi::get();
330+
return api.UsmNDArray_GetData_(raw_ar);
273331
}
274332

275333
template <typename T>
@@ -281,13 +339,17 @@ class usm_ndarray : public py::object
281339
int get_ndim() const
282340
{
283341
PyUSMArrayObject *raw_ar = usm_array_ptr();
284-
return raw_ar->nd_;
342+
343+
auto const &api = detail::dpnp_capi::get();
344+
return api.UsmNDArray_GetNDim_(raw_ar);
285345
}
286346

287347
const py::ssize_t *get_shape_raw() const
288348
{
289349
PyUSMArrayObject *raw_ar = usm_array_ptr();
290-
return raw_ar->shape_;
350+
351+
auto const &api = detail::dpnp_capi::get();
352+
return api.UsmNDArray_GetShape_(raw_ar);
291353
}
292354

293355
std::vector<py::ssize_t> get_shape_vector() const
@@ -308,7 +370,9 @@ class usm_ndarray : public py::object
308370
const py::ssize_t *get_strides_raw() const
309371
{
310372
PyUSMArrayObject *raw_ar = usm_array_ptr();
311-
return raw_ar->strides_;
373+
374+
auto const &api = detail::dpnp_capi::get();
375+
return api.UsmNDArray_GetStrides_(raw_ar);
312376
}
313377

314378
std::vector<py::ssize_t> get_strides_vector() const
@@ -343,8 +407,9 @@ class usm_ndarray : public py::object
343407
{
344408
PyUSMArrayObject *raw_ar = usm_array_ptr();
345409

346-
int ndim = raw_ar->nd_;
347-
const py::ssize_t *shape = raw_ar->shape_;
410+
auto const &api = detail::dpnp_capi::get();
411+
int ndim = api.UsmNDArray_GetNDim_(raw_ar);
412+
const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
348413

349414
py::ssize_t nelems = 1;
350415
for (int i = 0; i < ndim; ++i) {
@@ -359,9 +424,10 @@ class usm_ndarray : public py::object
359424
{
360425
PyUSMArrayObject *raw_ar = usm_array_ptr();
361426

362-
int nd = raw_ar->nd_;
363-
const py::ssize_t *shape = raw_ar->shape_;
364-
const py::ssize_t *strides = raw_ar->strides_;
427+
auto const &api = detail::dpnp_capi::get();
428+
int nd = api.UsmNDArray_GetNDim_(raw_ar);
429+
const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
430+
const py::ssize_t *strides = api.UsmNDArray_GetStrides_(raw_ar);
365431

366432
py::ssize_t offset_min = 0;
367433
py::ssize_t offset_max = 0;
@@ -389,77 +455,43 @@ class usm_ndarray : public py::object
389455
sycl::queue get_queue() const
390456
{
391457
PyUSMArrayObject *raw_ar = usm_array_ptr();
392-
Py_MemoryObject *mem_obj =
393-
reinterpret_cast<Py_MemoryObject *>(raw_ar->base_);
394458

395-
auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
396-
DPCTLSyclQueueRef QRef = dpctl_api.Memory_GetQueueRef_(mem_obj);
459+
auto const &api = detail::dpnp_capi::get();
460+
DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
397461
return *(reinterpret_cast<sycl::queue *>(QRef));
398462
}
399463

400464
sycl::device get_device() const
401465
{
402466
PyUSMArrayObject *raw_ar = usm_array_ptr();
403-
Py_MemoryObject *mem_obj =
404-
reinterpret_cast<Py_MemoryObject *>(raw_ar->base_);
405467

406-
auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
407-
DPCTLSyclQueueRef QRef = dpctl_api.Memory_GetQueueRef_(mem_obj);
468+
auto const &api = detail::dpnp_capi::get();
469+
DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
408470
return reinterpret_cast<sycl::queue *>(QRef)->get_device();
409471
}
410472

411473
int get_typenum() const
412474
{
413475
PyUSMArrayObject *raw_ar = usm_array_ptr();
414-
return raw_ar->typenum_;
476+
477+
auto const &api = detail::dpnp_capi::get();
478+
return api.UsmNDArray_GetTypenum_(raw_ar);
415479
}
416480

417481
int get_flags() const
418482
{
419483
PyUSMArrayObject *raw_ar = usm_array_ptr();
420-
return raw_ar->flags_;
484+
485+
auto const &api = detail::dpnp_capi::get();
486+
return api.UsmNDArray_GetFlags_(raw_ar);
421487
}
422488

423489
int get_elemsize() const
424490
{
425-
int typenum = get_typenum();
426-
auto const &api = detail::dpnp_capi::get();
491+
PyUSMArrayObject *raw_ar = usm_array_ptr();
427492

428-
// Lookup table for element sizes based on typenum
429-
if (typenum == api.UAR_BOOL_)
430-
return 1;
431-
if (typenum == api.UAR_BYTE_)
432-
return 1;
433-
if (typenum == api.UAR_UBYTE_)
434-
return 1;
435-
if (typenum == api.UAR_SHORT_)
436-
return 2;
437-
if (typenum == api.UAR_USHORT_)
438-
return 2;
439-
if (typenum == api.UAR_INT_)
440-
return 4;
441-
if (typenum == api.UAR_UINT_)
442-
return 4;
443-
if (typenum == api.UAR_LONG_)
444-
return sizeof(long);
445-
if (typenum == api.UAR_ULONG_)
446-
return sizeof(unsigned long);
447-
if (typenum == api.UAR_LONGLONG_)
448-
return 8;
449-
if (typenum == api.UAR_ULONGLONG_)
450-
return 8;
451-
if (typenum == api.UAR_FLOAT_)
452-
return 4;
453-
if (typenum == api.UAR_DOUBLE_)
454-
return 8;
455-
if (typenum == api.UAR_CFLOAT_)
456-
return 8;
457-
if (typenum == api.UAR_CDOUBLE_)
458-
return 16;
459-
if (typenum == api.UAR_HALF_)
460-
return 2;
461-
462-
return 0; // Unknown type
493+
auto const &api = detail::dpnp_capi::get();
494+
return api.UsmNDArray_GetElementSize_(raw_ar);
463495
}
464496

465497
bool is_c_contiguous() const
@@ -487,9 +519,10 @@ class usm_ndarray : public py::object
487519
py::object get_usm_data() const
488520
{
489521
PyUSMArrayObject *raw_ar = usm_array_ptr();
522+
523+
auto const &api = detail::dpnp_capi::get();
490524
// base_ is the Memory object - return new reference
491-
PyObject *usm_data = raw_ar->base_;
492-
Py_XINCREF(usm_data);
525+
PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
493526

494527
// pass reference ownership to py::object
495528
return py::reinterpret_steal<py::object>(usm_data);
@@ -498,28 +531,34 @@ class usm_ndarray : public py::object
498531
bool is_managed_by_smart_ptr() const
499532
{
500533
PyUSMArrayObject *raw_ar = usm_array_ptr();
501-
PyObject *usm_data = raw_ar->base_;
534+
535+
auto const &api = detail::dpnp_capi::get();
536+
PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
502537

503538
auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
504539
if (!PyObject_TypeCheck(usm_data, dpctl_api.Py_MemoryType_)) {
540+
Py_DECREF(usm_data);
505541
return false;
506542
}
507543

508544
Py_MemoryObject *mem_obj =
509545
reinterpret_cast<Py_MemoryObject *>(usm_data);
510546
const void *opaque_ptr = dpctl_api.Memory_GetOpaquePointer_(mem_obj);
511547

548+
Py_DECREF(usm_data);
512549
return bool(opaque_ptr);
513550
}
514551

515552
const std::shared_ptr<void> &get_smart_ptr_owner() const
516553
{
517554
PyUSMArrayObject *raw_ar = usm_array_ptr();
518-
PyObject *usm_data = raw_ar->base_;
519555

520-
auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
556+
auto const &api = detail::dpnp_capi::get();
557+
PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
521558

559+
auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get();
522560
if (!PyObject_TypeCheck(usm_data, dpctl_api.Py_MemoryType_)) {
561+
Py_DECREF(usm_data);
523562
throw std::runtime_error(
524563
"usm_ndarray object does not have Memory object "
525564
"managing lifetime of USM allocation");
@@ -528,6 +567,7 @@ class usm_ndarray : public py::object
528567
Py_MemoryObject *mem_obj =
529568
reinterpret_cast<Py_MemoryObject *>(usm_data);
530569
void *opaque_ptr = dpctl_api.Memory_GetOpaquePointer_(mem_obj);
570+
Py_DECREF(usm_data);
531571

532572
if (opaque_ptr) {
533573
auto shptr_ptr =

0 commit comments

Comments
 (0)