|
1 | 1 | from bisect import bisect |
| 2 | +from functools import reduce |
| 3 | +from itertools import accumulate, chain |
2 | 4 | from operator import add, mul |
3 | 5 | from typing import Iterator |
4 | 6 |
|
5 | | -import tlz |
6 | | -from toolz import reduce |
7 | | - |
8 | 7 | from cubed.array_api.creation_functions import empty |
9 | 8 | from cubed.backend_array_api import IS_IMMUTABLE_ARRAY |
10 | 9 | from cubed.backend_array_api import namespace as nxp |
@@ -39,7 +38,7 @@ def broadcast_arrays(*arrays): |
39 | 38 |
|
40 | 39 | # Unify uneven chunking |
41 | 40 | inds = [list(reversed(range(x.ndim))) for x in arrays] |
42 | | - uc_args = tlz.concat(zip(arrays, inds)) |
| 41 | + uc_args = chain.from_iterable(zip(arrays, inds)) |
43 | 42 | _, args = unify_chunks(*uc_args, warn=False) |
44 | 43 |
|
45 | 44 | shape = broadcast_shapes(*(e.shape for e in args)) |
@@ -133,11 +132,11 @@ def concat(arrays, /, *, axis=0, chunks=None): |
133 | 132 | inds = [list(range(x.ndim)) for x in arrays] |
134 | 133 | for i, ind in enumerate(inds): |
135 | 134 | ind[axis] = -(i + 1) |
136 | | - uc_args = tlz.concat(zip(arrays, inds)) |
| 135 | + uc_args = chain.from_iterable(zip(arrays, inds)) |
137 | 136 | chunkss, arrays = unify_chunks(*uc_args, warn=False) |
138 | 137 |
|
139 | 138 | # offsets along axis for the start of each array |
140 | | - offsets = [0] + list(tlz.accumulate(add, [a.shape[axis] for a in arrays])) |
| 139 | + offsets = [0] + list(accumulate([a.shape[axis] for a in arrays], add)) |
141 | 140 | in_shapes = tuple(array.shape for array in arrays) |
142 | 141 |
|
143 | 142 | axis = validate_axis(axis, ndim) |
|
0 commit comments