diff --git a/changes/3708.misc.md b/changes/3708.misc.md new file mode 100644 index 0000000000..dce7546c97 --- /dev/null +++ b/changes/3708.misc.md @@ -0,0 +1 @@ +Optimize Morton order computation with hypercube optimization, vectorized decoding, and singleton dimension removal, providing 10-45x speedup for typical chunk shapes. diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index beffa99cfa..df79728a85 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -1452,7 +1452,7 @@ def make_slice_selection(selection: Any) -> list[slice]: def decode_morton(z: int, chunk_shape: tuple[int, ...]) -> tuple[int, ...]: # Inspired by compressed morton code as implemented in Neuroglancer # https://github.com/google/neuroglancer/blob/master/src/neuroglancer/datasource/precomputed/volume.md#compressed-morton-code - bits = tuple(math.ceil(math.log2(c)) for c in chunk_shape) + bits = tuple((c - 1).bit_length() for c in chunk_shape) max_coords_bits = max(bits) input_bit = 0 input_value = z @@ -1467,16 +1467,102 @@ def decode_morton(z: int, chunk_shape: tuple[int, ...]) -> tuple[int, ...]: return tuple(out) -@lru_cache +def decode_morton_vectorized( + z: npt.NDArray[np.intp], chunk_shape: tuple[int, ...] +) -> npt.NDArray[np.intp]: + """Vectorized Morton code decoding for multiple z values. + + Parameters + ---------- + z : ndarray + 1D array of Morton codes to decode. + chunk_shape : tuple of int + Shape defining the coordinate space. + + Returns + ------- + ndarray + 2D array of shape (len(z), len(chunk_shape)) containing decoded coordinates. + """ + n_dims = len(chunk_shape) + bits = tuple((c - 1).bit_length() for c in chunk_shape) + + max_coords_bits = max(bits) if bits else 0 + out = np.zeros((len(z), n_dims), dtype=np.intp) + + input_bit = 0 + for coord_bit in range(max_coords_bits): + for dim in range(n_dims): + if coord_bit < bits[dim]: + # Extract bit at position input_bit from all z values + bit_values = (z >> input_bit) & 1 + # Place bit at coord_bit position in dimension dim + out[:, dim] |= bit_values << coord_bit + input_bit += 1 + + return out + + +@lru_cache(maxsize=16) def _morton_order(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]: n_total = product(chunk_shape) - order: list[tuple[int, ...]] = [] - i = 0 + if n_total == 0: + return () + + # Optimization: Remove singleton dimensions to enable magic number usage + # for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand. + singleton_dims = tuple(i for i, s in enumerate(chunk_shape) if s == 1) + if singleton_dims: + squeezed_shape = tuple(s for s in chunk_shape if s != 1) + if squeezed_shape: + # Compute Morton order on squeezed shape + squeezed_order = _morton_order(squeezed_shape) + # Expand coordinates to include singleton dimensions (always 0) + expanded: list[tuple[int, ...]] = [] + for coord in squeezed_order: + full_coord: list[int] = [] + squeezed_idx = 0 + for i in range(len(chunk_shape)): + if chunk_shape[i] == 1: + full_coord.append(0) + else: + full_coord.append(coord[squeezed_idx]) + squeezed_idx += 1 + expanded.append(tuple(full_coord)) + return tuple(expanded) + else: + # All dimensions are singletons, just return the single point + return ((0,) * len(chunk_shape),) + + n_dims = len(chunk_shape) + + # Find the largest power-of-2 hypercube that fits within chunk_shape. + # Within this hypercube, Morton codes are guaranteed to be in bounds. + min_dim = min(chunk_shape) + if min_dim >= 1: + power = min_dim.bit_length() - 1 # floor(log2(min_dim)) + hypercube_size = 1 << power # 2^power + n_hypercube = hypercube_size**n_dims + else: + n_hypercube = 0 + + # Within the hypercube, no bounds checking needed - use vectorized decoding + order: list[tuple[int, ...]] + if n_hypercube > 0: + z_values = np.arange(n_hypercube, dtype=np.intp) + hypercube_coords = decode_morton_vectorized(z_values, chunk_shape) + order = [tuple(row) for row in hypercube_coords] + else: + order = [] + + # For remaining elements, bounds checking is needed + i = n_hypercube while len(order) < n_total: m = decode_morton(i, chunk_shape) if all(x < y for x, y in zip(m, chunk_shape, strict=False)): order.append(m) i += 1 + return tuple(order)