-
Notifications
You must be signed in to change notification settings - Fork 34
Expand file tree
/
Copy path_dlpack.pyx
More file actions
1232 lines (1095 loc) · 43.7 KB
/
_dlpack.pyx
File metadata and controls
1232 lines (1095 loc) · 43.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Data Parallel Control (dpctl)
#
# Copyright 2020-2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# distutils: language = c++
# cython: language_level=3
# cython: linetrace=True
cdef extern from "numpy/npy_no_deprecated_api.h":
pass
cimport cpython
from libc cimport stdlib
from libc.stdint cimport int64_t, uint8_t, uint16_t, uint32_t, uint64_t
from numpy cimport ndarray
cimport dpctl as c_dpctl
cimport dpctl.memory as c_dpmem
from dpctl._sycl_queue_manager cimport get_device_cached_queue
from .._backend cimport (
DPCTLDevice_Delete,
DPCTLDevice_GetParentDevice,
DPCTLSyclDeviceRef,
DPCTLSyclUSMRef,
)
from ._usmarray cimport (
USM_ARRAY_C_CONTIGUOUS,
USM_ARRAY_F_CONTIGUOUS,
USM_ARRAY_WRITABLE,
usm_ndarray,
)
import ctypes
import numpy as np
import dpctl
import dpctl.memory as dpmem
from ._device import Device
cdef extern from "dlpack/dlpack.h" nogil:
cdef int DLPACK_MAJOR_VERSION
cdef int DLPACK_MINOR_VERSION
cdef int DLPACK_FLAG_BITMASK_READ_ONLY
cdef int DLPACK_FLAG_BITMASK_IS_COPIED
ctypedef struct DLPackVersion:
uint32_t major
uint32_t minor
cdef enum DLDeviceType:
kDLCPU
kDLCUDA
kDLCUDAHost
kDLCUDAManaged
kDLROCM
kDLROCMHost
kDLOpenCL
kDLVulkan
kDLMetal
kDLVPI
kDLOneAPI
kDLWebGPU
kDLHexagon
kDLMAIA
kDLTrn
ctypedef struct DLDevice:
DLDeviceType device_type
int device_id
cdef enum DLDataTypeCode:
kDLInt
kDLUInt
kDLFloat
kDLBfloat
kDLComplex
kDLBool
kDLFloat8_e3m4
kDLFloat8_e4m3
kDLFloat8_e4m3b11fnuz
kDLFloat8_e4m3fn
kDLFloat8_e4m3fnuz
kDLFloat8_e5m2
kDLFloat8_e5m2fnuz
kDLFloat8_e8m0fnu
kDLFloat6_e2m3fn
kDLFloat6_e3m2fn
kDLFloat4_e2m1fn
ctypedef struct DLDataType:
uint8_t code
uint8_t bits
uint16_t lanes
ctypedef struct DLTensor:
void *data
DLDevice device
int ndim
DLDataType dtype
int64_t *shape
int64_t *strides
uint64_t byte_offset
ctypedef struct DLManagedTensor:
DLTensor dl_tensor
void *manager_ctx
void (*deleter)(DLManagedTensor *) # noqa: E211
ctypedef struct DLManagedTensorVersioned:
DLPackVersion version
void *manager_ctx
void (*deleter)(DLManagedTensorVersioned *) # noqa: E211
uint64_t flags
DLTensor dl_tensor
def get_build_dlpack_version():
"""
Returns a tuple of integers representing the `major` and `minor`
version of DLPack :module:`dpctl.tensor` was built with.
This tuple can be passed as the `max_version` argument to
`__dlpack__` to guarantee module:`dpctl.tensor` can properly
consume capsule.
Returns:
Tuple[int, int]
A tuple of integers representing the `major` and `minor`
version of DLPack used to build :module:`dpctl.tensor`.
"""
return (DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION)
cdef void _pycapsule_deleter(object dlt_capsule) noexcept:
cdef DLManagedTensor *dlm_tensor = NULL
if cpython.PyCapsule_IsValid(dlt_capsule, "dltensor"):
dlm_tensor = <DLManagedTensor*>cpython.PyCapsule_GetPointer(
dlt_capsule, "dltensor")
dlm_tensor.deleter(dlm_tensor)
cdef void _managed_tensor_deleter(
DLManagedTensor *dlm_tensor
) noexcept with gil:
if dlm_tensor is not NULL:
# we only delete shape, because we make single allocation to
# acommodate both shape and strides if strides are needed
stdlib.free(dlm_tensor.dl_tensor.shape)
cpython.Py_DECREF(<object>dlm_tensor.manager_ctx)
dlm_tensor.manager_ctx = NULL
stdlib.free(dlm_tensor)
cdef void _pycapsule_versioned_deleter(object dlt_capsule) noexcept:
cdef DLManagedTensorVersioned *dlmv_tensor = NULL
if cpython.PyCapsule_IsValid(dlt_capsule, "dltensor_versioned"):
dlmv_tensor = <DLManagedTensorVersioned*>cpython.PyCapsule_GetPointer(
dlt_capsule, "dltensor_versioned")
dlmv_tensor.deleter(dlmv_tensor)
cdef void _managed_tensor_versioned_deleter(
DLManagedTensorVersioned *dlmv_tensor
) noexcept with gil:
if dlmv_tensor is not NULL:
# we only delete shape, because we make single allocation to
# acommodate both shape and strides if strides are needed
stdlib.free(dlmv_tensor.dl_tensor.shape)
cpython.Py_DECREF(<object>dlmv_tensor.manager_ctx)
dlmv_tensor.manager_ctx = NULL
stdlib.free(dlmv_tensor)
cdef object _get_default_context(c_dpctl.SyclDevice dev):
try:
default_context = dev.sycl_platform.default_context
except RuntimeError:
# RT does not support default_context
default_context = None
return default_context
cdef int get_array_dlpack_device_id(
usm_ndarray usm_ary
) except -1:
"""Finds ordinal number of the parent of device where array
was allocated.
"""
cdef c_dpctl.SyclQueue ary_sycl_queue
cdef c_dpctl.SyclDevice ary_sycl_device
cdef DPCTLSyclDeviceRef pDRef = NULL
cdef int device_id = -1
ary_sycl_queue = usm_ary.get_sycl_queue()
ary_sycl_device = ary_sycl_queue.get_sycl_device()
default_context = _get_default_context(ary_sycl_device)
if default_context is None:
# check that ary_sycl_device is a non-partitioned device
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
if pDRef is not NULL:
DPCTLDevice_Delete(pDRef)
raise DLPackCreationError(
"to_dlpack_capsule: DLPack can only export arrays allocated "
"on non-partitioned SYCL devices on platforms where "
"default_context oneAPI extension is not supported."
)
else:
if not usm_ary.sycl_context == default_context:
raise DLPackCreationError(
"to_dlpack_capsule: DLPack can only export arrays based on USM "
"allocations bound to a default platform SYCL context"
)
device_id = ary_sycl_device.get_device_id()
if device_id < 0:
raise DLPackCreationError(
"get_array_dlpack_device_id: failed to determine device_id"
)
return device_id
cpdef to_dlpack_capsule(usm_ndarray usm_ary):
"""
to_dlpack_capsule(usm_ary)
Constructs named Python capsule object referencing
instance of ``DLManagedTensor`` from
:class:`dpctl.tensor.usm_ndarray` instance.
Args:
usm_ary: An instance of :class:`dpctl.tensor.usm_ndarray`
Returns:
A new capsule with name ``"dltensor"`` that contains
a pointer to ``DLManagedTensor`` struct.
Raises:
DLPackCreationError: when array can be represented as
DLPack tensor. This may happen when array was allocated
on a partitioned sycl device, or its USM allocation is
not bound to the platform default SYCL context.
MemoryError: when host allocation to needed for ``DLManagedTensor``
did not succeed.
ValueError: when array elements data type could not be represented
in ``DLManagedTensor``.
"""
cdef DLManagedTensor *dlm_tensor = NULL
cdef DLTensor *dl_tensor = NULL
cdef int nd = usm_ary.get_ndim()
cdef char *data_ptr = usm_ary.get_data()
cdef Py_ssize_t *shape_ptr = NULL
cdef Py_ssize_t *strides_ptr = NULL
cdef int64_t *shape_strides_ptr = NULL
cdef int i = 0
cdef int device_id = -1
cdef Py_ssize_t element_offset = 0
cdef Py_ssize_t byte_offset = 0
cdef Py_ssize_t si = 1
ary_base = usm_ary.get_base()
device_id = get_array_dlpack_device_id(usm_ary)
dlm_tensor = <DLManagedTensor *> stdlib.malloc(
sizeof(DLManagedTensor))
if dlm_tensor is NULL:
raise MemoryError(
"to_dlpack_capsule: Could not allocate memory for DLManagedTensor"
)
if nd > 0:
shape_strides_ptr = <int64_t *>stdlib.malloc((sizeof(int64_t) * 2) * nd)
if shape_strides_ptr is NULL:
stdlib.free(dlm_tensor)
raise MemoryError(
"to_dlpack_capsule: Could not allocate memory for shape/strides"
)
shape_ptr = usm_ary.get_shape()
for i in range(nd):
shape_strides_ptr[i] = shape_ptr[i]
strides_ptr = usm_ary.get_strides()
flags = usm_ary.flags_
if strides_ptr:
for i in range(nd):
shape_strides_ptr[nd + i] = strides_ptr[i]
else:
if flags & USM_ARRAY_C_CONTIGUOUS:
si = 1
for i in range(nd - 1, -1, -1):
shape_strides_ptr[nd + i] = si
si = si * shape_ptr[i]
elif flags & USM_ARRAY_F_CONTIGUOUS:
si = 1
for i in range(0, nd):
shape_strides_ptr[nd + i] = si
si = si * shape_ptr[i]
else:
stdlib.free(shape_strides_ptr)
stdlib.free(dlm_tensor)
raise BufferError(
"to_dlpack_capsule: Could not reconstruct strides "
"for non-contiguous tensor"
)
strides_ptr = <Py_ssize_t *>&shape_strides_ptr[nd]
ary_dt = usm_ary.dtype
ary_dtk = ary_dt.kind
element_offset = usm_ary.get_offset()
byte_offset = element_offset * (<Py_ssize_t>ary_dt.itemsize)
dl_tensor = &dlm_tensor.dl_tensor
dl_tensor.data = <void*>(data_ptr - byte_offset)
dl_tensor.ndim = nd
dl_tensor.byte_offset = <uint64_t>byte_offset
dl_tensor.shape = &shape_strides_ptr[0] if nd > 0 else NULL
dl_tensor.strides = &shape_strides_ptr[nd] if nd > 0 else NULL
dl_tensor.device.device_type = kDLOneAPI
dl_tensor.device.device_id = device_id
dl_tensor.dtype.lanes = <uint16_t>1
dl_tensor.dtype.bits = <uint8_t>(ary_dt.itemsize * 8)
if (ary_dtk == "b"):
dl_tensor.dtype.code = <uint8_t>kDLBool
elif (ary_dtk == "u"):
dl_tensor.dtype.code = <uint8_t>kDLUInt
elif (ary_dtk == "i"):
dl_tensor.dtype.code = <uint8_t>kDLInt
elif (ary_dtk == "f"):
dl_tensor.dtype.code = <uint8_t>kDLFloat
elif (ary_dtk == "c"):
dl_tensor.dtype.code = <uint8_t>kDLComplex
else:
stdlib.free(shape_strides_ptr)
stdlib.free(dlm_tensor)
raise ValueError("Unrecognized array data type")
dlm_tensor.manager_ctx = <void*>ary_base
cpython.Py_INCREF(ary_base)
dlm_tensor.deleter = _managed_tensor_deleter
return cpython.PyCapsule_New(dlm_tensor, "dltensor", _pycapsule_deleter)
cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied):
"""
to_dlpack_versioned_capsule(usm_ary, copied)
Constructs named Python capsule object referencing
instance of ``DLManagedTensorVersioned`` from
:class:`dpctl.tensor.usm_ndarray` instance.
Args:
usm_ary: An instance of :class:`dpctl.tensor.usm_ndarray`
copied: A bint representing whether the data was previously
copied in order to set the flags with the is-copied
bitmask.
Returns:
A new capsule with name ``"dltensor_versioned"`` that
contains a pointer to ``DLManagedTensorVersioned`` struct.
Raises:
DLPackCreationError: when array can be represented as
DLPack tensor. This may happen when array was allocated
on a partitioned sycl device, or its USM allocation is
not bound to the platform default SYCL context.
MemoryError: when host allocation to needed for
``DLManagedTensorVersioned`` did not succeed.
ValueError: when array elements data type could not be represented
in ``DLManagedTensorVersioned``.
"""
cdef DLManagedTensorVersioned *dlmv_tensor = NULL
cdef DLTensor *dl_tensor = NULL
cdef uint32_t dlmv_flags = 0
cdef int nd = usm_ary.get_ndim()
cdef char *data_ptr = usm_ary.get_data()
cdef Py_ssize_t *shape_ptr = NULL
cdef Py_ssize_t *strides_ptr = NULL
cdef int64_t *shape_strides_ptr = NULL
cdef int i = 0
cdef int device_id = -1
cdef int flags = 0
cdef Py_ssize_t element_offset = 0
cdef Py_ssize_t byte_offset = 0
cdef Py_ssize_t si = 1
ary_base = usm_ary.get_base()
# Find ordinal number of the parent device
device_id = get_array_dlpack_device_id(usm_ary)
dlmv_tensor = <DLManagedTensorVersioned *> stdlib.malloc(
sizeof(DLManagedTensorVersioned))
if dlmv_tensor is NULL:
raise MemoryError(
"to_dlpack_versioned_capsule: Could not allocate memory "
"for DLManagedTensorVersioned"
)
if nd > 0:
shape_strides_ptr = <int64_t *>stdlib.malloc((sizeof(int64_t) * 2) * nd)
if shape_strides_ptr is NULL:
stdlib.free(dlmv_tensor)
raise MemoryError(
"to_dlpack_versioned_capsule: Could not allocate memory "
"for shape/strides"
)
# this can be a separate function for handling shapes and strides
shape_ptr = usm_ary.get_shape()
for i in range(nd):
shape_strides_ptr[i] = shape_ptr[i]
strides_ptr = usm_ary.get_strides()
flags = usm_ary.flags_
if strides_ptr:
for i in range(nd):
shape_strides_ptr[nd + i] = strides_ptr[i]
else:
if flags & USM_ARRAY_C_CONTIGUOUS:
si = 1
for i in range(nd - 1, -1, -1):
shape_strides_ptr[nd + i] = si
si = si * shape_ptr[i]
elif flags & USM_ARRAY_F_CONTIGUOUS:
si = 1
for i in range(0, nd):
shape_strides_ptr[nd + i] = si
si = si * shape_ptr[i]
else:
stdlib.free(shape_strides_ptr)
stdlib.free(dlmv_tensor)
raise BufferError(
"to_dlpack_versioned_capsule: Could not reconstruct "
"strides for non-contiguous tensor"
)
strides_ptr = <Py_ssize_t *>&shape_strides_ptr[nd]
# this can all be a function for building the dl_tensor
# object (separate from dlm/dlmv)
ary_dt = usm_ary.dtype
ary_dtk = ary_dt.kind
element_offset = usm_ary.get_offset()
byte_offset = element_offset * (<Py_ssize_t>ary_dt.itemsize)
dl_tensor = &dlmv_tensor.dl_tensor
dl_tensor.data = <void*>(data_ptr - byte_offset)
dl_tensor.ndim = nd
dl_tensor.byte_offset = <uint64_t>byte_offset
dl_tensor.shape = &shape_strides_ptr[0] if nd > 0 else NULL
dl_tensor.strides = &shape_strides_ptr[nd] if nd > 0 else NULL
dl_tensor.device.device_type = kDLOneAPI
dl_tensor.device.device_id = device_id
dl_tensor.dtype.lanes = <uint16_t>1
dl_tensor.dtype.bits = <uint8_t>(ary_dt.itemsize * 8)
if (ary_dtk == "b"):
dl_tensor.dtype.code = <uint8_t>kDLBool
elif (ary_dtk == "u"):
dl_tensor.dtype.code = <uint8_t>kDLUInt
elif (ary_dtk == "i"):
dl_tensor.dtype.code = <uint8_t>kDLInt
elif (ary_dtk == "f"):
dl_tensor.dtype.code = <uint8_t>kDLFloat
elif (ary_dtk == "c"):
dl_tensor.dtype.code = <uint8_t>kDLComplex
else:
stdlib.free(shape_strides_ptr)
stdlib.free(dlmv_tensor)
raise ValueError("Unrecognized array data type")
# set flags down here
if copied:
dlmv_flags |= DLPACK_FLAG_BITMASK_IS_COPIED
if not (flags & USM_ARRAY_WRITABLE):
dlmv_flags |= DLPACK_FLAG_BITMASK_READ_ONLY
dlmv_tensor.flags = dlmv_flags
dlmv_tensor.version.major = DLPACK_MAJOR_VERSION
dlmv_tensor.version.minor = DLPACK_MINOR_VERSION
dlmv_tensor.manager_ctx = <void*>ary_base
cpython.Py_INCREF(ary_base)
dlmv_tensor.deleter = _managed_tensor_versioned_deleter
return cpython.PyCapsule_New(
dlmv_tensor, "dltensor_versioned", _pycapsule_versioned_deleter
)
cpdef numpy_to_dlpack_versioned_capsule(ndarray npy_ary, bint copied):
"""
to_dlpack_versioned_capsule(npy_ary, copied)
Constructs named Python capsule object referencing
instance of ``DLManagedTensorVersioned`` from
:class:`numpy.ndarray` instance.
Args:
npy_ary: An instance of :class:`numpy.ndarray`
copied: A bint representing whether the data was previously
copied in order to set the flags with the is-copied
bitmask.
Returns:
A new capsule with name ``"dltensor_versioned"`` that
contains a pointer to ``DLManagedTensorVersioned`` struct.
Raises:
DLPackCreationError: when array can be represented as
DLPack tensor.
MemoryError: when host allocation to needed for
``DLManagedTensorVersioned`` did not succeed.
ValueError: when array elements data type could not be represented
in ``DLManagedTensorVersioned``.
"""
cdef DLManagedTensorVersioned *dlmv_tensor = NULL
cdef DLTensor *dl_tensor = NULL
cdef uint32_t dlmv_flags = 0
cdef int nd = npy_ary.ndim
cdef int64_t *shape_strides_ptr = NULL
cdef int i = 0
cdef Py_ssize_t byte_offset = 0
cdef int itemsize = npy_ary.itemsize
dlmv_tensor = <DLManagedTensorVersioned *> stdlib.malloc(
sizeof(DLManagedTensorVersioned))
if dlmv_tensor is NULL:
raise MemoryError(
"numpy_to_dlpack_versioned_capsule: Could not allocate memory "
"for DLManagedTensorVersioned"
)
shape = npy_ary.ctypes.shape_as(ctypes.c_int64)
strides = npy_ary.ctypes.strides_as(ctypes.c_int64)
if nd > 0:
if npy_ary.size != 1:
for i in range(nd):
if shape[i] != 1 and strides[i] % itemsize != 0:
stdlib.free(dlmv_tensor)
raise BufferError(
"numpy_to_dlpack_versioned_capsule: DLPack cannot "
"encode an array if strides are not a multiple of "
"itemsize"
)
shape_strides_ptr = <int64_t *>stdlib.malloc((sizeof(int64_t) * 2) * nd)
if shape_strides_ptr is NULL:
stdlib.free(dlmv_tensor)
raise MemoryError(
"numpy_to_dlpack_versioned_capsule: Could not allocate memory "
"for shape/strides"
)
for i in range(nd):
shape_strides_ptr[i] = shape[i]
shape_strides_ptr[nd + i] = strides[i] // itemsize
writable_flag = npy_ary.flags["W"]
ary_dt = npy_ary.dtype
ary_dtk = ary_dt.kind
dl_tensor = &dlmv_tensor.dl_tensor
dl_tensor.data = <void *> npy_ary.data
dl_tensor.ndim = nd
dl_tensor.byte_offset = <uint64_t>byte_offset
dl_tensor.shape = &shape_strides_ptr[0] if nd > 0 else NULL
dl_tensor.strides = &shape_strides_ptr[nd] if nd > 0 else NULL
dl_tensor.device.device_type = kDLCPU
dl_tensor.device.device_id = 0
dl_tensor.dtype.lanes = <uint16_t>1
dl_tensor.dtype.bits = <uint8_t>(ary_dt.itemsize * 8)
if (ary_dtk == "b"):
dl_tensor.dtype.code = <uint8_t>kDLBool
elif (ary_dtk == "u"):
dl_tensor.dtype.code = <uint8_t>kDLUInt
elif (ary_dtk == "i"):
dl_tensor.dtype.code = <uint8_t>kDLInt
elif (ary_dtk == "f" and ary_dt.itemsize <= 8):
dl_tensor.dtype.code = <uint8_t>kDLFloat
elif (ary_dtk == "c" and ary_dt.itemsize <= 16):
dl_tensor.dtype.code = <uint8_t>kDLComplex
else:
stdlib.free(shape_strides_ptr)
stdlib.free(dlmv_tensor)
raise ValueError("Unrecognized array data type")
# set flags down here
if copied:
dlmv_flags |= DLPACK_FLAG_BITMASK_IS_COPIED
if not writable_flag:
dlmv_flags |= DLPACK_FLAG_BITMASK_READ_ONLY
dlmv_tensor.flags = dlmv_flags
dlmv_tensor.version.major = DLPACK_MAJOR_VERSION
dlmv_tensor.version.minor = DLPACK_MINOR_VERSION
dlmv_tensor.manager_ctx = <void*>npy_ary
cpython.Py_INCREF(npy_ary)
dlmv_tensor.deleter = _managed_tensor_versioned_deleter
return cpython.PyCapsule_New(
dlmv_tensor, "dltensor_versioned", _pycapsule_versioned_deleter
)
cdef class _DLManagedTensorOwner:
"""
Helper class managing the lifetime of the DLManagedTensor struct
transferred from a 'dlpack' capsule.
"""
cdef DLManagedTensor * dlm_tensor
def __cinit__(self):
self.dlm_tensor = NULL
def __dealloc__(self):
if self.dlm_tensor:
self.dlm_tensor.deleter(self.dlm_tensor)
self.dlm_tensor = NULL
@staticmethod
cdef _DLManagedTensorOwner _create(DLManagedTensor *dlm_tensor_src):
cdef _DLManagedTensorOwner res
res = _DLManagedTensorOwner.__new__(_DLManagedTensorOwner)
res.dlm_tensor = dlm_tensor_src
return res
cdef class _DLManagedTensorVersionedOwner:
"""
Helper class managing the lifetime of the DLManagedTensorVersioned
struct transferred from a 'dlpack_versioned' capsule.
"""
cdef DLManagedTensorVersioned * dlmv_tensor
def __cinit__(self):
self.dlmv_tensor = NULL
def __dealloc__(self):
if self.dlmv_tensor:
self.dlmv_tensor.deleter(self.dlmv_tensor)
self.dlmv_tensor = NULL
@staticmethod
cdef _DLManagedTensorVersionedOwner _create(
DLManagedTensorVersioned *dlmv_tensor_src
):
cdef _DLManagedTensorVersionedOwner res
res = _DLManagedTensorVersionedOwner.__new__(
_DLManagedTensorVersionedOwner
)
res.dlmv_tensor = dlmv_tensor_src
return res
cdef dict _numpy_array_interface_from_dl_tensor(DLTensor *dlt, bint ro_flag):
"""Constructs a NumPy `__array_interface__` dictionary from a DLTensor."""
cdef int itemsize = 0
if dlt.dtype.lanes != 1:
raise BufferError(
"Can not import DLPack tensor with lanes != 1"
)
itemsize = dlt.dtype.bits // 8
shape = list()
if (dlt.strides is NULL):
strides = None
for dim in range(dlt.ndim):
shape.append(dlt.shape[dim])
else:
strides = list()
for dim in range(dlt.ndim):
shape.append(dlt.shape[dim])
# convert to byte-strides
strides.append(dlt.strides[dim] * itemsize)
strides = tuple(strides)
shape = tuple(shape)
if (dlt.dtype.code == kDLUInt):
ary_dt = "u" + str(itemsize)
elif (dlt.dtype.code == kDLInt):
ary_dt = "i" + str(itemsize)
elif (dlt.dtype.code == kDLFloat):
ary_dt = "f" + str(itemsize)
elif (dlt.dtype.code == kDLComplex):
ary_dt = "c" + str(itemsize)
elif (dlt.dtype.code == kDLBool):
ary_dt = "b" + str(itemsize)
else:
raise BufferError(
"Can not import DLPack tensor with type code {}.".format(
<object>dlt.dtype.code
)
)
typestr = "|" + ary_dt
return dict(
version=3,
shape=shape,
strides=strides,
data=(<size_t> dlt.data, True if ro_flag else False),
offset=dlt.byte_offset,
typestr=typestr,
)
class _numpy_array_interface_wrapper:
"""
Class that wraps a Python capsule and dictionary for consumption by NumPy.
Implementation taken from
https://github.com/dmlc/dlpack/blob/main/apps/numpy_dlpack/dlpack/to_numpy.py
Args:
array_interface:
A dictionary describing the underlying memory. Formatted
to match `numpy.ndarray.__array_interface__`.
pycapsule:
A Python capsule wrapping the dlpack tensor that will be
converted to numpy.
"""
def __init__(self, array_interface, memory_owner) -> None:
self.__array_interface__ = array_interface
self._memory_owner = memory_owner
cdef bint _is_kdlcpu_device(DLDevice *dev):
"Check if DLTensor.DLDevice denotes (kDLCPU, 0)"
return (dev[0].device_type == kDLCPU) and (dev[0].device_id == 0)
cpdef object from_dlpack_capsule(object py_caps):
"""
from_dlpack_capsule(py_caps)
Reconstructs instance of :class:`dpctl.tensor.usm_ndarray` from
named Python capsule object referencing instance of ``DLManagedTensor``
without copy. The instance forms a view in the memory of the tensor.
Args:
caps:
Python capsule with name ``"dltensor"`` expected to reference
an instance of ``DLManagedTensor`` struct.
Returns:
Instance of :class:`dpctl.tensor.usm_ndarray` with a view into
memory of the tensor. Capsule is renamed to ``"used_dltensor"``
upon success.
Raises:
TypeError:
if argument is not a ``"dltensor"`` capsule.
ValueError:
if argument is ``"used_dltensor"`` capsule
BufferError:
if the USM pointer is not bound to the reconstructed
sycl context, or the DLPack's device_type is not supported
by :mod:`dpctl`.
"""
cdef DLManagedTensorVersioned *dlmv_tensor = NULL
cdef DLManagedTensor *dlm_tensor = NULL
cdef DLTensor *dl_tensor = NULL
cdef int versioned = 0
cdef int readonly = 0
cdef bytes usm_type
cdef size_t sz = 1
cdef size_t alloc_sz = 1
cdef int i
cdef int device_id = -1
cdef int element_bytesize = 0
cdef Py_ssize_t offset_min = 0
cdef Py_ssize_t offset_max = 0
cdef char *mem_ptr = NULL
cdef Py_ssize_t mem_ptr_delta = 0
cdef Py_ssize_t element_offset = 0
cdef int64_t stride_i = -1
cdef int64_t shape_i = -1
if cpython.PyCapsule_IsValid(py_caps, "dltensor"):
dlm_tensor = <DLManagedTensor*>cpython.PyCapsule_GetPointer(
py_caps, "dltensor")
dl_tensor = &dlm_tensor.dl_tensor
elif cpython.PyCapsule_IsValid(py_caps, "dltensor_versioned"):
dlmv_tensor = <DLManagedTensorVersioned*>cpython.PyCapsule_GetPointer(
py_caps, "dltensor_versioned")
if dlmv_tensor.version.major > DLPACK_MAJOR_VERSION:
raise BufferError(
"Can not import DLPack tensor with major version "
f"greater than {DLPACK_MAJOR_VERSION}"
)
versioned = 1
readonly = (dlmv_tensor.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0
dl_tensor = &dlmv_tensor.dl_tensor
elif (
cpython.PyCapsule_IsValid(py_caps, "used_dltensor")
or cpython.PyCapsule_IsValid(py_caps, "used_dltensor_versioned")
):
raise ValueError(
"A DLPack tensor object can not be consumed multiple times"
)
else:
raise TypeError(
"`from_dlpack_capsule` expects a Python 'dltensor' capsule"
)
# Verify that we can work with this device
if dl_tensor.device.device_type == kDLOneAPI:
device_id = dl_tensor.device.device_id
root_device = dpctl.SyclDevice(str(<int>device_id))
try:
default_context = root_device.sycl_platform.default_context
except RuntimeError:
default_context = get_device_cached_queue(root_device).sycl_context
if dl_tensor.data is NULL:
usm_type = b"device"
q = get_device_cached_queue((default_context, root_device,))
else:
usm_type = c_dpmem._Memory.get_pointer_type(
<DPCTLSyclUSMRef> dl_tensor.data,
<c_dpctl.SyclContext>default_context)
if usm_type == b"unknown":
raise BufferError(
"Data pointer in DLPack is not bound to default sycl "
f"context of device '{device_id}', translated to "
f"{root_device.filter_string}"
)
alloc_device = c_dpmem._Memory.get_pointer_device(
<DPCTLSyclUSMRef> dl_tensor.data,
<c_dpctl.SyclContext>default_context
)
q = get_device_cached_queue((default_context, alloc_device,))
if dl_tensor.dtype.bits % 8:
raise BufferError(
"Can not import DLPack tensor whose element's "
"bitsize is not a multiple of 8"
)
if dl_tensor.dtype.lanes != 1:
raise BufferError(
"Can not import DLPack tensor with lanes != 1"
)
if dl_tensor.ndim > 0:
offset_min = 0
offset_max = 0
for i in range(dl_tensor.ndim):
stride_i = dl_tensor.strides[i]
shape_i = dl_tensor.shape[i]
if shape_i > 1:
shape_i -= 1
if stride_i > 0:
offset_max = offset_max + stride_i * shape_i
else:
offset_min = offset_min + stride_i * shape_i
sz = offset_max - offset_min + 1
if sz == 0:
sz = 1
element_bytesize = (dl_tensor.dtype.bits // 8)
sz = sz * element_bytesize
element_offset = dl_tensor.byte_offset // element_bytesize
# transfer ownership
if not versioned:
dlm_holder = _DLManagedTensorOwner._create(dlm_tensor)
cpython.PyCapsule_SetName(py_caps, "used_dltensor")
else:
dlmv_holder = _DLManagedTensorVersionedOwner._create(dlmv_tensor)
cpython.PyCapsule_SetName(py_caps, "used_dltensor_versioned")
if dl_tensor.data is NULL:
usm_mem = dpmem.MemoryUSMDevice(sz, q)
else:
mem_ptr_delta = dl_tensor.byte_offset - (
element_offset * element_bytesize
)
mem_ptr = <char *>dl_tensor.data
alloc_sz = dl_tensor.byte_offset + <uint64_t>(
(offset_max + 1) * element_bytesize)
tmp = c_dpmem._Memory.create_from_usm_pointer_size_qref(
<DPCTLSyclUSMRef> mem_ptr,
max(alloc_sz, <uint64_t>element_bytesize),
(<c_dpctl.SyclQueue>q).get_queue_ref(),
memory_owner=dlmv_holder if versioned else dlm_holder
)
if mem_ptr_delta == 0:
usm_mem = tmp
else:
alloc_sz = dl_tensor.byte_offset + <uint64_t>(
(offset_max * element_bytesize + mem_ptr_delta))
usm_mem = c_dpmem._Memory.create_from_usm_pointer_size_qref(
<DPCTLSyclUSMRef> (
mem_ptr + (element_bytesize - mem_ptr_delta)
),
max(alloc_sz, <uint64_t>element_bytesize),
(<c_dpctl.SyclQueue>q).get_queue_ref(),
memory_owner=tmp
)
py_shape = list()
if (dl_tensor.shape is not NULL):
for i in range(dl_tensor.ndim):
py_shape.append(dl_tensor.shape[i])
if (dl_tensor.strides is not NULL):
py_strides = list()
for i in range(dl_tensor.ndim):
py_strides.append(dl_tensor.strides[i])
else:
py_strides = None
if (dl_tensor.dtype.code == kDLUInt):
ary_dt = np.dtype("u" + str(element_bytesize))
elif (dl_tensor.dtype.code == kDLInt):
ary_dt = np.dtype("i" + str(element_bytesize))
elif (dl_tensor.dtype.code == kDLFloat):
ary_dt = np.dtype("f" + str(element_bytesize))
elif (dl_tensor.dtype.code == kDLComplex):
ary_dt = np.dtype("c" + str(element_bytesize))
elif (dl_tensor.dtype.code == kDLBool):
ary_dt = np.dtype("?")
else:
raise BufferError(
"Can not import DLPack tensor with type code {}.".format(
<object>dl_tensor.dtype.code
)
)
res_ary = usm_ndarray(
py_shape,
dtype=ary_dt,
buffer=usm_mem,
strides=py_strides,
offset=element_offset
)
if readonly:
res_ary.flags_ = (res_ary.flags_ & ~USM_ARRAY_WRITABLE)
return res_ary
elif _is_kdlcpu_device(&dl_tensor.device):
ary_iface = _numpy_array_interface_from_dl_tensor(dl_tensor, readonly)
if not versioned:
dlm_holder = _DLManagedTensorOwner._create(dlm_tensor)
cpython.PyCapsule_SetName(py_caps, "used_dltensor")
return np.ctypeslib.as_array(
_numpy_array_interface_wrapper(ary_iface, dlm_holder)
)
else:
dlmv_holder = _DLManagedTensorVersionedOwner._create(dlmv_tensor)
cpython.PyCapsule_SetName(py_caps, "used_dltensor_versioned")
return np.ctypeslib.as_array(
_numpy_array_interface_wrapper(ary_iface, dlmv_holder)
)
else:
raise BufferError(
"The DLPack tensor resides on unsupported device."
)
cdef usm_ndarray _to_usm_ary_from_host_blob(object host_blob, dev : Device):
q = dev.sycl_queue
np_ary = np.asarray(host_blob)
dt = np_ary.dtype
if dt.char in "dD" and q.sycl_device.has_aspect_fp64 is False:
Xusm_dtype = (
"float32" if dt.char == "d" else "complex64"
)
else:
Xusm_dtype = dt
usm_mem = dpmem.MemoryUSMDevice(np_ary.nbytes, queue=q)
usm_ary = usm_ndarray(np_ary.shape, dtype=Xusm_dtype, buffer=usm_mem)
usm_mem.copy_from_host(np.reshape(np_ary.view(dtype="u1"), -1))
return usm_ary
# only cdef to make it private
cdef object _create_device(object device, object dl_device):
if isinstance(device, Device):
return device
elif isinstance(device, dpctl.SyclDevice):
return Device.create_device(device)
else:
root_device = dpctl.SyclDevice(str(<int>dl_device[1]))
return Device.create_device(root_device)
def from_dlpack(x, /, *, device=None, copy=None):
"""from_dlpack(x, /, *, device=None, copy=None)
Constructs :class:`dpctl.tensor.usm_ndarray` or :class:`numpy.ndarray`
instance from a Python object ``x`` that implements ``__dlpack__`` protocol.
Args:
x (object):
A Python object representing an array that supports
``__dlpack__`` protocol.
device (
Optional[str, :class:`dpctl.SyclDevice`,
:class:`dpctl.SyclQueue`,