Skip to content

Commit b73b3dd

Browse files
authored
Add __array__ and DLPack protocols to OrtValue (#27980)
## Summary - Add `__array__`, `__dlpack__`, `__dlpack_device__`, and `from_dlpack` to the public `OrtValue` class - Enable standard Python interoperability protocols (numpy array protocol + DLPack) on `OrtValue` - Auto-detect boolean dtype from source objects in `from_dlpack` to avoid the uint8/bool ambiguity in older DLPack versions ## Motivation Fixes #24071 The C-level `C.OrtValue` already supports `__dlpack__`, `__dlpack_device__`, and `from_dlpack`, but the public Python wrapper `OrtValue` class does not expose them. Users currently have to access the private `_ortvalue` attribute (e.g. `ortvalue._ortvalue.__dlpack__()`) for DLPack interop. Similarly, `np.asarray(ortvalue)` doesn't work because `__array__` is not implemented. This makes `OrtValue` a well-behaved tensor type that works out of the box with: - `np.asarray(ortvalue)` / `np.array(ortvalue)` via `__array__` - `torch.from_dlpack(ortvalue)` via `__dlpack__` / `__dlpack_device__` - `OrtValue.from_dlpack(torch_tensor)` via the `from_dlpack` classmethod ## Changes **`onnxruntime/python/onnxruntime_inference_collection.py`**: - `__array__(dtype, copy)`: Delegates to `self.numpy()` with optional dtype conversion. Supports numpy 2.0 `copy` semantics while remaining compatible with older numpy versions. - `__dlpack__(*, stream)`: Thin wrapper over the C-level `__dlpack__`. - `__dlpack_device__()`: Thin wrapper over the C-level `__dlpack_device__`. - `from_dlpack(data)`: Classmethod that accepts any `__dlpack__`-compatible object or raw DLPack capsule. Detects boolean dtype from the source object's `dtype` attribute or `data_type()` method, avoiding the uint8/bool false-positive that `is_dlpack_uint8_tensor` would produce on genuine uint8 data. **`onnxruntime/test/python/onnxruntime_test_python.py`**: - `test_ort_value_array_protocol`: Tests `np.asarray`/`np.array` with float32, int64, bool dtypes, and dtype conversion. - `test_ort_value_dlpack_protocol`: Tests `__dlpack__` and `__dlpack_device__` on the public class. - `test_ort_value_from_dlpack_protocol_object`: Tests `from_dlpack` with numpy arrays and OrtValue-to-OrtValue round-trip, verifying zero-copy (shared memory). - `test_ort_value_from_dlpack_bool`: Tests bool round-trip and verifies uint8 is not falsely detected as bool. ## Test Plan - [x] `ruff check` passes on both modified files - [x] `ruff format --check` passes on both modified files - [x] `lintrunner` reports no issues - [x] Existing `test_ort_value_dlpack` test continues to pass - [x] All logic paths verified against C-level bindings (bool detection, dtype conversion, shared memory) - [ ] CI: new tests pass against a full build with DLPack enabled
1 parent cd48875 commit b73b3dd

2 files changed

Lines changed: 194 additions & 0 deletions

File tree

onnxruntime/python/onnxruntime_inference_collection.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,109 @@ def numpy(self) -> np.ndarray:
11991199
"""
12001200
return self._ortvalue.numpy()
12011201

1202+
def __array__(self, dtype=None, copy=None) -> np.ndarray:
1203+
"""
1204+
Supports ``numpy.asarray(ortvalue)`` and ``numpy.array(ortvalue)`` via the
1205+
`numpy __array__ protocol <https://numpy.org/devdocs/user/basics.interoperability.html>`_.
1206+
1207+
Valid only for OrtValues holding Tensors on CPU.
1208+
1209+
:param dtype: Optional numpy dtype to cast the result to.
1210+
:param copy: Optional bool (numpy >= 2.0). If ``False``, a copy will
1211+
only be made if necessary. If ``True``, a copy is always forced.
1212+
If ``None`` (default), a copy will be made only if needed.
1213+
:return: A numpy array with the same data as the OrtValue.
1214+
"""
1215+
import numpy as np # noqa: PLC0415
1216+
1217+
arr = self.numpy()
1218+
1219+
if copy is not None:
1220+
# numpy >= 2.0 added the copy kwarg to np.asarray;
1221+
# np.array has always accepted it but with weaker semantics pre-2.0.
1222+
arr = np.array(arr, dtype=dtype, copy=copy)
1223+
elif dtype is not None:
1224+
# np.asarray avoids a copy when the dtype already matches,
1225+
# preserving memory sharing with the underlying OrtValue.
1226+
arr = np.asarray(arr, dtype=dtype)
1227+
1228+
return arr
1229+
1230+
def __dlpack__(self, *, stream=None):
1231+
"""
1232+
Returns a DLPack capsule representing the tensor (part of the
1233+
`DLPack protocol <https://dmlc.github.io/dlpack/latest/>`_).
1234+
1235+
This enables interoperability with other frameworks via
1236+
``from_dlpack(ortvalue)`` (e.g. ``torch.from_dlpack``,
1237+
``jax.dlpack.from_dlpack``, ``numpy.from_dlpack``).
1238+
1239+
The OrtValue must hold a contiguous tensor. No data is copied;
1240+
the consumer shares memory with this OrtValue, which must remain
1241+
alive while the capsule is in use.
1242+
1243+
:param stream: Optional stream on which the tensor data is accessible.
1244+
Currently unused; included for protocol compliance.
1245+
:return: A PyCapsule holding a DLManagedTensor.
1246+
"""
1247+
return self._ortvalue.__dlpack__(stream=stream)
1248+
1249+
def __dlpack_device__(self) -> tuple[int, int]:
1250+
"""
1251+
Returns ``(device_type, device_id)`` indicating where the tensor data
1252+
resides (part of the `DLPack protocol
1253+
<https://dmlc.github.io/dlpack/latest/>`_).
1254+
1255+
:return: Tuple of ``(device_type, device_id)`` as ints following DLPack
1256+
``DLDeviceType`` enum values.
1257+
"""
1258+
return self._ortvalue.__dlpack_device__()
1259+
1260+
@classmethod
1261+
def from_dlpack(cls, data, /) -> OrtValue:
1262+
"""
1263+
Construct an OrtValue from an object that implements the DLPack protocol.
1264+
1265+
Accepts either:
1266+
1267+
* An object with ``__dlpack__`` / ``__dlpack_device__`` methods
1268+
(e.g. a PyTorch tensor, JAX array, or numpy array).
1269+
* A raw DLPack PyCapsule (legacy path).
1270+
1271+
Boolean tensors are automatically detected when the source object
1272+
exposes a ``dtype`` attribute (numpy, PyTorch, etc.) or is an
1273+
``OrtValue``. For raw DLPack capsules where the original dtype cannot
1274+
be inspected, bool tensors encoded as uint8 by older DLPack versions
1275+
are not distinguishable from true uint8 tensors and will be imported
1276+
as uint8.
1277+
1278+
No data is copied; the new OrtValue shares memory with the source.
1279+
1280+
:param data: A tensor object supporting the DLPack protocol, or a raw
1281+
DLPack PyCapsule.
1282+
:return: An OrtValue wrapping the tensor data.
1283+
"""
1284+
# Detect boolean dtype from the source object before consuming it,
1285+
# because DLPack encodes bool as uint8 and the capsule alone cannot
1286+
# distinguish between the two.
1287+
is_bool = False
1288+
if isinstance(data, OrtValue):
1289+
is_bool = data.data_type() == "tensor(bool)"
1290+
elif hasattr(data, "dtype"):
1291+
dtype_obj = data.dtype
1292+
# Use .name when available (numpy, cupy, tensorflow all expose it).
1293+
# Fall back to str() for frameworks that don't (e.g. PyTorch).
1294+
dtype_name = getattr(dtype_obj, "name", str(dtype_obj))
1295+
is_bool = dtype_name in ("bool", "bool_", "torch.bool")
1296+
1297+
# If the input supports the __dlpack__ protocol, call it to get the capsule.
1298+
if hasattr(data, "__dlpack__"):
1299+
capsule = data.__dlpack__()
1300+
else:
1301+
capsule = data
1302+
1303+
return cls(C.OrtValue.from_dlpack(capsule, is_bool))
1304+
12021305
def update_inplace(self, np_arr) -> None:
12031306
"""
12041307
Update the OrtValue in place with a new Numpy array. The numpy contents

onnxruntime/test/python/onnxruntime_test_python.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,6 +1475,97 @@ def test_ort_value_dlpack_zero_size(self):
14751475
ortvalue2 = C.OrtValue.from_dlpack(dlp2, False)
14761476
self.assertEqual(list(shape), list(ortvalue2.shape()))
14771477

1478+
def test_ort_value_array_protocol(self):
1479+
"""Test that OrtValue supports numpy's __array__ protocol."""
1480+
numpy_arr = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
1481+
ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(numpy_arr)
1482+
1483+
# np.asarray should work via __array__ and share memory (zero-copy)
1484+
result = np.asarray(ortvalue)
1485+
np.testing.assert_equal(numpy_arr, result)
1486+
self.assertEqual(result.dtype, np.float32)
1487+
self.assertEqual(ortvalue.data_ptr(), result.ctypes.data)
1488+
1489+
# np.array should also work
1490+
result2 = np.array(ortvalue)
1491+
np.testing.assert_equal(numpy_arr, result2)
1492+
1493+
# same dtype should still share memory (no unnecessary copy)
1494+
result_same = np.asarray(ortvalue, dtype=np.float32)
1495+
np.testing.assert_equal(numpy_arr, result_same)
1496+
self.assertEqual(ortvalue.data_ptr(), result_same.ctypes.data)
1497+
1498+
# dtype conversion via __array__
1499+
result_f64 = np.asarray(ortvalue, dtype=np.float64)
1500+
np.testing.assert_equal(numpy_arr.astype(np.float64), result_f64)
1501+
self.assertEqual(result_f64.dtype, np.float64)
1502+
1503+
# Integer tensor
1504+
int_arr = np.array([1, 2, 3], dtype=np.int64)
1505+
ortvalue_int = onnxrt.OrtValue.ortvalue_from_numpy(int_arr)
1506+
result_int = np.asarray(ortvalue_int)
1507+
np.testing.assert_equal(int_arr, result_int)
1508+
self.assertEqual(result_int.dtype, np.int64)
1509+
1510+
# Boolean tensor
1511+
bool_arr = np.array([True, False, True], dtype=np.bool_)
1512+
ortvalue_bool = onnxrt.OrtValue.ortvalue_from_numpy(bool_arr)
1513+
result_bool = np.asarray(ortvalue_bool)
1514+
np.testing.assert_equal(bool_arr, result_bool)
1515+
self.assertEqual(result_bool.dtype, np.bool_)
1516+
1517+
@unittest.skipIf(not hasattr(C.OrtValue, "from_dlpack"), "dlpack not enabled in this build")
1518+
def test_ort_value_dlpack_protocol(self):
1519+
"""Test that OrtValue exposes __dlpack__ and __dlpack_device__ protocols."""
1520+
numpy_arr = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
1521+
ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(numpy_arr)
1522+
1523+
# __dlpack_device__ should return (device_type, device_id) for CPU
1524+
device = ortvalue.__dlpack_device__()
1525+
self.assertEqual((1, 0), device)
1526+
1527+
# __dlpack__ should return a capsule that can be consumed by from_dlpack
1528+
dlp = ortvalue.__dlpack__()
1529+
ortvalue2 = onnxrt.OrtValue.from_dlpack(dlp)
1530+
np.testing.assert_equal(numpy_arr, ortvalue2.numpy())
1531+
1532+
@unittest.skipIf(not hasattr(C.OrtValue, "from_dlpack"), "dlpack not enabled in this build")
1533+
def test_ort_value_from_dlpack_protocol_object(self):
1534+
"""Test OrtValue.from_dlpack with objects implementing __dlpack__ protocol."""
1535+
numpy_arr = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
1536+
1537+
# numpy arrays support __dlpack__ protocol since numpy 1.22
1538+
if hasattr(numpy_arr, "__dlpack__"):
1539+
ortvalue = onnxrt.OrtValue.from_dlpack(numpy_arr)
1540+
np.testing.assert_equal(numpy_arr, ortvalue.numpy())
1541+
self.assertEqual(list(numpy_arr.shape), list(ortvalue.shape()))
1542+
1543+
# Round-trip: numpy -> OrtValue -> OrtValue (via __dlpack__)
1544+
ortvalue_src = onnxrt.OrtValue.ortvalue_from_numpy(numpy_arr)
1545+
ortvalue_dst = onnxrt.OrtValue.from_dlpack(ortvalue_src)
1546+
np.testing.assert_equal(numpy_arr, ortvalue_dst.numpy())
1547+
# Verify shared memory (no copy)
1548+
self.assertEqual(ortvalue_src.data_ptr(), ortvalue_dst.data_ptr())
1549+
1550+
@unittest.skipIf(not hasattr(C.OrtValue, "from_dlpack"), "dlpack not enabled in this build")
1551+
def test_ort_value_from_dlpack_bool(self):
1552+
"""Test that from_dlpack auto-detects boolean tensors."""
1553+
bool_arr = np.array([True, False, True, False], dtype=np.bool_)
1554+
ortvalue_src = onnxrt.OrtValue.ortvalue_from_numpy(bool_arr)
1555+
1556+
# Round-trip through DLPack should preserve bool dtype
1557+
ortvalue_dst = onnxrt.OrtValue.from_dlpack(ortvalue_src)
1558+
result = ortvalue_dst.numpy()
1559+
np.testing.assert_equal(bool_arr, result)
1560+
1561+
# Ensure uint8 is NOT falsely detected as bool
1562+
uint8_arr = np.array([1, 2, 255], dtype=np.uint8)
1563+
ortvalue_uint8 = onnxrt.OrtValue.ortvalue_from_numpy(uint8_arr)
1564+
ortvalue_uint8_dst = onnxrt.OrtValue.from_dlpack(ortvalue_uint8)
1565+
result_uint8 = ortvalue_uint8_dst.numpy()
1566+
np.testing.assert_equal(uint8_arr, result_uint8)
1567+
self.assertEqual(result_uint8.dtype, np.uint8)
1568+
14781569
def test_sparse_tensor_coo_format(self):
14791570
cpu_device = onnxrt.OrtDevice.make("cpu", 0)
14801571
shape = [9, 9]

0 commit comments

Comments
 (0)