@@ -77,6 +77,37 @@ class dpnp_capi
7777public:
7878 PyTypeObject *PyUSMArrayType_;
7979
80+ char *(*UsmNDArray_GetData_)(PyUSMArrayObject *);
81+ int (*UsmNDArray_GetNDim_)(PyUSMArrayObject *);
82+ py::ssize_t *(*UsmNDArray_GetShape_)(PyUSMArrayObject *);
83+ py::ssize_t *(*UsmNDArray_GetStrides_)(PyUSMArrayObject *);
84+ int (*UsmNDArray_GetTypenum_)(PyUSMArrayObject *);
85+ int (*UsmNDArray_GetElementSize_)(PyUSMArrayObject *);
86+ int (*UsmNDArray_GetFlags_)(PyUSMArrayObject *);
87+ DPCTLSyclQueueRef (*UsmNDArray_GetQueueRef_)(PyUSMArrayObject *);
88+ py::ssize_t (*UsmNDArray_GetOffset_)(PyUSMArrayObject *);
89+ PyObject *(*UsmNDArray_GetUSMData_)(PyUSMArrayObject *);
90+ void (*UsmNDArray_SetWritableFlag_)(PyUSMArrayObject *, int );
91+ PyObject *(*UsmNDArray_MakeSimpleFromMemory_)(int ,
92+ const py::ssize_t *,
93+ int ,
94+ Py_MemoryObject *,
95+ py::ssize_t ,
96+ char );
97+ PyObject *(*UsmNDArray_MakeSimpleFromPtr_)(size_t ,
98+ int ,
99+ DPCTLSyclUSMRef,
100+ DPCTLSyclQueueRef,
101+ PyObject *);
102+ PyObject *(*UsmNDArray_MakeFromPtr_)(int ,
103+ const py::ssize_t *,
104+ int ,
105+ const py::ssize_t *,
106+ DPCTLSyclUSMRef,
107+ DPCTLSyclQueueRef,
108+ py::ssize_t ,
109+ PyObject *);
110+
80111 int USM_ARRAY_C_CONTIGUOUS_;
81112 int USM_ARRAY_F_CONTIGUOUS_;
82113 int USM_ARRAY_WRITABLE_;
@@ -119,7 +150,15 @@ class dpnp_capi
119150 std::shared_ptr<py::object> default_usm_ndarray_;
120151
121152 dpnp_capi ()
122- : PyUSMArrayType_(nullptr ), USM_ARRAY_C_CONTIGUOUS_(0 ),
153+ : PyUSMArrayType_(nullptr ), UsmNDArray_GetData_(nullptr ),
154+ UsmNDArray_GetNDim_ (nullptr ), UsmNDArray_GetShape_(nullptr ),
155+ UsmNDArray_GetStrides_(nullptr ), UsmNDArray_GetTypenum_(nullptr ),
156+ UsmNDArray_GetElementSize_(nullptr ), UsmNDArray_GetFlags_(nullptr ),
157+ UsmNDArray_GetQueueRef_(nullptr ), UsmNDArray_GetOffset_(nullptr ),
158+ UsmNDArray_GetUSMData_(nullptr ), UsmNDArray_SetWritableFlag_(nullptr ),
159+ UsmNDArray_MakeSimpleFromMemory_(nullptr ),
160+ UsmNDArray_MakeSimpleFromPtr_(nullptr ),
161+ UsmNDArray_MakeFromPtr_(nullptr ), USM_ARRAY_C_CONTIGUOUS_(0 ),
123162 USM_ARRAY_F_CONTIGUOUS_(0 ), USM_ARRAY_WRITABLE_(0 ), UAR_BOOL_(-1 ),
124163 UAR_BYTE_(-1 ), UAR_UBYTE_(-1 ), UAR_SHORT_(-1 ), UAR_USHORT_(-1 ),
125164 UAR_INT_(-1 ), UAR_UINT_(-1 ), UAR_LONG_(-1 ), UAR_ULONG_(-1 ),
@@ -135,6 +174,23 @@ class dpnp_capi
135174
136175 this ->PyUSMArrayType_ = &PyUSMArrayType;
137176
177+ // dpnp.tensor.usm_ndarray API
178+ this ->UsmNDArray_GetData_ = UsmNDArray_GetData;
179+ this ->UsmNDArray_GetNDim_ = UsmNDArray_GetNDim;
180+ this ->UsmNDArray_GetShape_ = UsmNDArray_GetShape;
181+ this ->UsmNDArray_GetStrides_ = UsmNDArray_GetStrides;
182+ this ->UsmNDArray_GetTypenum_ = UsmNDArray_GetTypenum;
183+ this ->UsmNDArray_GetElementSize_ = UsmNDArray_GetElementSize;
184+ this ->UsmNDArray_GetFlags_ = UsmNDArray_GetFlags;
185+ this ->UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef;
186+ this ->UsmNDArray_GetOffset_ = UsmNDArray_GetOffset;
187+ this ->UsmNDArray_GetUSMData_ = UsmNDArray_GetUSMData;
188+ this ->UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag;
189+ this ->UsmNDArray_MakeSimpleFromMemory_ =
190+ UsmNDArray_MakeSimpleFromMemory;
191+ this ->UsmNDArray_MakeSimpleFromPtr_ = UsmNDArray_MakeSimpleFromPtr;
192+ this ->UsmNDArray_MakeFromPtr_ = UsmNDArray_MakeFromPtr;
193+
138194 // constants
139195 this ->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS;
140196 this ->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS;
@@ -269,7 +325,9 @@ class usm_ndarray : public py::object
269325 char *get_data () const
270326 {
271327 PyUSMArrayObject *raw_ar = usm_array_ptr ();
272- return raw_ar->data_ ;
328+
329+ auto const &api = detail::dpnp_capi::get ();
330+ return api.UsmNDArray_GetData_ (raw_ar);
273331 }
274332
275333 template <typename T>
@@ -281,13 +339,17 @@ class usm_ndarray : public py::object
281339 int get_ndim () const
282340 {
283341 PyUSMArrayObject *raw_ar = usm_array_ptr ();
284- return raw_ar->nd_ ;
342+
343+ auto const &api = detail::dpnp_capi::get ();
344+ return api.UsmNDArray_GetNDim_ (raw_ar);
285345 }
286346
287347 const py::ssize_t *get_shape_raw () const
288348 {
289349 PyUSMArrayObject *raw_ar = usm_array_ptr ();
290- return raw_ar->shape_ ;
350+
351+ auto const &api = detail::dpnp_capi::get ();
352+ return api.UsmNDArray_GetShape_ (raw_ar);
291353 }
292354
293355 std::vector<py::ssize_t > get_shape_vector () const
@@ -308,7 +370,9 @@ class usm_ndarray : public py::object
308370 const py::ssize_t *get_strides_raw () const
309371 {
310372 PyUSMArrayObject *raw_ar = usm_array_ptr ();
311- return raw_ar->strides_ ;
373+
374+ auto const &api = detail::dpnp_capi::get ();
375+ return api.UsmNDArray_GetStrides_ (raw_ar);
312376 }
313377
314378 std::vector<py::ssize_t > get_strides_vector () const
@@ -343,8 +407,9 @@ class usm_ndarray : public py::object
343407 {
344408 PyUSMArrayObject *raw_ar = usm_array_ptr ();
345409
346- int ndim = raw_ar->nd_ ;
347- const py::ssize_t *shape = raw_ar->shape_ ;
410+ auto const &api = detail::dpnp_capi::get ();
411+ int ndim = api.UsmNDArray_GetNDim_ (raw_ar);
412+ const py::ssize_t *shape = api.UsmNDArray_GetShape_ (raw_ar);
348413
349414 py::ssize_t nelems = 1 ;
350415 for (int i = 0 ; i < ndim; ++i) {
@@ -359,9 +424,10 @@ class usm_ndarray : public py::object
359424 {
360425 PyUSMArrayObject *raw_ar = usm_array_ptr ();
361426
362- int nd = raw_ar->nd_ ;
363- const py::ssize_t *shape = raw_ar->shape_ ;
364- const py::ssize_t *strides = raw_ar->strides_ ;
427+ auto const &api = detail::dpnp_capi::get ();
428+ int nd = api.UsmNDArray_GetNDim_ (raw_ar);
429+ const py::ssize_t *shape = api.UsmNDArray_GetShape_ (raw_ar);
430+ const py::ssize_t *strides = api.UsmNDArray_GetStrides_ (raw_ar);
365431
366432 py::ssize_t offset_min = 0 ;
367433 py::ssize_t offset_max = 0 ;
@@ -389,77 +455,43 @@ class usm_ndarray : public py::object
389455 sycl::queue get_queue () const
390456 {
391457 PyUSMArrayObject *raw_ar = usm_array_ptr ();
392- Py_MemoryObject *mem_obj =
393- reinterpret_cast <Py_MemoryObject *>(raw_ar->base_ );
394458
395- auto const &dpctl_api = :: dpctl:: detail::dpctl_capi ::get ();
396- DPCTLSyclQueueRef QRef = dpctl_api. Memory_GetQueueRef_ (mem_obj );
459+ auto const &api = detail::dpnp_capi ::get ();
460+ DPCTLSyclQueueRef QRef = api. UsmNDArray_GetQueueRef_ (raw_ar );
397461 return *(reinterpret_cast <sycl::queue *>(QRef));
398462 }
399463
400464 sycl::device get_device () const
401465 {
402466 PyUSMArrayObject *raw_ar = usm_array_ptr ();
403- Py_MemoryObject *mem_obj =
404- reinterpret_cast <Py_MemoryObject *>(raw_ar->base_ );
405467
406- auto const &dpctl_api = :: dpctl:: detail::dpctl_capi ::get ();
407- DPCTLSyclQueueRef QRef = dpctl_api. Memory_GetQueueRef_ (mem_obj );
468+ auto const &api = detail::dpnp_capi ::get ();
469+ DPCTLSyclQueueRef QRef = api. UsmNDArray_GetQueueRef_ (raw_ar );
408470 return reinterpret_cast <sycl::queue *>(QRef)->get_device ();
409471 }
410472
411473 int get_typenum () const
412474 {
413475 PyUSMArrayObject *raw_ar = usm_array_ptr ();
414- return raw_ar->typenum_ ;
476+
477+ auto const &api = detail::dpnp_capi::get ();
478+ return api.UsmNDArray_GetTypenum_ (raw_ar);
415479 }
416480
417481 int get_flags () const
418482 {
419483 PyUSMArrayObject *raw_ar = usm_array_ptr ();
420- return raw_ar->flags_ ;
484+
485+ auto const &api = detail::dpnp_capi::get ();
486+ return api.UsmNDArray_GetFlags_ (raw_ar);
421487 }
422488
423489 int get_elemsize () const
424490 {
425- int typenum = get_typenum ();
426- auto const &api = detail::dpnp_capi::get ();
491+ PyUSMArrayObject *raw_ar = usm_array_ptr ();
427492
428- // Lookup table for element sizes based on typenum
429- if (typenum == api.UAR_BOOL_ )
430- return 1 ;
431- if (typenum == api.UAR_BYTE_ )
432- return 1 ;
433- if (typenum == api.UAR_UBYTE_ )
434- return 1 ;
435- if (typenum == api.UAR_SHORT_ )
436- return 2 ;
437- if (typenum == api.UAR_USHORT_ )
438- return 2 ;
439- if (typenum == api.UAR_INT_ )
440- return 4 ;
441- if (typenum == api.UAR_UINT_ )
442- return 4 ;
443- if (typenum == api.UAR_LONG_ )
444- return sizeof (long );
445- if (typenum == api.UAR_ULONG_ )
446- return sizeof (unsigned long );
447- if (typenum == api.UAR_LONGLONG_ )
448- return 8 ;
449- if (typenum == api.UAR_ULONGLONG_ )
450- return 8 ;
451- if (typenum == api.UAR_FLOAT_ )
452- return 4 ;
453- if (typenum == api.UAR_DOUBLE_ )
454- return 8 ;
455- if (typenum == api.UAR_CFLOAT_ )
456- return 8 ;
457- if (typenum == api.UAR_CDOUBLE_ )
458- return 16 ;
459- if (typenum == api.UAR_HALF_ )
460- return 2 ;
461-
462- return 0 ; // Unknown type
493+ auto const &api = detail::dpnp_capi::get ();
494+ return api.UsmNDArray_GetElementSize_ (raw_ar);
463495 }
464496
465497 bool is_c_contiguous () const
@@ -487,9 +519,10 @@ class usm_ndarray : public py::object
487519 py::object get_usm_data () const
488520 {
489521 PyUSMArrayObject *raw_ar = usm_array_ptr ();
522+
523+ auto const &api = detail::dpnp_capi::get ();
490524 // base_ is the Memory object - return new reference
491- PyObject *usm_data = raw_ar->base_ ;
492- Py_XINCREF (usm_data);
525+ PyObject *usm_data = api.UsmNDArray_GetUSMData_ (raw_ar);
493526
494527 // pass reference ownership to py::object
495528 return py::reinterpret_steal<py::object>(usm_data);
@@ -498,28 +531,34 @@ class usm_ndarray : public py::object
498531 bool is_managed_by_smart_ptr () const
499532 {
500533 PyUSMArrayObject *raw_ar = usm_array_ptr ();
501- PyObject *usm_data = raw_ar->base_ ;
534+
535+ auto const &api = detail::dpnp_capi::get ();
536+ PyObject *usm_data = api.UsmNDArray_GetUSMData_ (raw_ar);
502537
503538 auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get ();
504539 if (!PyObject_TypeCheck (usm_data, dpctl_api.Py_MemoryType_ )) {
540+ Py_DECREF (usm_data);
505541 return false ;
506542 }
507543
508544 Py_MemoryObject *mem_obj =
509545 reinterpret_cast <Py_MemoryObject *>(usm_data);
510546 const void *opaque_ptr = dpctl_api.Memory_GetOpaquePointer_ (mem_obj);
511547
548+ Py_DECREF (usm_data);
512549 return bool (opaque_ptr);
513550 }
514551
515552 const std::shared_ptr<void > &get_smart_ptr_owner () const
516553 {
517554 PyUSMArrayObject *raw_ar = usm_array_ptr ();
518- PyObject *usm_data = raw_ar->base_ ;
519555
520- auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get ();
556+ auto const &api = detail::dpnp_capi::get ();
557+ PyObject *usm_data = api.UsmNDArray_GetUSMData_ (raw_ar);
521558
559+ auto const &dpctl_api = ::dpctl::detail::dpctl_capi::get ();
522560 if (!PyObject_TypeCheck (usm_data, dpctl_api.Py_MemoryType_ )) {
561+ Py_DECREF (usm_data);
523562 throw std::runtime_error (
524563 " usm_ndarray object does not have Memory object "
525564 " managing lifetime of USM allocation" );
@@ -528,6 +567,7 @@ class usm_ndarray : public py::object
528567 Py_MemoryObject *mem_obj =
529568 reinterpret_cast <Py_MemoryObject *>(usm_data);
530569 void *opaque_ptr = dpctl_api.Memory_GetOpaquePointer_ (mem_obj);
570+ Py_DECREF (usm_data);
531571
532572 if (opaque_ptr) {
533573 auto shptr_ptr =
0 commit comments