-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtest_numpy2p0.pyi
More file actions
111 lines (84 loc) · 2.89 KB
/
test_numpy2p0.pyi
File metadata and controls
111 lines (84 loc) · 2.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# mypy: disable-error-code="no-redef"
from types import ModuleType
from typing import Any, TypeAlias, assert_type
import numpy as np
import numpy.typing as npt
import array_api_typing as xpt
# DType aliases
F32: TypeAlias = np.float32
I32: TypeAlias = np.int32
B: TypeAlias = np.bool_
# Define NDArrays against which we can test the protocols
nparr: npt.NDArray[Any]
nparr_i32: npt.NDArray[I32]
nparr_f32: npt.NDArray[F32]
nparr_b: npt.NDArray[B]
# =========================================================
# `xpt.HasArrayNamespace`
# Check assignment
_: xpt.HasArrayNamespace[ModuleType] = nparr
_: xpt.HasArrayNamespace[ModuleType] = nparr_i32
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32
_: xpt.HasArrayNamespace[ModuleType] = nparr_b
# Check `__array_namespace__` method
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
ns: ModuleType = a_ns.__array_namespace__()
# Incorrect values are caught when using `__array_namespace__` and
# backpropagated to the type of `a_ns`
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
# =========================================================
# `xpt.HasDType`
# Check DTypeT_co assignment
_: xpt.HasDType[Any] = nparr
_: xpt.HasDType[np.dtype[I32]] = nparr_i32
_: xpt.HasDType[np.dtype[F32]] = nparr_f32
_: xpt.HasDType[np.dtype[B]] = nparr_b
# =========================================================
# `xpt.HasDevice`
_: xpt.HasDevice = nparr
_: xpt.HasDevice = nparr_i32
_: xpt.HasDevice = nparr_f32
_: xpt.HasDevice = nparr_b
# =========================================================
# `xpt.HasMatrixTranspose`
_: xpt.HasMatrixTranspose = nparr
_: xpt.HasMatrixTranspose = nparr_i32
_: xpt.HasMatrixTranspose = nparr_f32
_: xpt.HasMatrixTranspose = nparr_b
# =========================================================
# `xpt.HasNDim`
_: xpt.HasNDim = nparr
_: xpt.HasNDim = nparr_i32
_: xpt.HasNDim = nparr_f32
_: xpt.HasNDim = nparr_b
# =========================================================
# `xpt.HasShape`
_: xpt.HasShape = nparr
_: xpt.HasShape = nparr_i32
_: xpt.HasShape = nparr_f32
_: xpt.HasShape = nparr_b
# =========================================================
# `xpt.HasSize`
_: xpt.HasSize = nparr
_: xpt.HasSize = nparr_i32
_: xpt.HasSize = nparr_f32
_: xpt.HasSize = nparr_b
# =========================================================
# `xpt.HasTranspose`
_: xpt.HasTranspose = nparr
_: xpt.HasTranspose = nparr_i32
_: xpt.HasTranspose = nparr_f32
_: xpt.HasTranspose = nparr_b
# =========================================================
# `xpt.Array`
# Check NamespaceT_co assignment
a_ns: xpt.Array[Any, ModuleType] = nparr
# Check DTypeT_co assignment
_: xpt.Array[Any] = nparr
x_f32: xpt.Array[np.dtype[F32]] = nparr_f32
x_i32: xpt.Array[np.dtype[I32]] = nparr_i32
x_b: xpt.Array[np.dtype[B]] = nparr_b
# Check Attribute `.dtype`
assert_type(x_f32.dtype, np.dtype[F32])
assert_type(x_i32.dtype, np.dtype[I32])
assert_type(x_b.dtype, np.dtype[B])