Skip to content

Commit 5f7d69a

Browse files
authored
Merge branch 'main' into feature/rectilinear-chunk-grid
2 parents 629d34a + 879e1ce commit 5f7d69a

1 file changed

Lines changed: 40 additions & 46 deletions

File tree

src/zarr/core/indexing.py

Lines changed: 40 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,54 +1609,48 @@ def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]:
16091609
out.flags.writeable = False
16101610
return out
16111611

1612-
# Optimization: Remove singleton dimensions to enable magic number usage
1613-
# for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand.
1614-
singleton_dims = tuple(i for i, s in enumerate(chunk_shape) if s == 1)
1615-
if singleton_dims:
1616-
squeezed_shape = tuple(s for s in chunk_shape if s != 1)
1617-
if squeezed_shape:
1618-
# Compute Morton order on squeezed shape, then expand singleton dims (always 0)
1619-
squeezed_order = np.asarray(_morton_order(squeezed_shape))
1620-
out = np.zeros((n_total, n_dims), dtype=np.intp)
1621-
squeezed_col = 0
1622-
for full_col in range(n_dims):
1623-
if chunk_shape[full_col] != 1:
1624-
out[:, full_col] = squeezed_order[:, squeezed_col]
1625-
squeezed_col += 1
1626-
else:
1627-
# All dimensions are singletons, just return the single point
1628-
out = np.zeros((1, n_dims), dtype=np.intp)
1629-
out.flags.writeable = False
1630-
return out
1631-
1632-
# Find the largest power-of-2 hypercube that fits within chunk_shape.
1633-
# Within this hypercube, Morton codes are guaranteed to be in bounds.
1634-
min_dim = min(chunk_shape)
1635-
if min_dim >= 1:
1636-
power = min_dim.bit_length() - 1 # floor(log2(min_dim))
1637-
hypercube_size = 1 << power # 2^power
1638-
n_hypercube = hypercube_size**n_dims
1612+
# Ceiling hypercube: smallest power-of-2 hypercube whose Morton codes span
1613+
# all valid coordinates in chunk_shape. (c-1).bit_length() gives the number
1614+
# of bits needed to index c values (0 for singleton dims). n_z = 2**total_bits
1615+
# is the size of this hypercube.
1616+
total_bits = sum((c - 1).bit_length() for c in chunk_shape)
1617+
n_z = 1 << total_bits if total_bits > 0 else 1
1618+
1619+
# Decode all Morton codes in the ceiling hypercube, then filter to valid coords.
1620+
# This is fully vectorized. For shapes with n_z >> n_total (e.g. (33,33,33):
1621+
# n_z=262144, n_total=35937), consider the argsort strategy below.
1622+
order: npt.NDArray[np.intp]
1623+
if n_z <= 4 * n_total:
1624+
# Ceiling strategy: decode all n_z codes vectorized, filter in-bounds.
1625+
# Works well when the overgeneration ratio n_z/n_total is small (≤4).
1626+
z_values = np.arange(n_z, dtype=np.intp)
1627+
all_coords = decode_morton_vectorized(z_values, chunk_shape)
1628+
shape_arr = np.array(chunk_shape, dtype=np.intp)
1629+
valid_mask = np.all(all_coords < shape_arr, axis=1)
1630+
order = all_coords[valid_mask]
16391631
else:
1640-
n_hypercube = 0
1632+
# Argsort strategy: enumerate all n_total valid coordinates directly,
1633+
# encode each to a Morton code, then sort by code. Avoids the 8x or
1634+
# larger overgeneration penalty for near-miss shapes like (33,33,33).
1635+
# Cost: O(n_total * bits) encode + O(n_total log n_total) sort,
1636+
# vs O(n_z * bits) = O(8 * n_total * bits) for ceiling.
1637+
grids = np.meshgrid(*[np.arange(c, dtype=np.intp) for c in chunk_shape], indexing="ij")
1638+
all_coords = np.stack([g.ravel() for g in grids], axis=1)
1639+
1640+
# Encode all coordinates to Morton codes (vectorized).
1641+
bits_per_dim = tuple((c - 1).bit_length() for c in chunk_shape)
1642+
max_coord_bits = max(bits_per_dim)
1643+
z_codes = np.zeros(n_total, dtype=np.intp)
1644+
output_bit = 0
1645+
for coord_bit in range(max_coord_bits):
1646+
for dim in range(n_dims):
1647+
if coord_bit < bits_per_dim[dim]:
1648+
z_codes |= ((all_coords[:, dim] >> coord_bit) & 1) << output_bit
1649+
output_bit += 1
1650+
1651+
sort_idx: npt.NDArray[np.intp] = np.argsort(z_codes, kind="stable")
1652+
order = np.asarray(all_coords[sort_idx], dtype=np.intp)
16411653

1642-
# Within the hypercube, no bounds checking needed - use vectorized decoding
1643-
if n_hypercube > 0:
1644-
z_values = np.arange(n_hypercube, dtype=np.intp)
1645-
order: npt.NDArray[np.intp] = decode_morton_vectorized(z_values, chunk_shape)
1646-
else:
1647-
order = np.empty((0, n_dims), dtype=np.intp)
1648-
1649-
# For remaining elements outside the hypercube, bounds checking is needed
1650-
remaining: list[tuple[int, ...]] = []
1651-
i = n_hypercube
1652-
while len(order) + len(remaining) < n_total:
1653-
m = decode_morton(i, chunk_shape)
1654-
if all(x < y for x, y in zip(m, chunk_shape, strict=False)):
1655-
remaining.append(m)
1656-
i += 1
1657-
1658-
if remaining:
1659-
order = np.vstack([order, np.array(remaining, dtype=np.intp)])
16601654
order.flags.writeable = False
16611655
return order
16621656

0 commit comments

Comments
 (0)