@@ -1452,7 +1452,7 @@ def make_slice_selection(selection: Any) -> list[slice]:
14521452def decode_morton (z : int , chunk_shape : tuple [int , ...]) -> tuple [int , ...]:
14531453 # Inspired by compressed morton code as implemented in Neuroglancer
14541454 # https://github.com/google/neuroglancer/blob/master/src/neuroglancer/datasource/precomputed/volume.md#compressed-morton-code
1455- bits = tuple (math . ceil ( math . log2 ( c ) ) for c in chunk_shape )
1455+ bits = tuple (( c - 1 ). bit_length ( ) for c in chunk_shape )
14561456 max_coords_bits = max (bits )
14571457 input_bit = 0
14581458 input_value = z
@@ -1467,21 +1467,110 @@ def decode_morton(z: int, chunk_shape: tuple[int, ...]) -> tuple[int, ...]:
14671467 return tuple (out )
14681468
14691469
1470- @lru_cache
1471- def _morton_order (chunk_shape : tuple [int , ...]) -> tuple [tuple [int , ...], ...]:
1470+ def decode_morton_vectorized (
1471+ z : npt .NDArray [np .intp ], chunk_shape : tuple [int , ...]
1472+ ) -> npt .NDArray [np .intp ]:
1473+ """Vectorized Morton code decoding for multiple z values.
1474+
1475+ Parameters
1476+ ----------
1477+ z : ndarray
1478+ 1D array of Morton codes to decode.
1479+ chunk_shape : tuple of int
1480+ Shape defining the coordinate space.
1481+
1482+ Returns
1483+ -------
1484+ ndarray
1485+ 2D array of shape (len(z), len(chunk_shape)) containing decoded coordinates.
1486+ """
1487+ n_dims = len (chunk_shape )
1488+ bits = tuple ((c - 1 ).bit_length () for c in chunk_shape )
1489+
1490+ max_coords_bits = max (bits ) if bits else 0
1491+ out = np .zeros ((len (z ), n_dims ), dtype = np .intp )
1492+
1493+ input_bit = 0
1494+ for coord_bit in range (max_coords_bits ):
1495+ for dim in range (n_dims ):
1496+ if coord_bit < bits [dim ]:
1497+ # Extract bit at position input_bit from all z values
1498+ bit_values = (z >> input_bit ) & 1
1499+ # Place bit at coord_bit position in dimension dim
1500+ out [:, dim ] |= bit_values << coord_bit
1501+ input_bit += 1
1502+
1503+ return out
1504+
1505+
1506+ @lru_cache (maxsize = 16 )
1507+ def _morton_order (chunk_shape : tuple [int , ...]) -> npt .NDArray [np .intp ]:
14721508 n_total = product (chunk_shape )
1473- order : list [tuple [int , ...]] = []
1474- i = 0
1475- while len (order ) < n_total :
1509+ n_dims = len (chunk_shape )
1510+ if n_total == 0 :
1511+ out = np .empty ((0 , n_dims ), dtype = np .intp )
1512+ out .flags .writeable = False
1513+ return out
1514+
1515+ # Optimization: Remove singleton dimensions to enable magic number usage
1516+ # for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand.
1517+ singleton_dims = tuple (i for i , s in enumerate (chunk_shape ) if s == 1 )
1518+ if singleton_dims :
1519+ squeezed_shape = tuple (s for s in chunk_shape if s != 1 )
1520+ if squeezed_shape :
1521+ # Compute Morton order on squeezed shape, then expand singleton dims (always 0)
1522+ squeezed_order = np .asarray (_morton_order (squeezed_shape ))
1523+ out = np .zeros ((n_total , n_dims ), dtype = np .intp )
1524+ squeezed_col = 0
1525+ for full_col in range (n_dims ):
1526+ if chunk_shape [full_col ] != 1 :
1527+ out [:, full_col ] = squeezed_order [:, squeezed_col ]
1528+ squeezed_col += 1
1529+ else :
1530+ # All dimensions are singletons, just return the single point
1531+ out = np .zeros ((1 , n_dims ), dtype = np .intp )
1532+ out .flags .writeable = False
1533+ return out
1534+
1535+ # Find the largest power-of-2 hypercube that fits within chunk_shape.
1536+ # Within this hypercube, Morton codes are guaranteed to be in bounds.
1537+ min_dim = min (chunk_shape )
1538+ if min_dim >= 1 :
1539+ power = min_dim .bit_length () - 1 # floor(log2(min_dim))
1540+ hypercube_size = 1 << power # 2^power
1541+ n_hypercube = hypercube_size ** n_dims
1542+ else :
1543+ n_hypercube = 0
1544+
1545+ # Within the hypercube, no bounds checking needed - use vectorized decoding
1546+ if n_hypercube > 0 :
1547+ z_values = np .arange (n_hypercube , dtype = np .intp )
1548+ order : npt .NDArray [np .intp ] = decode_morton_vectorized (z_values , chunk_shape )
1549+ else :
1550+ order = np .empty ((0 , n_dims ), dtype = np .intp )
1551+
1552+ # For remaining elements outside the hypercube, bounds checking is needed
1553+ remaining : list [tuple [int , ...]] = []
1554+ i = n_hypercube
1555+ while len (order ) + len (remaining ) < n_total :
14761556 m = decode_morton (i , chunk_shape )
14771557 if all (x < y for x , y in zip (m , chunk_shape , strict = False )):
1478- order .append (m )
1558+ remaining .append (m )
14791559 i += 1
1480- return tuple (order )
1560+
1561+ if remaining :
1562+ order = np .vstack ([order , np .array (remaining , dtype = np .intp )])
1563+ order .flags .writeable = False
1564+ return order
1565+
1566+
1567+ @lru_cache (maxsize = 16 )
1568+ def _morton_order_keys (chunk_shape : tuple [int , ...]) -> tuple [tuple [int , ...], ...]:
1569+ return tuple (tuple (int (x ) for x in row ) for row in _morton_order (chunk_shape ))
14811570
14821571
14831572def morton_order_iter (chunk_shape : tuple [int , ...]) -> Iterator [tuple [int , ...]]:
1484- return iter (_morton_order (tuple (chunk_shape )))
1573+ return iter (_morton_order_keys (tuple (chunk_shape )))
14851574
14861575
14871576def c_order_iter (chunks_per_shard : tuple [int , ...]) -> Iterator [tuple [int , ...]]:
0 commit comments