Skip to content

Commit 879e1ce

Browse files
mkitticlauded-v-b
authored
perf: Fix near-miss penalty in _morton_order with hybrid ceiling+argsort strategy (#3718)
* tests: Add non-power-of-2 shard shapes to benchmarks Add (30,30,30) to large_morton_shards and (10,10,10), (20,20,20), (30,30,30) to morton_iter_shapes to benchmark the scalar fallback path for non-power-of-2 shapes, which are not fully covered by the vectorized hypercube path. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * tests: Add near-miss power-of-2 shape (33,33,33) to benchmarks Documents the performance penalty when a shard shape is just above a power-of-2 boundary, causing n_z to jump from 32,768 to 262,144. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * style: Apply ruff format to benchmark file Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * changes: Add changelog entry for PR #3717 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * perf: Fix near-miss penalty in _morton_order with hybrid ceiling+argsort strategy For shapes just above a power-of-2 (e.g. (33,33,33)), the ceiling-only approach generates n_z=262,144 Morton codes for only 35,937 valid coordinates (7.3× overgeneration). The floor+scalar approach is even worse since the scalar loop iterates n_z-n_floor times (229,376 for (33,33,33)), not n_total-n_floor. The fix: when n_z > 4*n_total, use an argsort strategy that enumerates all n_total valid coordinates via meshgrid, encodes each to a Morton code using vectorized bit manipulation, then sorts by Morton code. This avoids the large overgeneration while remaining fully vectorized. Result for test_morton_order_iter: (30,30,30): 24ms (ceiling, ratio=1.21) (32,32,32): 28ms (ceiling, ratio=1.00) (33,33,33): 32ms (argsort, ratio=7.3 → fixed from ~820ms with scalar) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: Address pre-commit CI failures in _morton_order - Replace Unicode multiplication sign × with ASCII x in comment (RUF003) - Add explicit type annotation for np.argsort result to satisfy mypy Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: Cast argsort result via np.asarray to resolve mypy no-any-return np.stack returns Any in mypy's view, so indexing into it also returns Any. Using np.asarray(..., dtype=np.intp) makes the type explicit and avoids the no-any-return error at the return site. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: Pre-declare order type to resolve mypy no-any-return in _morton_order np.asarray and np.stack return Any with numpy 2.1 type stubs, causing mypy to infer the return type as Any. Pre-declaring order as npt.NDArray[np.intp] before the if/else makes the intended type explicit. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: Davis Bennett <davis.v.bennett@gmail.com>
1 parent 32c7ab9 commit 879e1ce

File tree

1 file changed

+40
-46
lines changed

1 file changed

+40
-46
lines changed

src/zarr/core/indexing.py

Lines changed: 40 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,54 +1512,48 @@ def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]:
15121512
out.flags.writeable = False
15131513
return out
15141514

