1+ from cubed .array_api .creation_functions import full
12from cubed .array_api .manipulation_functions import concat
23
34
4- def pad (x , pad_width , mode = None , chunks = None ):
5+ def pad (x , pad_width , mode = None , constant_values = 0 , chunks = None ):
56 """Pad an array."""
67 if len (pad_width ) != x .ndim :
78 raise ValueError ("`pad_width` must have as many entries as array dimensions" )
9+
10+ if mode == "constant" :
11+ return _pad_constant (x , pad_width , constant_values , chunks )
12+ elif mode == "symmetric" :
13+ return _pad_symmetric (x , pad_width , chunks )
14+ else :
15+ raise ValueError (f"Mode is not supported: { mode } " )
16+
17+
18+ def _pad_constant (x , pad_width , constant_values , chunks ):
19+ cv = _normalize_constant_values (constant_values , x .ndim )
20+ result = x
21+ for axis , ((pad_before , pad_after ), (val_before , val_after )) in enumerate (
22+ zip (pad_width , cv )
23+ ):
24+ if pad_before == 0 and pad_after == 0 :
25+ continue
26+ arrays = []
27+ if pad_before > 0 :
28+ shape = list (result .shape )
29+ shape [axis ] = pad_before
30+ c = list (result .chunksize )
31+ c [axis ] = min (pad_before , result .chunksize [axis ])
32+ arrays .append (
33+ full (
34+ tuple (shape ),
35+ val_before ,
36+ dtype = result .dtype ,
37+ chunks = tuple (c ),
38+ spec = result .spec ,
39+ )
40+ )
41+ arrays .append (result )
42+ if pad_after > 0 :
43+ shape = list (result .shape )
44+ shape [axis ] = pad_after
45+ c = list (result .chunksize )
46+ c [axis ] = min (pad_after , result .chunksize [axis ])
47+ arrays .append (
48+ full (
49+ tuple (shape ),
50+ val_after ,
51+ dtype = result .dtype ,
52+ chunks = tuple (c ),
53+ spec = result .spec ,
54+ )
55+ )
56+ result = concat (arrays , axis = axis , chunks = chunks or x .chunksize )
57+ return result
58+
59+
60+ def _normalize_constant_values (constant_values , ndim ):
61+ """Normalize constant_values to a list of (before, after) per axis.
62+
63+ Accepts a scalar, a (before, after) pair, or a sequence of ndim pairs.
64+ """
65+ try :
66+ iter (constant_values )
67+ except TypeError :
68+ # scalar
69+ return [(constant_values , constant_values )] * ndim
70+
71+ cv = list (constant_values )
72+ if len (cv ) == 2 and not hasattr (cv [0 ], "__len__" ):
73+ # (before, after) pair applied to every axis
74+ return [(cv [0 ], cv [1 ])] * ndim
75+ if len (cv ) == ndim :
76+ # per-axis sequence of (before, after) pairs
77+ return [(pair [0 ], pair [1 ]) for pair in cv ]
78+ raise ValueError (f"Invalid constant_values for ndim={ ndim } : { constant_values } " )
79+
80+
81+ def _pad_symmetric (x , pad_width , chunks ):
882 axis = tuple (
983 i
1084 for (i , (before , after )) in enumerate (pad_width )
@@ -15,15 +89,7 @@ def pad(x, pad_width, mode=None, chunks=None):
1589 axis = axis [0 ]
1690 if pad_width [axis ] != (1 , 0 ):
1791 raise ValueError ("only a pad width of (1, 0) is allowed" )
18- if mode != "symmetric" :
19- raise ValueError (f"Mode is not supported: { mode } " )
2092
21- select = []
22- for i in range (x .ndim ):
23- if i == axis :
24- select .append (slice (0 , 1 ))
25- else :
26- select .append (slice (None ))
27- select = tuple (select )
93+ select = tuple (slice (0 , 1 ) if i == axis else slice (None ) for i in range (x .ndim ))
2894 a = x [select ]
2995 return concat ([a , x ], axis = axis , chunks = chunks or x .chunksize )
0 commit comments