@@ -1609,54 +1609,48 @@ def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]:
16091609 out .flags .writeable = False
16101610 return out
16111611
1612- # Optimization: Remove singleton dimensions to enable magic number usage
1613- # for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand.
1614- singleton_dims = tuple (i for i , s in enumerate (chunk_shape ) if s == 1 )
1615- if singleton_dims :
1616- squeezed_shape = tuple (s for s in chunk_shape if s != 1 )
1617- if squeezed_shape :
1618- # Compute Morton order on squeezed shape, then expand singleton dims (always 0)
1619- squeezed_order = np .asarray (_morton_order (squeezed_shape ))
1620- out = np .zeros ((n_total , n_dims ), dtype = np .intp )
1621- squeezed_col = 0
1622- for full_col in range (n_dims ):
1623- if chunk_shape [full_col ] != 1 :
1624- out [:, full_col ] = squeezed_order [:, squeezed_col ]
1625- squeezed_col += 1
1626- else :
1627- # All dimensions are singletons, just return the single point
1628- out = np .zeros ((1 , n_dims ), dtype = np .intp )
1629- out .flags .writeable = False
1630- return out
1631-
1632- # Find the largest power-of-2 hypercube that fits within chunk_shape.
1633- # Within this hypercube, Morton codes are guaranteed to be in bounds.
1634- min_dim = min (chunk_shape )
1635- if min_dim >= 1 :
1636- power = min_dim .bit_length () - 1 # floor(log2(min_dim))
1637- hypercube_size = 1 << power # 2^power
1638- n_hypercube = hypercube_size ** n_dims
1612+ # Ceiling hypercube: smallest power-of-2 hypercube whose Morton codes span
1613+ # all valid coordinates in chunk_shape. (c-1).bit_length() gives the number
1614+ # of bits needed to index c values (0 for singleton dims). n_z = 2**total_bits
1615+ # is the size of this hypercube.
1616+ total_bits = sum ((c - 1 ).bit_length () for c in chunk_shape )
1617+ n_z = 1 << total_bits if total_bits > 0 else 1
1618+
1619+ # Decode all Morton codes in the ceiling hypercube, then filter to valid coords.
1620+ # This is fully vectorized. For shapes with n_z >> n_total (e.g. (33,33,33):
1621+ # n_z=262144, n_total=35937), consider the argsort strategy below.
1622+ order : npt .NDArray [np .intp ]
1623+ if n_z <= 4 * n_total :
1624+ # Ceiling strategy: decode all n_z codes vectorized, filter in-bounds.
1625+ # Works well when the overgeneration ratio n_z/n_total is small (≤4).
1626+ z_values = np .arange (n_z , dtype = np .intp )
1627+ all_coords = decode_morton_vectorized (z_values , chunk_shape )
1628+ shape_arr = np .array (chunk_shape , dtype = np .intp )
1629+ valid_mask = np .all (all_coords < shape_arr , axis = 1 )
1630+ order = all_coords [valid_mask ]
16391631 else :
1640- n_hypercube = 0
1632+ # Argsort strategy: enumerate all n_total valid coordinates directly,
1633+ # encode each to a Morton code, then sort by code. Avoids the 8x or
1634+ # larger overgeneration penalty for near-miss shapes like (33,33,33).
1635+ # Cost: O(n_total * bits) encode + O(n_total log n_total) sort,
1636+ # vs O(n_z * bits) = O(8 * n_total * bits) for ceiling.
1637+ grids = np .meshgrid (* [np .arange (c , dtype = np .intp ) for c in chunk_shape ], indexing = "ij" )
1638+ all_coords = np .stack ([g .ravel () for g in grids ], axis = 1 )
1639+
1640+ # Encode all coordinates to Morton codes (vectorized).
1641+ bits_per_dim = tuple ((c - 1 ).bit_length () for c in chunk_shape )
1642+ max_coord_bits = max (bits_per_dim )
1643+ z_codes = np .zeros (n_total , dtype = np .intp )
1644+ output_bit = 0
1645+ for coord_bit in range (max_coord_bits ):
1646+ for dim in range (n_dims ):
1647+ if coord_bit < bits_per_dim [dim ]:
1648+ z_codes |= ((all_coords [:, dim ] >> coord_bit ) & 1 ) << output_bit
1649+ output_bit += 1
1650+
1651+ sort_idx : npt .NDArray [np .intp ] = np .argsort (z_codes , kind = "stable" )
1652+ order = np .asarray (all_coords [sort_idx ], dtype = np .intp )
16411653
1642- # Within the hypercube, no bounds checking needed - use vectorized decoding
1643- if n_hypercube > 0 :
1644- z_values = np .arange (n_hypercube , dtype = np .intp )
1645- order : npt .NDArray [np .intp ] = decode_morton_vectorized (z_values , chunk_shape )
1646- else :
1647- order = np .empty ((0 , n_dims ), dtype = np .intp )
1648-
1649- # For remaining elements outside the hypercube, bounds checking is needed
1650- remaining : list [tuple [int , ...]] = []
1651- i = n_hypercube
1652- while len (order ) + len (remaining ) < n_total :
1653- m = decode_morton (i , chunk_shape )
1654- if all (x < y for x , y in zip (m , chunk_shape , strict = False )):
1655- remaining .append (m )
1656- i += 1
1657-
1658- if remaining :
1659- order = np .vstack ([order , np .array (remaining , dtype = np .intp )])
16601654 order .flags .writeable = False
16611655 return order
16621656
0 commit comments