Skip to content

Commit d31c2de

Browse files
committed
Update third_party/cupy/testing/_helper.py
1 parent 6a547b9 commit d31c2de

1 file changed

Lines changed: 65 additions & 36 deletions

File tree

dpnp/tests/third_party/cupy/testing/_helper.py

Lines changed: 65 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
1+
from __future__ import annotations
2+
13
import contextlib
24
import importlib.metadata
35
import inspect
46
import unittest
57
import warnings
8+
from collections.abc import Callable
69
from importlib.metadata import PackageNotFoundError
7-
from typing import Callable
810
from unittest import mock
911

1012
import numpy
1113

1214
import dpnp as cupy
13-
from dpnp.tests.third_party.cupy.testing._pytest_impl import is_available
1415

1516
# from cupy._core import internal
16-
# import cupyx
17-
# import cupyx.scipy.sparse
18-
17+
from dpnp.tests.third_party.cupy.testing._pytest_impl import is_available
1918

2019
if is_available():
2120
import pytest
@@ -25,7 +24,7 @@
2524
_skipif = unittest.skipIf
2625

2726

28-
def with_requires(*requirements):
27+
def with_requires(*requirements: str) -> Callable[[Callable], Callable]:
2928
"""Run a test case only when given requirements are satisfied.
3029
3130
.. admonition:: Example
@@ -49,7 +48,7 @@ def with_requires(*requirements):
4948
return _skipif(not installed(*requirements), reason=msg)
5049

5150

52-
def installed(*specifiers):
51+
def installed(*specifiers: str) -> bool:
5352
"""Returns True if the current environment satisfies the specified
5453
package requirement.
5554
@@ -72,13 +71,13 @@ def installed(*specifiers):
7271
return True
7372

7473

75-
def numpy_satisfies(version_range):
74+
def numpy_satisfies(version_range: str) -> bool:
7675
"""Returns True if numpy version satisfies the specified criteria.
7776
7877
Args:
7978
version_range: A version specifier (e.g., `>=1.13.0`).
8079
"""
81-
return installed("numpy{}".format(version_range))
80+
return installed(f"numpy{version_range}")
8281

8382

8483
def shaped_arange(shape, xp=cupy, dtype=numpy.float32, order="C", device=None):
@@ -162,45 +161,72 @@ def shaped_random(
162161
from uniform distribution over :math:`[0, scale)`
163162
with specified dtype.
164163
"""
165-
numpy.random.seed(seed)
164+
rng = numpy.random.RandomState(seed)
166165
dtype = numpy.dtype(dtype)
167166
if dtype == "?":
168-
a = numpy.random.randint(2, size=shape)
167+
a = rng.randint(2, size=shape)
169168
elif dtype.kind == "c":
170-
a = numpy.random.rand(*shape) + 1j * numpy.random.rand(*shape)
169+
a = rng.rand(*shape) + 1j * rng.rand(*shape)
171170
a *= scale
172171
else:
173-
a = numpy.random.rand(*shape) * scale
172+
a = rng.rand(*shape) * scale
174173
return xp.asarray(a, dtype=dtype, order=order)
175174

176175

177-
# def shaped_sparse_random(
178-
# shape, sp=cupyx.scipy.sparse, dtype=numpy.float32,
179-
# density=0.01, format='coo', seed=0):
180-
# """Returns an array filled with random values.
176+
def shaped_sparse_random(
177+
shape, sp=None, dtype=numpy.float32, density=0.01, format="", seed=0
178+
):
179+
"""Returns an array filled with random values.
180+
181+
Args:
182+
shape (tuple): Shape of returned sparse matrix.
183+
sp (scipy.sparse or cupyx.scipy.sparse): Sparse matrix module to use.
184+
dtype (dtype): Dtype of returned sparse matrix.
185+
density (float): Density of returned sparse matrix.
186+
format (str): Format of returned sparse matrix.
187+
seed (int): Random seed.
188+
189+
Returns:
190+
The sparse matrix with given shape, array module,
191+
"""
192+
import cupyx.scipy.sparse
193+
import scipy.sparse
194+
195+
if sp is None:
196+
sp = cupyx.scipy.sparse
197+
n_rows, n_cols = shape
198+
a = scipy.sparse.random(n_rows, n_cols, density, random_state=seed).astype(
199+
dtype
200+
)
201+
202+
try:
203+
return sp.coo_matrix(a).asformat(format)
204+
except AttributeError:
205+
raise ValueError(f"Module {sp} does not have the expected sparse APIs")
181206

182-
# Args:
183-
# shape (tuple): Shape of returned sparse matrix.
184-
# sp (scipy.sparse or cupyx.scipy.sparse): Sparse matrix module to use.
185-
# dtype (dtype): Dtype of returned sparse matrix.
186-
# density (float): Density of returned sparse matrix.
187-
# format (str): Format of returned sparse matrix.
188-
# seed (int): Random seed.
189207

190-
# Returns:
191-
# The sparse matrix with given shape, array module,
192-
# """
193-
# import scipy.sparse
194-
# n_rows, n_cols = shape
195-
# numpy.random.seed(seed)
196-
# a = scipy.sparse.random(n_rows, n_cols, density).astype(dtype)
208+
def shaped_linspace(start, stop, shape, xp=cupy, dtype=numpy.float32):
209+
"""Returns an array with given shape, array module, and dtype.
197210
198-
# if sp is cupyx.scipy.sparse:
199-
# a = cupyx.scipy.sparse.coo_matrix(a)
200-
# elif sp is not scipy.sparse:
201-
# raise ValueError('Unknown module: {}'.format(sp))
211+
Args:
212+
start (int): The starting value.
213+
stop (int): The end value.
214+
shape (tuple of int): Shape of returned ndarray.
215+
xp (numpy or cupy): Array module to use.
216+
dtype (dtype): Dtype of returned ndarray.
202217
203-
# return a.asformat(format)
218+
Returns:
219+
numpy.ndarray or cupy.ndarray:
220+
"""
221+
dtype = numpy.dtype(dtype)
222+
size = numpy.prod(shape)
223+
if dtype == "?":
224+
start = max(start, 0)
225+
stop = min(stop, 1)
226+
elif dtype.kind == "u":
227+
start = max(start, 0)
228+
a = numpy.linspace(start, stop, size)
229+
return xp.array(a.astype(dtype).reshape(shape))
204230

205231

206232
def generate_matrix(
@@ -276,6 +302,7 @@ def assert_warns(expected):
276302

277303

278304
class NumpyAliasTestBase(unittest.TestCase):
305+
279306
@property
280307
def func(self):
281308
raise NotImplementedError()
@@ -290,6 +317,7 @@ def numpy_func(self):
290317

291318

292319
class NumpyAliasBasicTestBase(NumpyAliasTestBase):
320+
293321
def test_argspec(self):
294322
f = inspect.signature
295323
assert f(self.cupy_func) == f(self.numpy_func)
@@ -304,6 +332,7 @@ def test_docstring(self):
304332

305333

306334
class NumpyAliasValuesTestBase(NumpyAliasTestBase):
335+
307336
def test_values(self):
308337
assert self.cupy_func(*self.args) == self.numpy_func(*self.args)
309338

0 commit comments

Comments
 (0)