@@ -1467,6 +1467,78 @@ def decode_morton(z: int, chunk_shape: tuple[int, ...]) -> tuple[int, ...]:
14671467 return tuple (out )
14681468
14691469
1470+ # Magic numbers for 2D Morton decode: extract every 2nd bit and compact them.
1471+ # Bits for dimension d are at positions d, d+2, d+4, ...
1472+ _MORTON_2D_MASKS : tuple [int , ...] = (
1473+ 0x5555555555555555 , # Extract mask: bits 0, 2, 4, ...
1474+ 0x3333333333333333 , # Compact stage 1
1475+ 0x0F0F0F0F0F0F0F0F , # Compact stage 2
1476+ 0x00FF00FF00FF00FF , # Compact stage 3
1477+ 0x0000FFFF0000FFFF , # Compact stage 4
1478+ )
1479+
1480+ # Magic numbers for 3D Morton decode: extract every 3rd bit and compact them.
1481+ # Bits for dimension d are at positions d, d+3, d+6, d+9, ...
1482+ # Uses uint64 to handle large mask values correctly with numpy.
1483+ _MORTON_3D_MASKS : tuple [np .uint64 , ...] = (
1484+ np .uint64 (0x1249249249249249 ), # Extract mask: bits 0, 3, 6, 9, ...
1485+ np .uint64 (0x30C30C30C30C30C3 ), # Compact stage 1
1486+ np .uint64 (0xF00F00F00F00F00F ), # Compact stage 2
1487+ np .uint64 (0x00FF0000FF0000FF ), # Compact stage 3
1488+ np .uint64 (0x00FF00000000FFFF ), # Compact stage 4
1489+ np .uint64 (0x00000000001FFFFF ), # Compact stage 5
1490+ )
1491+
1492+
1493+ def _decode_morton_2d (z : npt .NDArray [np .intp ]) -> npt .NDArray [np .intp ]:
1494+ """Decode 2D Morton codes using magic number bit manipulation.
1495+
1496+ This extracts interleaved x,y coordinates from Morton codes using
1497+ parallel bit operations instead of bit-by-bit loops.
1498+ """
1499+ out = np .zeros ((len (z ), 2 ), dtype = np .intp )
1500+
1501+ # Extract x (bits 0, 2, 4, ...) and compact to (bits 0, 1, 2, ...)
1502+ x = z & _MORTON_2D_MASKS [0 ]
1503+ x = (x | (x >> 1 )) & _MORTON_2D_MASKS [1 ]
1504+ x = (x | (x >> 2 )) & _MORTON_2D_MASKS [2 ]
1505+ x = (x | (x >> 4 )) & _MORTON_2D_MASKS [3 ]
1506+ x = (x | (x >> 8 )) & _MORTON_2D_MASKS [4 ]
1507+ out [:, 0 ] = x
1508+
1509+ # Extract y (bits 1, 3, 5, ...) and compact
1510+ y = (z >> 1 ) & _MORTON_2D_MASKS [0 ]
1511+ y = (y | (y >> 1 )) & _MORTON_2D_MASKS [1 ]
1512+ y = (y | (y >> 2 )) & _MORTON_2D_MASKS [2 ]
1513+ y = (y | (y >> 4 )) & _MORTON_2D_MASKS [3 ]
1514+ y = (y | (y >> 8 )) & _MORTON_2D_MASKS [4 ]
1515+ out [:, 1 ] = y
1516+
1517+ return out
1518+
1519+
1520+ def _decode_morton_3d (z : npt .NDArray [np .intp ]) -> npt .NDArray [np .intp ]:
1521+ """Decode 3D Morton codes using magic number bit manipulation.
1522+
1523+ This extracts interleaved x,y,z coordinates from Morton codes using
1524+ parallel bit operations instead of bit-by-bit loops.
1525+ """
1526+ # Convert to uint64 for bitwise operations with large masks
1527+ z_u64 = z .astype (np .uint64 )
1528+ out = np .zeros ((len (z ), 3 ), dtype = np .intp )
1529+
1530+ for dim in range (3 ):
1531+ x = (z_u64 >> dim ) & _MORTON_3D_MASKS [0 ]
1532+ x = (x ^ (x >> 2 )) & _MORTON_3D_MASKS [1 ]
1533+ x = (x ^ (x >> 4 )) & _MORTON_3D_MASKS [2 ]
1534+ x = (x ^ (x >> 8 )) & _MORTON_3D_MASKS [3 ]
1535+ x = (x ^ (x >> 16 )) & _MORTON_3D_MASKS [4 ]
1536+ x = (x ^ (x >> 32 )) & _MORTON_3D_MASKS [5 ]
1537+ out [:, dim ] = x
1538+
1539+ return out
1540+
1541+
14701542def decode_morton_vectorized (
14711543 z : npt .NDArray [np .intp ], chunk_shape : tuple [int , ...]
14721544) -> npt .NDArray [np .intp ]:
@@ -1486,9 +1558,18 @@ def decode_morton_vectorized(
14861558 """
14871559 n_dims = len (chunk_shape )
14881560 bits = tuple ((c - 1 ).bit_length () for c in chunk_shape )
1489- max_coords_bits = max (bits ) if bits else 0
14901561
1491- # Output array: each row is a decoded coordinate
1562+ # Use magic number optimization for 2D/3D with uniform bit widths.
1563+ # Magic numbers have bit limits: 2D supports up to 32 bits, 3D up to 21 bits per dimension.
1564+ if len (set (bits )) == 1 : # All dimensions have same bit width
1565+ max_bits = bits [0 ] if bits else 0
1566+ if n_dims == 2 and max_bits <= 32 :
1567+ return _decode_morton_2d (z )
1568+ if n_dims == 3 and max_bits <= 21 :
1569+ return _decode_morton_3d (z )
1570+
1571+ # Fall back to generic bit-by-bit decoding
1572+ max_coords_bits = max (bits ) if bits else 0
14921573 out = np .zeros ((len (z ), n_dims ), dtype = np .intp )
14931574
14941575 input_bit = 0
0 commit comments