Skip to content

Commit 999416c

Browse files
authored
Implement constant_values in pad (#903)
1 parent 94efb66 commit 999416c

2 files changed

Lines changed: 100 additions & 11 deletions

File tree

cubed/array/pad.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,84 @@
1+
from cubed.array_api.creation_functions import full
12
from cubed.array_api.manipulation_functions import concat
23

34

4-
def pad(x, pad_width, mode=None, chunks=None):
5+
def pad(x, pad_width, mode=None, constant_values=0, chunks=None):
56
"""Pad an array."""
67
if len(pad_width) != x.ndim:
78
raise ValueError("`pad_width` must have as many entries as array dimensions")
9+
10+
if mode == "constant":
11+
return _pad_constant(x, pad_width, constant_values, chunks)
12+
elif mode == "symmetric":
13+
return _pad_symmetric(x, pad_width, chunks)
14+
else:
15+
raise ValueError(f"Mode is not supported: {mode}")
16+
17+
18+
def _pad_constant(x, pad_width, constant_values, chunks):
19+
cv = _normalize_constant_values(constant_values, x.ndim)
20+
result = x
21+
for axis, ((pad_before, pad_after), (val_before, val_after)) in enumerate(
22+
zip(pad_width, cv)
23+
):
24+
if pad_before == 0 and pad_after == 0:
25+
continue
26+
arrays = []
27+
if pad_before > 0:
28+
shape = list(result.shape)
29+
shape[axis] = pad_before
30+
c = list(result.chunksize)
31+
c[axis] = min(pad_before, result.chunksize[axis])
32+
arrays.append(
33+
full(
34+
tuple(shape),
35+
val_before,
36+
dtype=result.dtype,
37+
chunks=tuple(c),
38+
spec=result.spec,
39+
)
40+
)
41+
arrays.append(result)
42+
if pad_after > 0:
43+
shape = list(result.shape)
44+
shape[axis] = pad_after
45+
c = list(result.chunksize)
46+
c[axis] = min(pad_after, result.chunksize[axis])
47+
arrays.append(
48+
full(
49+
tuple(shape),
50+
val_after,
51+
dtype=result.dtype,
52+
chunks=tuple(c),
53+
spec=result.spec,
54+
)
55+
)
56+
result = concat(arrays, axis=axis, chunks=chunks or x.chunksize)
57+
return result
58+
59+
60+
def _normalize_constant_values(constant_values, ndim):
61+
"""Normalize constant_values to a list of (before, after) per axis.
62+
63+
Accepts a scalar, a (before, after) pair, or a sequence of ndim pairs.
64+
"""
65+
try:
66+
iter(constant_values)
67+
except TypeError:
68+
# scalar
69+
return [(constant_values, constant_values)] * ndim
70+
71+
cv = list(constant_values)
72+
if len(cv) == 2 and not hasattr(cv[0], "__len__"):
73+
# (before, after) pair applied to every axis
74+
return [(cv[0], cv[1])] * ndim
75+
if len(cv) == ndim:
76+
# per-axis sequence of (before, after) pairs
77+
return [(pair[0], pair[1]) for pair in cv]
78+
raise ValueError(f"Invalid constant_values for ndim={ndim}: {constant_values}")
79+
80+
81+
def _pad_symmetric(x, pad_width, chunks):
882
axis = tuple(
983
i
1084
for (i, (before, after)) in enumerate(pad_width)
@@ -15,15 +89,7 @@ def pad(x, pad_width, mode=None, chunks=None):
1589
axis = axis[0]
1690
if pad_width[axis] != (1, 0):
1791
raise ValueError("only a pad width of (1, 0) is allowed")
18-
if mode != "symmetric":
19-
raise ValueError(f"Mode is not supported: {mode}")
2092

21-
select = []
22-
for i in range(x.ndim):
23-
if i == axis:
24-
select.append(slice(0, 1))
25-
else:
26-
select.append(slice(None))
27-
select = tuple(select)
93+
select = tuple(slice(0, 1) if i == axis else slice(None) for i in range(x.ndim))
2894
a = x[select]
2995
return concat([a, x], axis=axis, chunks=chunks or x.chunksize)

cubed/tests/array/test_pad.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,34 @@ def spec(tmp_path):
1111
return cubed.Spec(tmp_path, allowed_mem=100000)
1212

1313

14-
def test_pad(spec):
14+
def test_pad_symmetric(spec):
1515
an = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
1616

1717
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
1818
b = cubed.pad(a, ((1, 0), (0, 0)), mode="symmetric")
1919
assert b.chunks == ((2, 2), (2, 1))
2020

2121
assert_array_equal(b.compute(), np.pad(an, ((1, 0), (0, 0)), mode="symmetric"))
22+
23+
24+
def test_pad_constant(spec):
25+
an = np.arange(12).reshape(3, 4).astype(float)
26+
a = xp.asarray(an, chunks=(2, 2), spec=spec)
27+
28+
# scalar constant_values, padding on both axes
29+
b = cubed.pad(a, ((2, 1), (0, 3)), mode="constant", constant_values=0)
30+
assert_array_equal(
31+
b, np.pad(an, ((2, 1), (0, 3)), mode="constant", constant_values=0)
32+
)
33+
34+
# nan fill (xarray shift/rolling use case)
35+
b = cubed.pad(a, ((2, 0), (0, 0)), mode="constant", constant_values=float("nan"))
36+
assert_array_equal(
37+
b, np.pad(an, ((2, 0), (0, 0)), mode="constant", constant_values=float("nan"))
38+
)
39+
40+
# single-axis padding (no-op on one axis)
41+
b = cubed.pad(a, ((0, 0), (1, 1)), mode="constant", constant_values=99)
42+
assert_array_equal(
43+
b, np.pad(an, ((0, 0), (1, 1)), mode="constant", constant_values=99)
44+
)

0 commit comments

Comments
 (0)