@@ -1452,7 +1452,7 @@ def make_slice_selection(selection: Any) -> list[slice]:
14521452def decode_morton (z : int , chunk_shape : tuple [int , ...]) -> tuple [int , ...]:
14531453 # Inspired by compressed morton code as implemented in Neuroglancer
14541454 # https://github.com/google/neuroglancer/blob/master/src/neuroglancer/datasource/precomputed/volume.md#compressed-morton-code
1455- bits = tuple (math . ceil ( math . log2 ( c ) ) for c in chunk_shape )
1455+ bits = tuple (( c - 1 ). bit_length ( ) for c in chunk_shape )
14561456 max_coords_bits = max (bits )
14571457 input_bit = 0
14581458 input_value = z
@@ -1467,16 +1467,102 @@ def decode_morton(z: int, chunk_shape: tuple[int, ...]) -> tuple[int, ...]:
14671467 return tuple (out )
14681468
14691469
1470- @lru_cache
1470+ def decode_morton_vectorized (
1471+ z : npt .NDArray [np .intp ], chunk_shape : tuple [int , ...]
1472+ ) -> npt .NDArray [np .intp ]:
1473+ """Vectorized Morton code decoding for multiple z values.
1474+
1475+ Parameters
1476+ ----------
1477+ z : ndarray
1478+ 1D array of Morton codes to decode.
1479+ chunk_shape : tuple of int
1480+ Shape defining the coordinate space.
1481+
1482+ Returns
1483+ -------
1484+ ndarray
1485+ 2D array of shape (len(z), len(chunk_shape)) containing decoded coordinates.
1486+ """
1487+ n_dims = len (chunk_shape )
1488+ bits = tuple ((c - 1 ).bit_length () for c in chunk_shape )
1489+
1490+ max_coords_bits = max (bits ) if bits else 0
1491+ out = np .zeros ((len (z ), n_dims ), dtype = np .intp )
1492+
1493+ input_bit = 0
1494+ for coord_bit in range (max_coords_bits ):
1495+ for dim in range (n_dims ):
1496+ if coord_bit < bits [dim ]:
1497+ # Extract bit at position input_bit from all z values
1498+ bit_values = (z >> input_bit ) & 1
1499+ # Place bit at coord_bit position in dimension dim
1500+ out [:, dim ] |= bit_values << coord_bit
1501+ input_bit += 1
1502+
1503+ return out
1504+
1505+
1506+ @lru_cache (maxsize = 16 )
14711507def _morton_order (chunk_shape : tuple [int , ...]) -> tuple [tuple [int , ...], ...]:
14721508 n_total = product (chunk_shape )
1473- order : list [tuple [int , ...]] = []
1474- i = 0
1509+ if n_total == 0 :
1510+ return ()
1511+
1512+ # Optimization: Remove singleton dimensions to enable magic number usage
1513+ # for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand.
1514+ singleton_dims = tuple (i for i , s in enumerate (chunk_shape ) if s == 1 )
1515+ if singleton_dims :
1516+ squeezed_shape = tuple (s for s in chunk_shape if s != 1 )
1517+ if squeezed_shape :
1518+ # Compute Morton order on squeezed shape
1519+ squeezed_order = _morton_order (squeezed_shape )
1520+ # Expand coordinates to include singleton dimensions (always 0)
1521+ expanded : list [tuple [int , ...]] = []
1522+ for coord in squeezed_order :
1523+ full_coord : list [int ] = []
1524+ squeezed_idx = 0
1525+ for i in range (len (chunk_shape )):
1526+ if chunk_shape [i ] == 1 :
1527+ full_coord .append (0 )
1528+ else :
1529+ full_coord .append (coord [squeezed_idx ])
1530+ squeezed_idx += 1
1531+ expanded .append (tuple (full_coord ))
1532+ return tuple (expanded )
1533+ else :
1534+ # All dimensions are singletons, just return the single point
1535+ return ((0 ,) * len (chunk_shape ),)
1536+
1537+ n_dims = len (chunk_shape )
1538+
1539+ # Find the largest power-of-2 hypercube that fits within chunk_shape.
1540+ # Within this hypercube, Morton codes are guaranteed to be in bounds.
1541+ min_dim = min (chunk_shape )
1542+ if min_dim >= 1 :
1543+ power = min_dim .bit_length () - 1 # floor(log2(min_dim))
1544+ hypercube_size = 1 << power # 2^power
1545+ n_hypercube = hypercube_size ** n_dims
1546+ else :
1547+ n_hypercube = 0
1548+
1549+ # Within the hypercube, no bounds checking needed - use vectorized decoding
1550+ order : list [tuple [int , ...]]
1551+ if n_hypercube > 0 :
1552+ z_values = np .arange (n_hypercube , dtype = np .intp )
1553+ hypercube_coords = decode_morton_vectorized (z_values , chunk_shape )
1554+ order = [tuple (row ) for row in hypercube_coords ]
1555+ else :
1556+ order = []
1557+
1558+ # For remaining elements, bounds checking is needed
1559+ i = n_hypercube
14751560 while len (order ) < n_total :
14761561 m = decode_morton (i , chunk_shape )
14771562 if all (x < y for x , y in zip (m , chunk_shape , strict = False )):
14781563 order .append (m )
14791564 i += 1
1565+
14801566 return tuple (order )
14811567
14821568
0 commit comments