-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtest_numpy2p0.pyi
More file actions
100 lines (77 loc) · 2.82 KB
/
test_numpy2p0.pyi
File metadata and controls
100 lines (77 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# mypy: disable-error-code="no-redef"
from types import ModuleType
from typing import Any, TypeAlias, assert_type
import numpy as np
import numpy.typing as npt
import array_api_typing as xpt
# DType aliases
F32: TypeAlias = np.float32
I32: TypeAlias = np.int32
B: TypeAlias = np.bool_
# Define NDArrays against which we can test the protocols
nparr: npt.NDArray[Any]
nparr_i32: npt.NDArray[I32]
nparr_f32: npt.NDArray[F32]
nparr_b: npt.NDArray[B]
# =========================================================
# `xpt.HasArrayNamespace`
# Check assignment
_: xpt.HasArrayNamespace[ModuleType] = nparr
_: xpt.HasArrayNamespace[ModuleType] = nparr_i32
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32
_: xpt.HasArrayNamespace[ModuleType] = nparr_b
# Check `__array_namespace__` method
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
ns: ModuleType = a_ns.__array_namespace__()
# Incorrect values are caught when using `__array_namespace__` and
# backpropagated to the type of `a_ns`
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
# =========================================================
# `xpt.HasDType`
# Check DTypeT_co assignment
_: xpt.HasDType[Any] = nparr
_: xpt.HasDType[np.dtype[I32]] = nparr_i32
_: xpt.HasDType[np.dtype[F32]] = nparr_f32
_: xpt.HasDType[np.dtype[B]] = nparr_b
# =========================================================
# `xpt.Array`
# Check NamespaceT_co assignment
a_ns: xpt.Array[Any, Any, ModuleType] = nparr
# Check DTypeT_co assignment
_: xpt.Array[Any] = nparr
x_f32: xpt.Array[np.dtype[F32]] = nparr_f32
x_i32: xpt.Array[np.dtype[I32]] = nparr_i32
x_b: xpt.Array[np.dtype[B]] = nparr_b
# Check Attribute `.dtype`
assert_type(x_f32.dtype, np.dtype[F32])
assert_type(x_i32.dtype, np.dtype[I32])
assert_type(x_b.dtype, np.dtype[B])
# Check DeviceT_co assignment
x_gooddevice: xpt.Array[Any, object, Any] = nparr
assert_type(x_gooddevice.device, object)
x_baddevice: xpt.Array[Any, int, Any] = nparr # type: ignore[assignment]
_: int = x_baddevice.device
# Check Attribute `.device`
assert_type(x_f32.device, object)
assert_type(x_i32.device, object)
assert_type(x_b.device, object)
# Check Attribute `.mT`
assert_type(x_f32.mT, xpt.Array[np.dtype[F32]])
assert_type(x_i32.mT, xpt.Array[np.dtype[I32]])
assert_type(x_b.mT, xpt.Array[np.dtype[B]])
# Check Attribute `.ndim`
assert_type(x_f32.ndim, int)
assert_type(x_i32.ndim, int)
assert_type(x_b.ndim, int)
# Check Attribute `.shape`
assert_type(x_f32.shape, tuple[int | None, ...])
assert_type(x_i32.shape, tuple[int | None, ...])
assert_type(x_b.shape, tuple[int | None, ...])
# Check Attribute `.size`
assert_type(x_f32.size, int | None)
assert_type(x_i32.size, int | None)
assert_type(x_b.size, int | None)
# Check Attribute `.T`
assert_type(x_f32.T, xpt.Array[np.dtype[F32]])
assert_type(x_i32.T, xpt.Array[np.dtype[I32]])
assert_type(x_b.T, xpt.Array[np.dtype[B]])