-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Expand file tree
/
Copy pathonnxruntime_pybind_ortvalue.cc
More file actions
645 lines (587 loc) · 32.6 KB
/
onnxruntime_pybind_ortvalue.cc
File metadata and controls
645 lines (587 loc) · 32.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/onnxruntime_pybind_exceptions.h"
#include "python/onnxruntime_pybind_mlvalue.h"
#include "python/onnxruntime_pybind_state_common.h"
#define NO_IMPORT_ARRAY
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#define PY_ARRAY_UNIQUE_SYMBOL onnxruntime_python_ARRAY_API
#include "python/numpy_helper.h"
#include "core/framework/ort_value.h"
#include "core/framework/tensor.h"
#include "core/framework/sparse_tensor.h"
#include "core/framework/TensorSeq.h"
namespace onnxruntime {
namespace python {
namespace py = pybind11;
namespace {
std::unique_ptr<OrtValue> OrtValueFromShapeAndType(const std::vector<int64_t>& shape,
MLDataType element_type,
const OrtDevice& device) {
AllocatorPtr allocator;
if (strcmp(GetDeviceName(device), CPU) == 0) {
allocator = GetAllocator();
} else {
#if !defined(ORT_MINIMAL_BUILD)
// prefer a shared allocator from the environment.
// these are provided by plugin EPs or custom allocators explicitly registered by the user.
allocator = GetSharedAllocator(device);
#endif
if (!allocator) {
if (strcmp(GetDeviceName(device), CUDA) == 0) {
#ifdef USE_CUDA
if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
}
allocator = GetCudaAllocator(device.Id());
#else
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
#endif
} else if (strcmp(GetDeviceName(device), HIP) == 0) {
#if USE_MIGRAPHX
allocator = GetMIGraphXAllocator(device.Id());
#else
throw std::runtime_error(
"Can't allocate memory on the AMD device using this package of OnnxRuntime. "
"Please use the ROCm package of OnnxRuntime to use this feature.");
#endif
} else if (strcmp(GetDeviceName(device), DML) == 0) {
#if USE_DML
allocator = GetDmlAllocator(device.Id());
#else
throw std::runtime_error(
"Can't allocate memory on the DirectML device using this package of OnnxRuntime. "
"Please use the DirectML package of OnnxRuntime to use this feature.");
#endif
}
}
if (!allocator) {
throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device");
}
}
auto ml_value = std::make_unique<OrtValue>();
Tensor::InitOrtValue(element_type, gsl::make_span(shape), std::move(allocator), *ml_value);
return ml_value;
}
// Allocate an OrtValue using the shared allocator matching the given OrtMemoryInfo.
// This allows callers to specify the exact memory type (e.g. HOST_ACCESSIBLE) rather than
// relying on OrtDevice.make() which always uses DEFAULT.
//
// Uses the full OrtMemoryInfo for the lookup (including mem_type) rather than just the OrtDevice,
// because the registered allocator's OrtMemoryInfo has a specific mem_type (e.g. OrtMemTypeCPU
// for HOST_ACCESSIBLE) that must match for FindExistingAllocator to succeed.
std::unique_ptr<OrtValue> OrtValueFromShapeAndTypeWithMemoryInfo(const std::vector<int64_t>& shape,
MLDataType element_type,
const OrtMemoryInfo& memory_info) {
auto& env = GetOrtEnv()->GetEnvironment();
AllocatorPtr allocator = env.GetRegisteredSharedAllocator(memory_info);
if (!allocator) {
throw std::runtime_error("No shared allocator found for: " + memory_info.ToString());
}
auto ml_value = std::make_unique<OrtValue>();
Tensor::InitOrtValue(element_type, gsl::make_span(shape), std::move(allocator), *ml_value);
return ml_value;
}
} // namespace
void addOrtValueMethods(pybind11::module& m) {
py::class_<OrtValue> ortvalue_binding(m, "OrtValue");
ortvalue_binding
// Factory method to create an OrtValue (Tensor) from the given Numpy object
// The Tensor allocates and manages its own memory (on the specified device) and copies data from the Numpy data buffer
.def_static("ortvalue_from_numpy", [](const py::object& array_on_cpu, const OrtDevice& device) {
if (!IsNumericNumpyArray(array_on_cpu)) {
throw std::runtime_error("Creation of OrtValues is currently only supported from non-string numpy arrays");
}
auto ml_value = std::make_unique<OrtValue>();
// The tensor's memory is allocated on the CPU
if (device.Type() == OrtDevice::CPU) {
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
CreateGenericMLValue(nullptr, GetAllocator(), "", array_on_cpu, ml_value.get(), true);
} else if (device.Type() == OrtDevice::GPU) {
#if USE_DML
if (device.Vendor() == OrtDevice::VendorIds::MICROSOFT) {
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors
// in DML
CreateGenericMLValue(
nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy);
} else
#endif
#ifdef USE_CUDA
if (device.Vendor() == OrtDevice::VendorIds::NVIDIA) {
if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
}
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors
// in CUDA
CreateGenericMLValue(nullptr, GetCudaAllocator(device.Id()), "", array_on_cpu, ml_value.get(),
true, false, CpuToCudaMemCpy);
} else
#endif
#if USE_MIGRAPHX
if (device.Vendor() == OrtDevice::VendorIds::AMD) {
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors
// in MIGraphX
CreateGenericMLValue(nullptr, GetMIGraphXAllocator(device.Id()), "", array_on_cpu, ml_value.get(),
true, false, CpuToMIGraphXMemCpy);
} else
#endif
{
// see if we can do the copy with an allocator and IDataTransfer registered by a plugin EP
auto allocator = GetSharedAllocator(device);
auto cpu_to_device_copy_fn = allocator ? CreateDataTransferMemCpy(OrtDevice{}, device) : nullptr;
if (cpu_to_device_copy_fn) {
CreateGenericMLValue(nullptr, allocator, "", array_on_cpu, ml_value.get(), true, false,
cpu_to_device_copy_fn);
} else {
throw std::runtime_error(
"Can't allocate memory on the device using this package of OnnxRuntime. "
"Please use the appropriate package of OnnxRuntime for your hardware to use this feature.");
}
}
} else if (device.Type() == OrtDevice::NPU && device.Vendor() == OrtDevice::VendorIds::HUAWEI) {
#ifdef USE_CANN
if (!IsCannDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
throw std::runtime_error("The provided device id doesn't match any available NPUs on the machine.");
}
CreateGenericMLValue(nullptr, GetCannAllocator(device.Id()), "", array_on_cpu, ml_value.get(),
true, false, CpuToCannMemCpy);
#else
throw std::runtime_error(
"Can't allocate memory on the CANN device using this package of OnnxRuntime. "
"Please use the CANN package of OnnxRuntime to use this feature.");
#endif
} else {
throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device");
}
return ml_value;
})
.def("update_inplace", [](OrtValue* ml_value, const py::array& py_values) {
if (!IsNumericNumpyArray(py_values)) {
throw std::runtime_error("Inplace update of OrtValues is currently only supported from non-string numpy arrays");
}
if (py_values.size() != ml_value->Get<Tensor>().Shape().Size()) {
throw std::runtime_error("The input size of numpy arrays does not match the size of the OrtValue.");
}
auto values_type = GetNumpyArrayType(py_values);
const auto device = ml_value->Get<Tensor>().Location().device;
if (device.Type() == OrtDevice::CPU) {
onnxruntime::python::CopyDataToTensor(
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToCpuMemCpy);
} else if (device.Type() == OrtDevice::GPU) {
#ifdef USE_CUDA
if (device.Vendor() == OrtDevice::VendorIds::NVIDIA) {
if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
}
onnxruntime::python::CopyDataToTensor(
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToCudaMemCpy);
} else
#endif
#if USE_MIGRAPHX
if (device.Vendor() == OrtDevice::VendorIds::AMD) {
onnxruntime::python::CopyDataToTensor(
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToMIGraphXMemCpy);
} else
#endif
#if USE_DML
if (device.Vendor() == OrtDevice::VendorIds::MICROSOFT) {
onnxruntime::python::CopyDataToTensor(
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToDmlMemCpy);
} else
#endif
{
// see if we can do the copy with an allocator and IDataTransfer registered by a plugin EP
auto allocator = GetSharedAllocator(device);
auto cpu_to_device_copy_fn = allocator ? CreateDataTransferMemCpy(OrtDevice{}, device) : nullptr;
if (cpu_to_device_copy_fn) {
onnxruntime::python::CopyDataToTensor(py_values, values_type, *(ml_value->GetMutable<Tensor>()),
cpu_to_device_copy_fn);
} else {
throw std::runtime_error(
"Unsupported GPU device: Cannot find the supported GPU device.");
}
}
} else if (device.Type() == OrtDevice::DML) {
#if USE_DML
onnxruntime::python::CopyDataToTensor(
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToDmlMemCpy);
#else
throw std::runtime_error(
"Unsupported GPU device: Cannot find the supported GPU device.");
#endif
} else {
throw std::runtime_error("Unsupported device: Cannot update the OrtValue on this device");
}
})
.def("update_inplace", [](OrtValue* ml_value, const OrtValue& source) {
python::UpdateOrtValueInplace(*ml_value, source);
})
// Create an ortvalue value on top of the numpy array, but interpret the data
// as a different type with the same element size.
.def_static("ortvalue_from_numpy_with_onnx_type", [](py::array& data, int32_t onnx_element_type) -> std::unique_ptr<OrtValue> {
if (!ONNX_NAMESPACE::TensorProto_DataType_IsValid(onnx_element_type)) {
ORT_THROW("Not a valid ONNX Tensor data type: ", onnx_element_type);
}
const auto element_type = OnnxTypeToOnnxRuntimeTensorType(onnx_element_type);
const auto element_size = element_type->Size();
if (narrow<size_t>(data.itemsize()) != element_size) {
ORT_THROW("Items size in the incoming array: ", data.itemsize(),
" specified by onnxtype: ", element_size);
}
auto cpu_allocator = GetAllocator();
auto ort_value = std::make_unique<OrtValue>();
Tensor::InitOrtValue(element_type, GetShape(data),
const_cast<void*>(data.data()), cpu_allocator->Info(), *ort_value);
return ort_value;
})
// Factory method to create an OrtValue from the given shape and numpy element type on the specified device.
// The memory is left uninitialized
.def_static("ortvalue_from_shape_and_type", [](const std::vector<int64_t>& shape, py::object& numpy_element_type, const OrtDevice& device) -> std::unique_ptr<OrtValue> {
PyArray_Descr* dtype;
if (!PyArray_DescrConverter(numpy_element_type.ptr(), &dtype)) {
throw std::runtime_error("Not a valid numpy type");
}
int type_num = dtype->type_num;
Py_DECREF(dtype);
if (!IsNumericNumpyType(type_num)) {
throw std::runtime_error("Creation of OrtValues is currently only supported from non-string numpy arrays");
}
auto element_type = NumpyTypeToOnnxRuntimeTensorType(type_num);
return OrtValueFromShapeAndType(shape, element_type, device);
})
// Factory method to create an OrtValue from the given shape and onnx element type on the specified device.
// The memory is left uninitialized
.def_static("ortvalue_from_shape_and_onnx_type", [](const std::vector<int64_t>& shape, int32_t onnx_element_type, const OrtDevice& device) -> std::unique_ptr<OrtValue> {
if (onnx_element_type == onnx::TensorProto_DataType::TensorProto_DataType_STRING) {
throw std::runtime_error("Creation of OrtValues is currently only supported from non-string numpy arrays");
}
auto element_type = OnnxTypeToOnnxRuntimeTensorType(onnx_element_type);
return OrtValueFromShapeAndType(shape, element_type, device);
})
// Factory methods to create an OrtValue using an OrtMemoryInfo to select the allocator.
// This enables allocation with a specific memory type (e.g. HOST_ACCESSIBLE) from plugin EPs.
.def_static("ortvalue_from_shape_and_type_for_memory_info", [](const std::vector<int64_t>& shape, py::object& numpy_element_type, const OrtMemoryInfo& memory_info) -> std::unique_ptr<OrtValue> {
PyArray_Descr* dtype;
if (!PyArray_DescrConverter(numpy_element_type.ptr(), &dtype)) {
throw std::runtime_error("Not a valid numpy type");
}
int type_num = dtype->type_num;
Py_DECREF(dtype);
if (!IsNumericNumpyType(type_num)) {
throw std::runtime_error("Creation of OrtValues is currently only supported from non-string numpy arrays");
}
auto element_type = NumpyTypeToOnnxRuntimeTensorType(type_num);
return OrtValueFromShapeAndTypeWithMemoryInfo(shape, element_type, memory_info);
})
.def_static("ortvalue_from_shape_and_onnx_type_for_memory_info", [](const std::vector<int64_t>& shape, int32_t onnx_element_type, const OrtMemoryInfo& memory_info) -> std::unique_ptr<OrtValue> {
if (onnx_element_type == onnx::TensorProto_DataType::TensorProto_DataType_STRING) {
throw std::runtime_error("Creation of OrtValues is currently only supported from non-string numpy arrays");
}
auto element_type = OnnxTypeToOnnxRuntimeTensorType(onnx_element_type);
return OrtValueFromShapeAndTypeWithMemoryInfo(shape, element_type, memory_info);
})
#if !defined(DISABLE_SPARSE_TENSORS)
.def_static("ort_value_from_sparse_tensor", [](const PySparseTensor* py_sparse_tensor) -> std::unique_ptr<OrtValue> {
return py_sparse_tensor->AsOrtValue();
})
// This will create a copy of OrtValue(cheap) and will return as a separate SparseTensor object
.def("as_sparse_tensor", [](const OrtValue* ort_value) -> std::unique_ptr<PySparseTensor> {
if (!ort_value->IsSparseTensor()) {
ORT_THROW("This OrtValue does not contain SparseTensor. Check data_type() value.");
}
return std::make_unique<PySparseTensor>(*ort_value);
})
#endif
// Get a pointer to Tensor data
.def("data_ptr", [](OrtValue* ml_value) -> uintptr_t {
// TODO: Assumes that the OrtValue is a Tensor, make this generic to handle non-Tensors
ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are currently supported");
auto* tensor = ml_value->GetMutable<Tensor>();
if (tensor->Shape().Size() == 0) {
return 0;
}
// Should cover x86 and x64 platforms
return reinterpret_cast<uintptr_t>(tensor->MutableDataRaw());
})
.def("device_name", [](const OrtValue* ort_value) -> std::string {
if (ort_value->IsTensor()) {
return std::string(GetDeviceName(ort_value->Get<Tensor>().Location().device));
}
#if !defined(DISABLE_SPARSE_TENSORS)
else if (ort_value->IsSparseTensor()) {
return std::string(GetDeviceName(ort_value->Get<SparseTensor>().Location().device));
}
ORT_THROW("Only OrtValues that are Tensors/SparseTensors are currently supported");
#else
ORT_THROW("Only OrtValues that are Tensors are supported in this build");
#endif
})
.def("shape", [](const OrtValue* ort_value) -> py::list {
#if !defined(DISABLE_SPARSE_TENSORS)
// OrtValue can only be a Tensor/SparseTensor, make this generic to handle non-Tensors
ORT_ENFORCE(ort_value->IsTensor() || ort_value->IsSparseTensor(),
"Only OrtValues that are Tensors/SpareTensors are currently supported");
const auto dims = (ort_value->IsTensor())
? ort_value->Get<Tensor>().Shape().GetDims()
: ort_value->Get<SparseTensor>().DenseShape().GetDims();
#else
ORT_ENFORCE(ort_value->IsTensor(), "Only OrtValues that are Tensors are supported in this build");
const auto dims = ort_value->Get<Tensor>().Shape().GetDims();
#endif
py::list shape_arr;
for (auto dim : dims) {
// For sequence tensors - we would append a list of dims to the outermost list
// For now only tensors are supported in OrtValue
shape_arr.append(dim);
}
return shape_arr;
})
.def("data_type", [](const OrtValue* ort_value) -> std::string {
const ONNX_NAMESPACE::TypeProto* type_proto;
// Handle gutless types first to get the actual type
if (ort_value->IsTensor()) {
auto elem_type = ort_value->Get<Tensor>().GetElementType();
type_proto = DataTypeImpl::TensorTypeFromONNXEnum(elem_type)->GetTypeProto();
#if !defined(DISABLE_SPARSE_TENSORS)
} else if (ort_value->IsSparseTensor()) {
auto elem_type = ort_value->Get<SparseTensor>().GetElementType();
type_proto = DataTypeImpl::SparseTensorTypeFromONNXEnum(elem_type)->GetTypeProto();
#endif
} else if (ort_value->IsTensorSequence()) {
auto elem_type = ort_value->Get<TensorSeq>().DataType()->AsPrimitiveDataType()->GetDataType();
type_proto = DataTypeImpl::SequenceTensorTypeFromONNXEnum(elem_type)->GetTypeProto();
} else {
// Plane sequences and maps probably have their specific type
type_proto = ort_value->Type()->GetTypeProto();
}
ORT_ENFORCE(type_proto != nullptr, "Unknown type of OrtValue: ", ort_value->Type());
return *ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(*type_proto);
})
.def("element_type", [](const OrtValue* ort_value) -> int32_t { return GetTensorProtoType(*ort_value); },
"Returns an integer equal to the ONNX tensor proto type of the tensor or sequence. "
"This integer is one type defined by ONNX TensorProto_DataType "
"(such as onnx.TensorProto.FLOAT)."
"Raises an exception in any other case.")
.def("tensor_size_in_bytes", [](const OrtValue* ort_value) -> size_t {
ORT_ENFORCE(ort_value->IsTensor(), "Only OrtValues that are Tensors are currently supported");
return ort_value->Get<Tensor>().SizeInBytes(); }, "Returns tensor size in bytes.")
.def("has_value", [](const OrtValue* ort_value) -> bool { return ort_value->IsAllocated(); })
.def("is_tensor", [](const OrtValue* ort_value) -> bool { return ort_value->IsTensor(); })
.def("is_sparse_tensor", [](const OrtValue* ort_value) -> bool { return ort_value->IsSparseTensor(); })
.def("is_tensor_sequence", [](const OrtValue* ort_value) -> bool { return ort_value->IsTensorSequence(); })
// Converts Tensor into a numpy array
.def("numpy", [](const OrtValue* ml_value) -> py::object {
ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are convertible to Numpy objects");
[[maybe_unused]] const auto& device = ml_value->Get<Tensor>().Location().device;
#ifdef _MSC_VER
// The switch statement may only contain the 'default' label. In such a case, the MSVC compiler
// will warn about it, and since the warnings are treated as errors, the compilation will break.
// Below pragmas turn off warning generation for this switch only.
#pragma warning(push)
#pragma warning(disable : 4065)
#endif
switch (device.Vendor()) {
#ifdef USE_CUDA
case OrtDevice::VendorIds::NVIDIA:
return GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction(device),
/*zero_copy_non_owning=*/true);
#endif
#ifdef USE_CANN
case OrtDevice::VendorIds::HUAWEI:
return GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction(),
/*zero_copy_non_owning=*/true);
#endif
#ifdef USE_DML
case OrtDevice::VendorIds::MICROSOFT:
return GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction(device),
/*zero_copy_non_owning=*/true);
#endif
#ifdef USE_MIGRAPHX
case OrtDevice::VendorIds::AMD:
return GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction(device),
/*zero_copy_non_owning=*/true);
#endif
default:
// OrtValue.numpy() is called by the user who explicitly holds the OrtValue
// Python object, so the backing memory lifetime is managed externally.
// zero_copy_non_owning=true is safe here (and required to preserve the
// zero-copy semantics that OrtValue.numpy() / __array__ rely on).
return GetPyObjFromTensor(*ml_value, nullptr, nullptr,
/*zero_copy_non_owning=*/true);
}
#ifdef _MSC_VER
#pragma warning(pop)
#endif
})
#if defined(ENABLE_DLPACK)
.def("to_dlpack", [](OrtValue* ort_value) -> py::object { return py::reinterpret_steal<py::object>(ToDlpack(*ort_value)); },
"Returns a DLPack representing the tensor. This method does not copy the pointer shape, "
"instead, it copies the pointer value. The OrtValue must be persist until the dlpack structure "
"is consumed.")
.def_static("from_dlpack", [](py::object data, bool is_bool_tensor) { return FromDlpack(data.ptr(), is_bool_tensor); }, py::arg("data"), py::arg("is_bool_tensor") = false, "Converts a tensor from a external library into an OrtValue by means of the __dlpack__ protocol.")
.def("__dlpack__", [](OrtValue* ort_value, py::object /* stream */) -> py::object { return py::reinterpret_steal<py::object>(ToDlpack(*ort_value)); }, py::arg("stream") = py::none(),
"Returns a DLPack representing the tensor (part of __dlpack__ protocol). "
"This method does not copy the pointer shape, instead, it copies the pointer value. "
"The OrtValue must persist until the dlpack structure is consumed.")
.def("__dlpack_device__", [](const OrtValue* ort_value) -> py::tuple {
ORT_ENFORCE(ort_value->IsTensor(), "Only tensor type OrtValues are supported");
const onnxruntime::Tensor& tensor = ort_value->Get<Tensor>();
DLDevice device = onnxruntime::dlpack::GetDlpackDevice(*ort_value, tensor.Location().device.Id());
return py::make_tuple(static_cast<int>(device.device_type), device.device_id); }, "Returns a tuple of integers, (device, device index) (part of __dlpack__ protocol).")
#endif
;
py::class_<std::vector<OrtValue>>(m, "OrtValueVector")
.def(py::init<>())
.def("push_back", [](std::vector<OrtValue>* v, const OrtValue& ortvalue) {
v->push_back(ortvalue);
})
#if defined(ENABLE_DLPACK)
.def("push_back", [](std::vector<OrtValue>* v, py::object dlpack_tensor, const bool is_bool_tensor) { v->push_back(FromDlpack(dlpack_tensor.ptr(), is_bool_tensor)); }, "Add a new OrtValue after being ownership was transferred from the DLPack structure.", py::arg("dlpack_tensor"), py::arg("is_bool_tensor") = false)
.def("push_back_batch", [](std::vector<OrtValue>* v, std::vector<py::object>& torch_tensors, std::vector<int64_t>& data_ptrs, std::vector<py::object>& element_types, const std::vector<std::vector<int64_t>>& shapes, const std::vector<OrtDevice>& devices) {
for (size_t i = 0; i < torch_tensors.size(); ++i) {
py::object& element_type = element_types.at(i);
const std::vector<int64_t>& shape = shapes.at(i);
int64_t data_ptr = data_ptrs.at(i);
ORT_ENFORCE(data_ptr, "Pointer to data memory is not valid");
PyArray_Descr* dtype;
if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) {
throw std::runtime_error("Not a valid numpy type");
}
int type_num = dtype->type_num;
Py_DECREF(dtype);
auto ml_type = NumpyTypeToOnnxRuntimeTensorType(type_num);
auto device = devices.at(i);
OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device);
OrtValue ml_value;
Tensor::InitOrtValue(ml_type, gsl::make_span(shape), reinterpret_cast<void*>(data_ptr), info, ml_value);
v->push_back(ml_value);
} }, "Add a batch of OrtValue's by wrapping PyTorch tensors.")
#endif
.def("reserve", [](std::vector<OrtValue>* v, const size_t len) { v->reserve(len); })
.def("shrink_to_fit", [](std::vector<OrtValue>* v) { v->shrink_to_fit(); })
.def("__len__", [](const std::vector<OrtValue>& v) { return v.size(); })
.def("__iter__", [](const std::vector<OrtValue>& v) { return py::make_iterator(v.cbegin(), v.cend()); }, py::keep_alive<0, 1>())
.def("__getitem__", [](const std::vector<OrtValue>& v, const size_t idx) { return v.at(idx); })
.def("bool_tensor_indices", [](std::vector<OrtValue>* v) -> std::vector<int64_t> {
std::vector<int64_t> indices;
for (size_t i = 0; i < v->size(); ++i) {
if (GetTensorProtoType((*v)[i]) == ONNX_NAMESPACE::TensorProto_DataType_BOOL) {
indices.push_back(static_cast<int64_t>(i));
}
}
return indices; },
"Returns the indices of every boolean tensor in this vector of OrtValue. "
"In case of a boolean tensor, method to_dlpacks returns a uint8 tensor instead of a boolean tensor. "
"If torch consumes the dlpack structure, `.to(torch.bool)` must be applied to the torch tensor "
"to get a boolean tensor.")
#if defined(ENABLE_DLPACK)
.def("dlpack_at", [](std::vector<OrtValue>* v, const size_t idx) { return py::reinterpret_steal<py::object>(ToDlpack(v->at(idx))); })
#endif
.def("element_type_at", [](std::vector<OrtValue>* v, const size_t idx) -> int32_t { return GetTensorProtoType(v->at(idx)); },
"Returns an integer equal to the ONNX proto type of the tensor at position i. "
"This integer is one type defined by ONNX TensorProto_DataType "
"(such as onnx.TensorProto.FLOAT)."
"Raises an exception in any other case.",
py::arg("idx"))
#if defined(ENABLE_DLPACK)
.def("to_dlpacks", [](const std::vector<OrtValue>& v, py::object to_tensor) -> py::list {
if (v.size() == 0)
return py::list();
py::list list_dlpacks;
PyObject* obj;
py::gil_scoped_acquire acquire;
if (to_tensor.is_none()) {
DLManagedTensor* dlmanaged_tensor;
for (auto it : v) {
dlmanaged_tensor = dlpack::OrtValueToDlpack(it);
py::capsule capsule(dlmanaged_tensor, "dltensor", DlpackCapsuleDestructor);
list_dlpacks.append(capsule);
}
} else {
DLManagedTensor* dlmanaged_tensor;
PyObject* capsule = NULL;
PyObject* handle = to_tensor.ptr();
for (auto it : v) {
// A new instance of dlpack needs to be created. The object which consumes it
// is responsible for its deletion.
dlmanaged_tensor = dlpack::OrtValueToDlpack(it);
if (capsule == NULL) {
capsule = PyCapsule_New(dlmanaged_tensor, "dltensor", NULL);
if (capsule == NULL)
throw std::runtime_error("Unexpected error: empty capsule returned.");
} else {
// The same capsule is reused but FromDLPack rename the capsule into used_dltensor.
PyCapsule_SetName(capsule, "dltensor");
PyCapsule_SetPointer(capsule, dlmanaged_tensor);
}
obj = PyObject_CallFunctionObjArgs(handle, capsule, NULL);
if (obj == NULL)
throw std::runtime_error("to_tensor returned a null pointer. This may be caused by the data conversion.");
list_dlpacks.append(obj);
Py_DECREF(obj);
}
if (capsule != NULL) {
// This test is never wrong because v is not empty if the execution goes through that path.
// If not present, Guardian detects a potential failure.
Py_DECREF(capsule);
}
}
return list_dlpacks; },
R"pbdoc(Converts all OrtValue into tensors through DLPack protocol, the method creates
a DLPack structure for every tensors, then calls python function `to_tensor` to a new object
consuming the DLPack structure or return a list of capsule if this function is None.
:param to_tensor: this function takes a capsule holding a pointer onto a DLPack structure and returns
a new tensor which becomes the new owner of the data. This function takes one python object and
returns a new python object. It fits the same signature as `torch.utils.from_dlpack`,
if None, the method returns a capsule for every new DLPack structure.
:return: a list containing the new tensors or a the new capsules if *to_tensor* is None
This method is used to replace `tuple(torch._C._from_dlpack(ov.to_dlpack()) for ov in ort_values)`
by a faster instruction `tuple(ort_values.to_dlpack(torch._C._from_dlpack))`. This loop
is difficult to parallelize as it goes through the GIL many times.
It creates many tensors acquiring ownership of existing OrtValue.
This method saves one object creation and an C++ allocation
for every transferred tensor.
)pbdoc",
py::arg("to_tensor"))
#endif
;
#if defined(ENABLE_DLPACK)
m.def(
"is_dlpack_uint8_tensor", [](py::capsule cap) -> bool {
// case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
// dtype.code = DLDataTypeCode::kDLUInt;
// dtype.bits = sizeof(bool);
DLManagedTensor* dlmanaged_tensor = (DLManagedTensor*)cap.get_pointer();
return dlmanaged_tensor->dl_tensor.dtype.code == DLDataTypeCode::kDLUInt && dlmanaged_tensor->dl_tensor.dtype.bits == 8;
},
"Tells if a DLPack structure is a uint8 tensor.\n"
".. note::\n"
" Boolean tensors are also uint8 tensor once converted with DLPack protocol.");
#endif
}
} // namespace python
} // namespace onnxruntime