|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from collections.abc import Iterable, Mapping, MutableMapping |
| 3 | +from collections.abc import Iterable, Mapping, MutableMapping, Sequence |
4 | 4 | from dataclasses import dataclass, replace |
5 | 5 | from enum import Enum |
6 | 6 | from functools import lru_cache |
|
45 | 45 | from zarr.core.dtype.npy.int import UInt64 |
46 | 46 | from zarr.core.indexing import ( |
47 | 47 | BasicIndexer, |
| 48 | + ChunkProjection, |
48 | 49 | SelectorTuple, |
49 | 50 | _morton_order, |
50 | 51 | _morton_order_keys, |
@@ -574,21 +575,26 @@ async def _encode_partial_single( |
574 | 575 | chunks_per_shard = self._get_chunks_per_shard(shard_spec) |
575 | 576 | chunk_spec = self._get_chunk_spec(shard_spec) |
576 | 577 |
|
577 | | - shard_reader = await self._load_full_shard_maybe( |
578 | | - byte_getter=byte_setter, |
579 | | - prototype=chunk_spec.prototype, |
580 | | - chunks_per_shard=chunks_per_shard, |
581 | | - ) |
582 | | - shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) |
583 | | - # Use vectorized lookup for better performance |
584 | | - shard_dict = shard_reader.to_dict_vectorized(np.asarray(_morton_order(chunks_per_shard))) |
585 | | - |
586 | 578 | indexer = list( |
587 | 579 | get_indexer( |
588 | 580 | selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape) |
589 | 581 | ) |
590 | 582 | ) |
591 | 583 |
|
| 584 | + if self._is_complete_shard_write(indexer, chunks_per_shard): |
| 585 | + shard_dict = dict.fromkeys(morton_order_iter(chunks_per_shard)) |
| 586 | + else: |
| 587 | + shard_reader = await self._load_full_shard_maybe( |
| 588 | + byte_getter=byte_setter, |
| 589 | + prototype=chunk_spec.prototype, |
| 590 | + chunks_per_shard=chunks_per_shard, |
| 591 | + ) |
| 592 | + shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) |
| 593 | + # Use vectorized lookup for better performance |
| 594 | + shard_dict = shard_reader.to_dict_vectorized( |
| 595 | + np.asarray(_morton_order(chunks_per_shard)) |
| 596 | + ) |
| 597 | + |
592 | 598 | await self.codec_pipeline.write( |
593 | 599 | [ |
594 | 600 | ( |
@@ -661,6 +667,16 @@ def _is_total_shard( |
661 | 667 | chunk_coords in all_chunk_coords for chunk_coords in c_order_iter(chunks_per_shard) |
662 | 668 | ) |
663 | 669 |
|
| 670 | + def _is_complete_shard_write( |
| 671 | + self, |
| 672 | + indexed_chunks: Sequence[ChunkProjection], |
| 673 | + chunks_per_shard: tuple[int, ...], |
| 674 | + ) -> bool: |
| 675 | + all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks} |
| 676 | + return self._is_total_shard(all_chunk_coords, chunks_per_shard) and all( |
| 677 | + is_complete_chunk for *_, is_complete_chunk in indexed_chunks |
| 678 | + ) |
| 679 | + |
664 | 680 | async def _decode_shard_index( |
665 | 681 | self, index_bytes: Buffer, chunks_per_shard: tuple[int, ...] |
666 | 682 | ) -> _ShardIndex: |
|
0 commit comments