Skip to content

Commit 16ab8df

Browse files
committed
make arrays pickleable
1 parent 237e1f5 commit 16ab8df

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

array_api_strict/_array_object.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from collections.abc import Iterator
2121
from enum import IntEnum
2222
from types import EllipsisType, ModuleType
23-
from typing import Any, Final, Literal, SupportsIndex
23+
from typing import Any, Final, Literal, SupportsIndex, Callable
2424

2525
import numpy as np
2626
import numpy.typing as npt
@@ -125,6 +125,9 @@ def __new__(cls, *args: object, **kwargs: object) -> Array:
125125
raise TypeError(
126126
"The array_api_strict Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead."
127127
)
128+
129+
def __reduce__(self) -> tuple[Callable, tuple[npt.NDArray[Any], Device]]:
130+
return (self._new, (self._array, self._device))
128131

129132
# These functions are not required by the spec, but are implemented for
130133
# the sake of usability.

array_api_strict/tests/test_array_object.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
22
import warnings
33
import operator
4+
import pickle
45
from builtins import all as all_
56

67
from numpy.testing import assert_raises
@@ -747,3 +748,15 @@ def test_dlpack_2023_12(api_version):
747748
a.__dlpack__(copy=False)
748749
a.__dlpack__(copy=True)
749750
a.__dlpack__(copy=None)
751+
752+
def test_pickle():
753+
"""Check that arrays are pickleable (despite raising on `__new__`)"""
754+
a = ones(2)
755+
min_supported_protocol = 2
756+
for protocol in range(min_supported_protocol, pickle.HIGHEST_PROTOCOL + 1):
757+
bytes = pickle.dumps(a, protocol=protocol)
758+
a_from_pickle = pickle.loads(bytes)
759+
assert a_from_pickle.device == a.device
760+
assert a_from_pickle.dtype == a.dtype
761+
assert a_from_pickle.shape == a.shape
762+
assert all(a_from_pickle == a)

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ classifiers = [
2323
"Operating System :: OS Independent",
2424
]
2525

26+
[project.optional-dependencies]
27+
test = ["pytest", "hypothesis"]
28+
2629
[project.urls]
2730
Homepage = "https://data-apis.org/array-api-strict/"
2831
Repository = "https://github.com/data-apis/array-api-strict"

0 commit comments

Comments
 (0)