diff --git a/src/ntops/kernels/element_wise.py b/src/ntops/kernels/element_wise.py index 4ed3c10..05aae1e 100644 --- a/src/ntops/kernels/element_wise.py +++ b/src/ntops/kernels/element_wise.py @@ -1,12 +1,12 @@ import ninetoothed -def arrangement(*tensors): +def arrangement(*tensors, block_size=ninetoothed.block_size()): ndim = max(tensor.ndim for tensor in tensors) assert all(tensor.ndim == ndim or tensor.ndim == 0 for tensor in tensors) - block_shape = tuple(1 for _ in range(ndim - 1)) + (ninetoothed.block_size(),) + block_shape = tuple(1 for _ in range(ndim - 1)) + (block_size,) return tuple( tensor.tile(block_shape) if tensor.ndim != 0 else tensor for tensor in tensors