Skip to content

Commit 82a7e30

Browse files
committed
Reorder methods in sharding.py, add docstring + commenting
1 parent 6133e78 commit 82a7e30

1 file changed

Lines changed: 134 additions & 104 deletions

File tree

src/zarr/codecs/sharding.py

Lines changed: 134 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,9 @@ async def from_bytes(
198198

199199
@classmethod
200200
def create_empty(
201-
cls, chunks_per_shard: ChunkCoords, buffer_prototype: BufferPrototype | None = None
201+
cls,
202+
chunks_per_shard: ChunkCoords,
203+
buffer_prototype: BufferPrototype | None = None,
202204
) -> _ShardReader:
203205
if buffer_prototype is None:
204206
buffer_prototype = default_buffer_prototype()
@@ -248,7 +250,9 @@ def merge_with_morton_order(
248250

249251
@classmethod
250252
def create_empty(
251-
cls, chunks_per_shard: ChunkCoords, buffer_prototype: BufferPrototype | None = None
253+
cls,
254+
chunks_per_shard: ChunkCoords,
255+
buffer_prototype: BufferPrototype | None = None,
252256
) -> _ShardBuilder:
253257
if buffer_prototype is None:
254258
buffer_prototype = default_buffer_prototype()
@@ -330,13 +334,17 @@ async def finalize(
330334

331335

332336
class _ChunkCoordsByteSlice(NamedTuple):
337+
"""Holds a chunk's coordinates and it's byte range in a serialized shard."""
338+
333339
coords: ChunkCoords
334340
byte_slice: slice
335341

336342

337343
@dataclass(frozen=True)
338344
class ShardingCodec(
339-
ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin
345+
ArrayBytesCodec,
346+
ArrayBytesCodecPartialDecodeMixin,
347+
ArrayBytesCodecPartialEncodeMixin,
340348
):
341349
chunk_shape: ChunkCoords
342350
codecs: tuple[Codec, ...]
@@ -446,7 +454,10 @@ async def _decode_single(
446454

447455
# setup output array
448456
out = chunk_spec.prototype.nd_buffer.create(
449-
shape=shard_shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0
457+
shape=shard_shape,
458+
dtype=shard_spec.dtype,
459+
order=shard_spec.order,
460+
fill_value=0,
450461
)
451462
shard_dict = await _ShardReader.from_bytes(shard_bytes, self, chunks_per_shard)
452463

@@ -490,7 +501,10 @@ async def _decode_partial_single(
490501

491502
# setup output array
492503
out = shard_spec.prototype.nd_buffer.create(
493-
shape=indexer.shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0
504+
shape=indexer.shape,
505+
dtype=shard_spec.dtype,
506+
order=shard_spec.order,
507+
fill_value=0,
494508
)
495509

496510
indexed_chunks = list(indexer)
@@ -533,101 +547,6 @@ async def _decode_partial_single(
533547
else:
534548
return out
535549

536-
async def _load_partial_shard_maybe(
537-
self,
538-
byte_getter: ByteGetter,
539-
prototype: BufferPrototype,
540-
chunks_per_shard: ChunkCoords,
541-
all_chunk_coords: set[ChunkCoords],
542-
) -> ShardMapping | None:
543-
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
544-
if shard_index is None:
545-
return None
546-
547-
chunks = [
548-
_ChunkCoordsByteSlice(chunk_coords, slice(*chunk_byte_slice))
549-
for chunk_coords in all_chunk_coords
550-
if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords))
551-
]
552-
if len(chunks) == 0:
553-
return {}
554-
555-
groups = self._coalesce_chunks(chunks)
556-
557-
shard_dicts = await concurrent_map(
558-
[(group, byte_getter, prototype) for group in groups],
559-
self._get_group_bytes,
560-
config.get("async.concurrency"),
561-
)
562-
563-
shard_dict: ShardMutableMapping = {}
564-
for d in shard_dicts:
565-
shard_dict.update(d)
566-
567-
return shard_dict
568-
569-
def _coalesce_chunks(
570-
self,
571-
chunks: list[_ChunkCoordsByteSlice],
572-
) -> list[list[_ChunkCoordsByteSlice]]:
573-
"""
574-
Combine chunks from a single shard into groups that should be read together
575-
in a single request.
576-
577-
Respects the following configuration options:
578-
- `sharding.read.coalesce_max_gap_bytes`: The maximum gap between
579-
chunks to coalesce into a single group.
580-
- `sharding.read.coalesce_max_bytes`: The maximum number of bytes in a group.
581-
"""
582-
max_gap_bytes = config.get("sharding.read.coalesce_max_gap_bytes")
583-
coalesce_max_bytes = config.get("sharding.read.coalesce_max_bytes")
584-
585-
sorted_chunks = sorted(chunks, key=lambda c: c.byte_slice.start)
586-
587-
groups = []
588-
current_group = [sorted_chunks[0]]
589-
590-
for chunk in sorted_chunks[1:]:
591-
gap_to_chunk = chunk.byte_slice.start - current_group[-1].byte_slice.stop
592-
size_if_coalesced = chunk.byte_slice.stop - current_group[0].byte_slice.start
593-
if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes:
594-
current_group.append(chunk)
595-
else:
596-
groups.append(current_group)
597-
current_group = [chunk]
598-
599-
groups.append(current_group)
600-
601-
return groups
602-
603-
async def _get_group_bytes(
604-
self,
605-
group: list[_ChunkCoordsByteSlice],
606-
byte_getter: ByteGetter,
607-
prototype: BufferPrototype,
608-
) -> ShardMapping:
609-
group_start = group[0].byte_slice.start
610-
group_end = group[-1].byte_slice.stop
611-
612-
# A single call to retrieve the bytes for the entire group.
613-
group_bytes = await byte_getter.get(
614-
prototype=prototype,
615-
byte_range=RangeByteRequest(group_start, group_end),
616-
)
617-
if group_bytes is None:
618-
return {}
619-
620-
# Extract the bytes corresponding to each chunk in group from group_bytes.
621-
shard_dict = {}
622-
for chunk in group:
623-
chunk_slice = slice(
624-
chunk.byte_slice.start - group_start,
625-
chunk.byte_slice.stop - group_start,
626-
)
627-
shard_dict[chunk.coords] = group_bytes[chunk_slice]
628-
629-
return shard_dict
630-
631550
async def _encode_single(
632551
self,
633552
shard_array: NDBuffer,
@@ -688,7 +607,9 @@ async def _encode_partial_single(
688607

689608
indexer = list(
690609
get_indexer(
691-
selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape)
610+
selection,
611+
shape=shard_shape,
612+
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
692613
)
693614
)
694615

@@ -762,7 +683,8 @@ def _shard_index_size(self, chunks_per_shard: ChunkCoords) -> int:
762683
get_pipeline_class()
763684
.from_codecs(self.index_codecs)
764685
.compute_encoded_size(
765-
16 * product(chunks_per_shard), self._get_index_chunk_spec(chunks_per_shard)
686+
16 * product(chunks_per_shard),
687+
self._get_index_chunk_spec(chunks_per_shard),
766688
)
767689
)
768690

@@ -807,7 +729,8 @@ async def _load_shard_index_maybe(
807729
)
808730
else:
809731
index_bytes = await byte_getter.get(
810-
prototype=numpy_buffer_prototype(), byte_range=SuffixByteRequest(shard_index_size)
732+
prototype=numpy_buffer_prototype(),
733+
byte_range=SuffixByteRequest(shard_index_size),
811734
)
812735
if index_bytes is not None:
813736
return await self._decode_shard_index(index_bytes, chunks_per_shard)
@@ -821,7 +744,10 @@ async def _load_shard_index(
821744
) or _ShardIndex.create_empty(chunks_per_shard)
822745

823746
async def _load_full_shard_maybe(
824-
self, byte_getter: ByteGetter, prototype: BufferPrototype, chunks_per_shard: ChunkCoords
747+
self,
748+
byte_getter: ByteGetter,
749+
prototype: BufferPrototype,
750+
chunks_per_shard: ChunkCoords,
825751
) -> _ShardReader | None:
826752
shard_bytes = await byte_getter.get(prototype=prototype)
827753

@@ -831,6 +757,110 @@ async def _load_full_shard_maybe(
831757
else None
832758
)
833759

760+
async def _load_partial_shard_maybe(
761+
self,
762+
byte_getter: ByteGetter,
763+
prototype: BufferPrototype,
764+
chunks_per_shard: ChunkCoords,
765+
all_chunk_coords: set[ChunkCoords],
766+
) -> ShardMapping | None:
767+
"""
768+
Read bytes from `byte_getter` for the case where the read is less than a full shard.
769+
Returns a mapping of chunk coordinates to bytes.
770+
"""
771+
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
772+
if shard_index is None:
773+
return None
774+
775+
chunks = [
776+
_ChunkCoordsByteSlice(chunk_coords, slice(*chunk_byte_slice))
777+
for chunk_coords in all_chunk_coords
778+
# Drop chunks where index lookup fails
779+
if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords))
780+
]
781+
if len(chunks) == 0:
782+
return {}
783+
784+
groups = self._coalesce_chunks(chunks)
785+
786+
shard_dicts = await concurrent_map(
787+
[(group, byte_getter, prototype) for group in groups],
788+
self._get_group_bytes,
789+
config.get("async.concurrency"),
790+
)
791+
792+
shard_dict: ShardMutableMapping = {}
793+
for d in shard_dicts:
794+
shard_dict.update(d)
795+
796+
return shard_dict
797+
798+
def _coalesce_chunks(
799+
self,
800+
chunks: list[_ChunkCoordsByteSlice],
801+
) -> list[list[_ChunkCoordsByteSlice]]:
802+
"""
803+
Combine chunks from a single shard into groups that should be read together
804+
in a single request.
805+
806+
Respects the following configuration options:
807+
- `sharding.read.coalesce_max_gap_bytes`: The maximum gap between
808+
chunks to coalesce into a single group.
809+
- `sharding.read.coalesce_max_bytes`: The maximum number of bytes in a group.
810+
"""
811+
max_gap_bytes = config.get("sharding.read.coalesce_max_gap_bytes")
812+
coalesce_max_bytes = config.get("sharding.read.coalesce_max_bytes")
813+
814+
sorted_chunks = sorted(chunks, key=lambda c: c.byte_slice.start)
815+
816+
groups = []
817+
current_group = [sorted_chunks[0]]
818+
819+
for chunk in sorted_chunks[1:]:
820+
gap_to_chunk = chunk.byte_slice.start - current_group[-1].byte_slice.stop
821+
size_if_coalesced = chunk.byte_slice.stop - current_group[0].byte_slice.start
822+
if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes:
823+
current_group.append(chunk)
824+
else:
825+
groups.append(current_group)
826+
current_group = [chunk]
827+
828+
groups.append(current_group)
829+
830+
return groups
831+
832+
async def _get_group_bytes(
833+
self,
834+
group: list[_ChunkCoordsByteSlice],
835+
byte_getter: ByteGetter,
836+
prototype: BufferPrototype,
837+
) -> ShardMapping:
838+
"""
839+
Reads a possibly coalesced group of one or more chunks from a shard.
840+
Returns a mapping of chunk coordinates to bytes.
841+
"""
842+
group_start = group[0].byte_slice.start
843+
group_end = group[-1].byte_slice.stop
844+
845+
# A single call to retrieve the bytes for the entire group.
846+
group_bytes = await byte_getter.get(
847+
prototype=prototype,
848+
byte_range=RangeByteRequest(group_start, group_end),
849+
)
850+
if group_bytes is None:
851+
return {}
852+
853+
# Extract the bytes corresponding to each chunk in group from group_bytes.
854+
shard_dict = {}
855+
for chunk in group:
856+
chunk_slice = slice(
857+
chunk.byte_slice.start - group_start,
858+
chunk.byte_slice.stop - group_start,
859+
)
860+
shard_dict[chunk.coords] = group_bytes[chunk_slice]
861+
862+
return shard_dict
863+
834864
def compute_encoded_size(self, input_byte_length: int, shard_spec: ArraySpec) -> int:
835865
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
836866
return input_byte_length + self._shard_index_size(chunks_per_shard)

0 commit comments

Comments
 (0)