@@ -1504,70 +1504,67 @@ def decode_morton_vectorized(
15041504
15051505
15061506@lru_cache (maxsize = 16 )
1507- def _morton_order (chunk_shape : tuple [int , ...]) -> tuple [ tuple [ int , ...], ... ]:
1507+ def _morton_order (chunk_shape : tuple [int , ...]) -> npt . NDArray [ np . intp ]:
15081508 n_total = product (chunk_shape )
1509- if n_total == 0 :
1510- return ()
1511-
1512- # Optimization: Remove singleton dimensions to enable magic number usage
1513- # for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand.
1514- singleton_dims = tuple (i for i , s in enumerate (chunk_shape ) if s == 1 )
1515- if singleton_dims :
1516- squeezed_shape = tuple (s for s in chunk_shape if s != 1 )
1517- if squeezed_shape :
1518- # Compute Morton order on squeezed shape
1519- squeezed_order = _morton_order (squeezed_shape )
1520- # Expand coordinates to include singleton dimensions (always 0)
1521- expanded : list [tuple [int , ...]] = []
1522- for coord in squeezed_order :
1523- full_coord : list [int ] = []
1524- squeezed_idx = 0
1525- for i in range (len (chunk_shape )):
1526- if chunk_shape [i ] == 1 :
1527- full_coord .append (0 )
1528- else :
1529- full_coord .append (coord [squeezed_idx ])
1530- squeezed_idx += 1
1531- expanded .append (tuple (full_coord ))
1532- return tuple (expanded )
1533- else :
1534- # All dimensions are singletons, just return the single point
1535- return ((0 ,) * len (chunk_shape ),)
1536-
15371509 n_dims = len (chunk_shape )
1538-
1539- # Find the largest power-of-2 hypercube that fits within chunk_shape.
1540- # Within this hypercube, Morton codes are guaranteed to be in bounds.
1541- min_dim = min (chunk_shape )
1542- if min_dim >= 1 :
1543- power = min_dim .bit_length () - 1 # floor(log2(min_dim))
1544- hypercube_size = 1 << power # 2^power
1545- n_hypercube = hypercube_size ** n_dims
1546- else :
1547- n_hypercube = 0
1548-
1549- # Within the hypercube, no bounds checking needed - use vectorized decoding
1550- order : list [tuple [int , ...]]
1551- if n_hypercube > 0 :
1552- z_values = np .arange (n_hypercube , dtype = np .intp )
1553- hypercube_coords = decode_morton_vectorized (z_values , chunk_shape )
1554- order = [tuple (row ) for row in hypercube_coords ]
1510+ if n_total == 0 :
1511+ out = np .empty ((0 , n_dims ), dtype = np .intp )
1512+ out .flags .writeable = False
1513+ return out
1514+
1515+ # Ceiling hypercube: smallest power-of-2 hypercube whose Morton codes span
1516+ # all valid coordinates in chunk_shape. (c-1).bit_length() gives the number
1517+ # of bits needed to index c values (0 for singleton dims). n_z = 2**total_bits
1518+ # is the size of this hypercube.
1519+ total_bits = sum ((c - 1 ).bit_length () for c in chunk_shape )
1520+ n_z = 1 << total_bits if total_bits > 0 else 1
1521+
1522+ # Decode all Morton codes in the ceiling hypercube, then filter to valid coords.
1523+ # This is fully vectorized. For shapes with n_z >> n_total (e.g. (33,33,33):
1524+ # n_z=262144, n_total=35937), consider the argsort strategy below.
1525+ order : npt .NDArray [np .intp ]
1526+ if n_z <= 4 * n_total :
1527+ # Ceiling strategy: decode all n_z codes vectorized, filter in-bounds.
1528+ # Works well when the overgeneration ratio n_z/n_total is small (≤4).
1529+ z_values = np .arange (n_z , dtype = np .intp )
1530+ all_coords = decode_morton_vectorized (z_values , chunk_shape )
1531+ shape_arr = np .array (chunk_shape , dtype = np .intp )
1532+ valid_mask = np .all (all_coords < shape_arr , axis = 1 )
1533+ order = all_coords [valid_mask ]
15551534 else :
1556- order = []
1535+ # Argsort strategy: enumerate all n_total valid coordinates directly,
1536+ # encode each to a Morton code, then sort by code. Avoids the 8x or
1537+ # larger overgeneration penalty for near-miss shapes like (33,33,33).
1538+ # Cost: O(n_total * bits) encode + O(n_total log n_total) sort,
1539+ # vs O(n_z * bits) = O(8 * n_total * bits) for ceiling.
1540+ grids = np .meshgrid (* [np .arange (c , dtype = np .intp ) for c in chunk_shape ], indexing = "ij" )
1541+ all_coords = np .stack ([g .ravel () for g in grids ], axis = 1 )
1542+
1543+ # Encode all coordinates to Morton codes (vectorized).
1544+ bits_per_dim = tuple ((c - 1 ).bit_length () for c in chunk_shape )
1545+ max_coord_bits = max (bits_per_dim )
1546+ z_codes = np .zeros (n_total , dtype = np .intp )
1547+ output_bit = 0
1548+ for coord_bit in range (max_coord_bits ):
1549+ for dim in range (n_dims ):
1550+ if coord_bit < bits_per_dim [dim ]:
1551+ z_codes |= ((all_coords [:, dim ] >> coord_bit ) & 1 ) << output_bit
1552+ output_bit += 1
1553+
1554+ sort_idx : npt .NDArray [np .intp ] = np .argsort (z_codes , kind = "stable" )
1555+ order = np .asarray (all_coords [sort_idx ], dtype = np .intp )
1556+
1557+ order .flags .writeable = False
1558+ return order
15571559
1558- # For remaining elements, bounds checking is needed
1559- i = n_hypercube
1560- while len (order ) < n_total :
1561- m = decode_morton (i , chunk_shape )
1562- if all (x < y for x , y in zip (m , chunk_shape , strict = False )):
1563- order .append (m )
1564- i += 1
15651560
1566- return tuple (order )
1561+ @lru_cache (maxsize = 16 )
1562+ def _morton_order_keys (chunk_shape : tuple [int , ...]) -> tuple [tuple [int , ...], ...]:
1563+ return tuple (tuple (int (x ) for x in row ) for row in _morton_order (chunk_shape ))
15671564
15681565
15691566def morton_order_iter (chunk_shape : tuple [int , ...]) -> Iterator [tuple [int , ...]]:
1570- return iter (_morton_order (tuple (chunk_shape )))
1567+ return iter (_morton_order_keys (tuple (chunk_shape )))
15711568
15721569
15731570def c_order_iter (chunks_per_shard : tuple [int , ...]) -> Iterator [tuple [int , ...]]:
0 commit comments