1515-
# Optimization: Remove singleton dimensions to enable magic number usage
1516-
# for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand.
1517-
singleton_dims = tuple(i for i, s in enumerate(chunk_shape) if s == 1)
1518-
if singleton_dims:
1519-
squeezed_shape = tuple(s for s in chunk_shape if s != 1)
1520-
if squeezed_shape:
1521-
# Compute Morton order on squeezed shape, then expand singleton dims (always 0)
1522-
squeezed_order = np.asarray(_morton_order(squeezed_shape))
1523-
out = np.zeros((n_total, n_dims), dtype=np.intp)
1524-
squeezed_col = 0
1525-
for full_col in range(n_dims):
1526-
if chunk_shape[full_col] != 1:
1527-
out[:, full_col] = squeezed_order[:, squeezed_col]
1528-
squeezed_col += 1
1529-
else:
1530-
# All dimensions are singletons, just return the single point
1531-
out = np.zeros((1, n_dims), dtype=np.intp)
1532-
out.flags.writeable = False
1533-
return out
1534-
1535-
# Find the largest power-of-2 hypercube that fits within chunk_shape.
1536-
# Within this hypercube, Morton codes are guaranteed to be in bounds.
1537-
min_dim = min(chunk_shape)
1538-
if min_dim >= 1:
1539-
power = min_dim.bit_length() - 1 # floor(log2(min_dim))
1540-
hypercube_size = 1 << power # 2^power
1541-
n_hypercube = hypercube_size**n_dims
1515+
# Ceiling hypercube: smallest power-of-2 hypercube whose Morton codes span
1516+
# all valid coordinates in chunk_shape. (c-1).bit_length() gives the number
1517+
# of bits needed to index c values (0 for singleton dims). n_z = 2**total_bits
1518+
# is the size of this hypercube.
1519+
total_bits = sum((c - 1).bit_length() for c in chunk_shape)
1520+
n_z = 1 << total_bits if total_bits > 0 else 1
1521+
1522+
# Decode all Morton codes in the ceiling hypercube, then filter to valid coords.
1523+
# This is fully vectorized. For shapes with n_z >> n_total (e.g. (33,33,33):
1524+
# n_z=262144, n_total=35937), consider the argsort strategy below.
1525+
order: npt.NDArray[np.intp]
1526+
if n_z <= 4 * n_total:
1527+
# Ceiling strategy: decode all n_z codes vectorized, filter in-bounds.
1528+
# Works well when the overgeneration ratio n_z/n_total is small (≤4).
1529+
z_values = np.arange(n_z, dtype=np.intp)
1530+
all_coords = decode_morton_vectorized(z_values, chunk_shape)
1531+
shape_arr = np.array(chunk_shape, dtype=np.intp)
1532+
valid_mask = np.all(all_coords < shape_arr, axis=1)
1533+
order = all_coords[valid_mask]
15421534
else:
1543-
n_hypercube = 0
1535+
# Argsort strategy: enumerate all n_total valid coordinates directly,
1536+
# encode each to a Morton code, then sort by code. Avoids the 8x or
1537+
# larger overgeneration penalty for near-miss shapes like (33,33,33).
1538+
# Cost: O(n_total * bits) encode + O(n_total log n_total) sort,
1539+
# vs O(n_z * bits) = O(8 * n_total * bits) for ceiling.
1540+
grids = np.meshgrid(*[np.arange(c, dtype=np.intp) for c in chunk_shape], indexing="ij")
1541+
all_coords = np.stack([g.ravel() for g in grids], axis=1)
1542+
1543+
# Encode all coordinates to Morton codes (vectorized).
1544+
bits_per_dim = tuple((c - 1).bit_length() for c in chunk_shape)
1545+
max_coord_bits = max(bits_per_dim)
1546+
z_codes = np.zeros(n_total, dtype=np.intp)
1547+
output_bit = 0
1548+
for coord_bit in range(max_coord_bits):
1549+
for dim in range(n_dims):
1550+
if coord_bit < bits_per_dim[dim]:
1551+
z_codes |= ((all_coords[:, dim] >> coord_bit) & 1) << output_bit
1552+
output_bit += 1
1553+
1554+
sort_idx: npt.NDArray[np.intp] = np.argsort(z_codes, kind="stable")
1555+
order = np.asarray(all_coords[sort_idx], dtype=np.intp)
15441556

1545-
# Within the hypercube, no bounds checking needed - use vectorized decoding
1546-
if n_hypercube > 0:
1547-
z_values = np.arange(n_hypercube, dtype=np.intp)
1548-
order: npt.NDArray[np.intp] = decode_morton_vectorized(z_values, chunk_shape)
1549-
else:
1550-
order = np.empty((0, n_dims), dtype=np.intp)
1551-
1552-
# For remaining elements outside the hypercube, bounds checking is needed
1553-
remaining: list[tuple[int, ...]] = []
1554-
i = n_hypercube
1555-
while len(order) + len(remaining) < n_total:
1556-
m = decode_morton(i, chunk_shape)
1557-
if all(x < y for x, y in zip(m, chunk_shape, strict=False)):
1558-
remaining.append(m)
1559-
i += 1
1560-
1561-
if remaining:
1562-
order = np.vstack([order, np.array(remaining, dtype=np.intp)])
15631557
order.flags.writeable = False
15641558
return order
15651559

0 commit comments

Comments
 (0)