Skip to content

Commit 21197ee

Browse files
Tests for scalar, refactoring and comments
1 parent 37e7148 commit 21197ee

File tree

2 files changed

+60
-38
lines changed

2 files changed

+60
-38
lines changed

dpnp/dpnp_iface_statistics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,9 @@ def convolve(a, v, mode="full", method="auto"):
433433
Notes
434434
-----
435435
The discrete convolution operation is defined as
436-
436+
437437
.. math:: (a * v)_n = \sum_{m = -\infty}^{\infty} a_m v_{n - m}
438-
438+
439439
It can be shown that a convolution :math:`x(t) * y(t)` in time/space
440440
is equivalent to the multiplication :math:`X(f) Y(f)` in the Fourier
441441
domain, after appropriate padding (padding is necessary to prevent
@@ -458,14 +458,14 @@ def convolve(a, v, mode="full", method="auto"):
458458
>>> v = np.array([0, 1, 0.5], dtype=np.float32)
459459
>>> np.convolve(a, v)
460460
array([0. , 1. , 2.5, 4. , 1.5], dtype=float32)
461-
461+
462462
Only return the middle values of the convolution.
463463
Contains boundary effects, where zeros are taken
464464
into account:
465465
466466
>>> np.convolve(a, v, 'same')
467467
array([1. , 2.5, 4. ], dtype=float32)
468-
468+
469469
The two arrays are of the same length, so there
470470
is only one position where they completely overlap:
471471

dpnp/tests/test_statistics.py

Lines changed: 56 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,17 @@ def test_avg_error(self):
128128

129129

130130
class TestConvolve:
131+
@staticmethod
132+
def _get_kwargs(mode=None, method=None):
133+
dpnp_kwargs = {}
134+
numpy_kwargs = {}
135+
if mode is not None:
136+
dpnp_kwargs["mode"] = mode
137+
numpy_kwargs["mode"] = mode
138+
if method is not None:
139+
dpnp_kwargs["method"] = method
140+
return dpnp_kwargs, numpy_kwargs
141+
131142
def setup_method(self):
132143
numpy.random.seed(0)
133144

@@ -143,13 +154,7 @@ def test_convolve(self, a, v, mode, dtype, method):
143154
ad = dpnp.array(an)
144155
vd = dpnp.array(vn)
145156

146-
dpnp_kwargs = {}
147-
numpy_kwargs = {}
148-
if mode is not None:
149-
dpnp_kwargs["mode"] = mode
150-
numpy_kwargs["mode"] = mode
151-
if method is not None:
152-
dpnp_kwargs["method"] = method
157+
dpnp_kwargs, numpy_kwargs = self._get_kwargs(mode, method)
153158

154159
expected = numpy.convolve(an, vn, **numpy_kwargs)
155160
result = dpnp.convolve(ad, vd, **dpnp_kwargs)
@@ -164,42 +169,27 @@ def test_convolve(self, a, v, mode, dtype, method):
164169
def test_convolve_random(self, a_size, v_size, mode, dtype, method):
165170
if dtype in [numpy.int8, numpy.uint8, numpy.int16, numpy.uint16]:
166171
pytest.skip("avoid overflow.")
167-
if dtype == dpnp.bool:
168-
an = numpy.random.rand(a_size) > 0.9
169-
vn = numpy.random.rand(v_size) > 0.9
170-
else:
171-
an = (100 * numpy.random.rand(a_size)).astype(dtype)
172-
vn = (100 * numpy.random.rand(v_size)).astype(dtype)
173-
174-
if dpnp.issubdtype(dtype, dpnp.complexfloating):
175-
an = an + 1j * (100 * numpy.random.rand(a_size)).astype(dtype)
176-
vn = vn + 1j * (100 * numpy.random.rand(v_size)).astype(dtype)
172+
an = generate_random_numpy_array(
173+
a_size, dtype, low=-3, high=3, probability=0.9
174+
)
175+
vn = generate_random_numpy_array(
176+
v_size, dtype, low=-3, high=3, probability=0.9
177+
)
177178

178179
ad = dpnp.array(an)
179180
vd = dpnp.array(vn)
180181

181-
dpnp_kwargs = {}
182-
numpy_kwargs = {}
183-
if mode is not None:
184-
dpnp_kwargs["mode"] = mode
185-
numpy_kwargs["mode"] = mode
186-
if method is not None:
187-
dpnp_kwargs["method"] = method
182+
dpnp_kwargs, numpy_kwargs = self._get_kwargs(mode, method)
188183

189184
result = dpnp.convolve(ad, vd, **dpnp_kwargs)
190185
expected = numpy.convolve(an, vn, **numpy_kwargs)
191186

192-
rdtype = result.dtype
193-
if dpnp.issubdtype(rdtype, dpnp.integer):
194-
rdtype = dpnp.default_float_type(ad.device)
195-
196187
if method != "fft" and (
197188
dpnp.issubdtype(dtype, dpnp.integer) or dtype == dpnp.bool
198189
):
199190
# For 'direct' and 'auto' methods, we expect exact results for integer types
200191
assert_array_equal(result, expected)
201192
else:
202-
result = result.astype(rdtype)
203193
if method == "direct":
204194
# For 'direct' method we can use standard validation
205195
# acceptable error depends on the kernel size
@@ -210,6 +200,16 @@ def test_convolve_random(self, a_size, v_size, mode, dtype, method):
210200
factor = int(40 * (min(a_size, v_size) ** 0.5))
211201
assert_dtype_allclose(result, expected, factor=factor)
212202
else:
203+
rdtype = result.dtype
204+
if dpnp.issubdtype(rdtype, dpnp.integer):
205+
# 'fft' do its calculations in float
206+
# and 'auto' could use fft
207+
# also assert_dtype_allclose for integer types is
208+
# always check for exact match
209+
rdtype = dpnp.default_float_type(ad.device)
210+
211+
result = result.astype(rdtype)
212+
213213
rtol = 1e-3
214214
atol = 1e-10
215215

@@ -231,6 +231,8 @@ def test_convolve_random(self, a_size, v_size, mode, dtype, method):
231231
# We can tolerate a few such outliers.
232232
max_outliers = 8 if expected.size > 1 else 0
233233
if invalid.sum() > max_outliers:
234+
# we already failed check,
235+
# call assert_dtype_allclose just to report error nicely
234236
assert_dtype_allclose(result, expected, factor=1000)
235237

236238
def test_convolve_mode_error(self):
@@ -249,6 +251,19 @@ def test_convolve_empty(self, a, v):
249251
with pytest.raises(ValueError):
250252
dpnp.convolve(a, v)
251253

254+
@pytest.mark.parametrize("a, v", [([1], 2), (3, [4]), (5, 6)])
255+
def test_convolve_scalar(self, a, v):
256+
an = numpy.asarray(a, dtype=numpy.float32)
257+
vn = numpy.asarray(v, dtype=numpy.float32)
258+
259+
ad = dpnp.asarray(a, dtype=numpy.float32)
260+
vd = dpnp.asarray(v, dtype=numpy.float32)
261+
262+
expected = numpy.convolve(an, vn)
263+
result = dpnp.convolve(ad, vd)
264+
265+
assert_dtype_allclose(result, expected)
266+
252267
@pytest.mark.parametrize(
253268
"a, v",
254269
[
@@ -404,17 +419,12 @@ def test_correlate_random(self, a_size, v_size, mode, dtype, method):
404419
result = dpnp.correlate(ad, vd, **dpnp_kwargs)
405420
expected = numpy.correlate(an, vn, **numpy_kwargs)
406421

407-
rdtype = result.dtype
408-
if dpnp.issubdtype(rdtype, dpnp.integer):
409-
rdtype = dpnp.default_float_type(ad.device)
410-
411422
if method != "fft" and (
412423
dpnp.issubdtype(dtype, dpnp.integer) or dtype == dpnp.bool
413424
):
414425
# For 'direct' and 'auto' methods, we expect exact results for integer types
415426
assert_array_equal(result, expected)
416427
else:
417-
result = result.astype(rdtype)
418428
if method == "direct":
419429
expected = numpy.correlate(an, vn, **numpy_kwargs)
420430
# For 'direct' method we can use standard validation
@@ -426,6 +436,16 @@ def test_correlate_random(self, a_size, v_size, mode, dtype, method):
426436
factor = int(40 * (min(a_size, v_size) ** 0.5))
427437
assert_dtype_allclose(result, expected, factor=factor)
428438
else:
439+
rdtype = result.dtype
440+
if dpnp.issubdtype(rdtype, dpnp.integer):
441+
# 'fft' do its calculations in float
442+
# and 'auto' could use fft
443+
# also assert_dtype_allclose for integer types is
444+
# always check for exact match
445+
rdtype = dpnp.default_float_type(ad.device)
446+
447+
result = result.astype(rdtype)
448+
429449
rtol = 1e-3
430450
atol = 1e-3
431451

@@ -447,6 +467,8 @@ def test_correlate_random(self, a_size, v_size, mode, dtype, method):
447467
# We can tolerate a few such outliers.
448468
max_outliers = 10 if expected.size > 1 else 0
449469
if invalid.sum() > max_outliers:
470+
# we already failed check,
471+
# call assert_dtype_allclose just to report error nicely
450472
assert_dtype_allclose(result, expected, factor=1000)
451473

452474
def test_correlate_mode_error(self):

0 commit comments

Comments
 (0)