@@ -1489,6 +1489,16 @@ def decode_morton(z: int, chunk_shape: tuple[int, ...]) -> tuple[int, ...]:
14891489 np .uint64 (0x00000000001FFFFF ), # Compact stage 5
14901490)
14911491
1492+ # Magic numbers for 4D Morton decode: extract every 4th bit and compact them.
1493+ # Bits for dimension d are at positions d, d+4, d+8, d+12, ...
1494+ _MORTON_4D_MASKS : tuple [np .uint64 , ...] = (
1495+ np .uint64 (0x1111111111111111 ), # Extract mask: bits 0, 4, 8, 12, ...
1496+ np .uint64 (0x0303030303030303 ), # Compact stage 1
1497+ np .uint64 (0x000F000F000F000F ), # Compact stage 2
1498+ np .uint64 (0x000000FF000000FF ), # Compact stage 3
1499+ np .uint64 (0x000000000000FFFF ), # Compact stage 4
1500+ )
1501+
14921502
14931503def _decode_morton_2d (z : npt .NDArray [np .intp ]) -> npt .NDArray [np .intp ]:
14941504 """Decode 2D Morton codes using magic number bit manipulation.
@@ -1539,6 +1549,27 @@ def _decode_morton_3d(z: npt.NDArray[np.intp]) -> npt.NDArray[np.intp]:
15391549 return out
15401550
15411551
1552+ def _decode_morton_4d (z : npt .NDArray [np .intp ]) -> npt .NDArray [np .intp ]:
1553+ """Decode 4D Morton codes using magic number bit manipulation.
1554+
1555+ This extracts interleaved x,y,z,w coordinates from Morton codes using
1556+ parallel bit operations instead of bit-by-bit loops.
1557+ """
1558+ # Convert to uint64 for bitwise operations with large masks
1559+ z_u64 = z .astype (np .uint64 )
1560+ out = np .zeros ((len (z ), 4 ), dtype = np .intp )
1561+
1562+ for dim in range (4 ):
1563+ x = (z_u64 >> dim ) & _MORTON_4D_MASKS [0 ]
1564+ x = (x ^ (x >> 3 )) & _MORTON_4D_MASKS [1 ]
1565+ x = (x ^ (x >> 6 )) & _MORTON_4D_MASKS [2 ]
1566+ x = (x ^ (x >> 12 )) & _MORTON_4D_MASKS [3 ]
1567+ x = (x ^ (x >> 24 )) & _MORTON_4D_MASKS [4 ]
1568+ out [:, dim ] = x
1569+
1570+ return out
1571+
1572+
15421573def decode_morton_vectorized (
15431574 z : npt .NDArray [np .intp ], chunk_shape : tuple [int , ...]
15441575) -> npt .NDArray [np .intp ]:
@@ -1559,14 +1590,16 @@ def decode_morton_vectorized(
15591590 n_dims = len (chunk_shape )
15601591 bits = tuple ((c - 1 ).bit_length () for c in chunk_shape )
15611592
1562- # Use magic number optimization for 2D/3D with uniform bit widths.
1563- # Magic numbers have bit limits: 2D supports up to 32 bits , 3D up to 21 bits per dimension.
1593+ # Use magic number optimization for 2D/3D/4D with uniform bit widths.
1594+ # Magic numbers have bit limits: 2D up to 32, 3D up to 21, 4D up to 16 bits per dimension.
15641595 if len (set (bits )) == 1 : # All dimensions have same bit width
15651596 max_bits = bits [0 ] if bits else 0
15661597 if n_dims == 2 and max_bits <= 32 :
15671598 return _decode_morton_2d (z )
15681599 if n_dims == 3 and max_bits <= 21 :
15691600 return _decode_morton_3d (z )
1601+ if n_dims == 4 and max_bits <= 16 :
1602+ return _decode_morton_4d (z )
15701603
15711604 # Fall back to generic bit-by-bit decoding
15721605 max_coords_bits = max (bits ) if bits else 0
0 commit comments