@@ -198,7 +198,9 @@ async def from_bytes(
198198
199199 @classmethod
200200 def create_empty (
201- cls , chunks_per_shard : ChunkCoords , buffer_prototype : BufferPrototype | None = None
201+ cls ,
202+ chunks_per_shard : ChunkCoords ,
203+ buffer_prototype : BufferPrototype | None = None ,
202204 ) -> _ShardReader :
203205 if buffer_prototype is None :
204206 buffer_prototype = default_buffer_prototype ()
@@ -248,7 +250,9 @@ def merge_with_morton_order(
248250
249251 @classmethod
250252 def create_empty (
251- cls , chunks_per_shard : ChunkCoords , buffer_prototype : BufferPrototype | None = None
253+ cls ,
254+ chunks_per_shard : ChunkCoords ,
255+ buffer_prototype : BufferPrototype | None = None ,
252256 ) -> _ShardBuilder :
253257 if buffer_prototype is None :
254258 buffer_prototype = default_buffer_prototype ()
@@ -330,13 +334,17 @@ async def finalize(
330334
331335
332336class _ChunkCoordsByteSlice (NamedTuple ):
337+ """Holds a chunk's coordinates and it's byte range in a serialized shard."""
338+
333339 coords : ChunkCoords
334340 byte_slice : slice
335341
336342
337343@dataclass (frozen = True )
338344class ShardingCodec (
339- ArrayBytesCodec , ArrayBytesCodecPartialDecodeMixin , ArrayBytesCodecPartialEncodeMixin
345+ ArrayBytesCodec ,
346+ ArrayBytesCodecPartialDecodeMixin ,
347+ ArrayBytesCodecPartialEncodeMixin ,
340348):
341349 chunk_shape : ChunkCoords
342350 codecs : tuple [Codec , ...]
@@ -446,7 +454,10 @@ async def _decode_single(
446454
447455 # setup output array
448456 out = chunk_spec .prototype .nd_buffer .create (
449- shape = shard_shape , dtype = shard_spec .dtype , order = shard_spec .order , fill_value = 0
457+ shape = shard_shape ,
458+ dtype = shard_spec .dtype ,
459+ order = shard_spec .order ,
460+ fill_value = 0 ,
450461 )
451462 shard_dict = await _ShardReader .from_bytes (shard_bytes , self , chunks_per_shard )
452463
@@ -490,7 +501,10 @@ async def _decode_partial_single(
490501
491502 # setup output array
492503 out = shard_spec .prototype .nd_buffer .create (
493- shape = indexer .shape , dtype = shard_spec .dtype , order = shard_spec .order , fill_value = 0
504+ shape = indexer .shape ,
505+ dtype = shard_spec .dtype ,
506+ order = shard_spec .order ,
507+ fill_value = 0 ,
494508 )
495509
496510 indexed_chunks = list (indexer )
@@ -533,101 +547,6 @@ async def _decode_partial_single(
533547 else :
534548 return out
535549
536- async def _load_partial_shard_maybe (
537- self ,
538- byte_getter : ByteGetter ,
539- prototype : BufferPrototype ,
540- chunks_per_shard : ChunkCoords ,
541- all_chunk_coords : set [ChunkCoords ],
542- ) -> ShardMapping | None :
543- shard_index = await self ._load_shard_index_maybe (byte_getter , chunks_per_shard )
544- if shard_index is None :
545- return None
546-
547- chunks = [
548- _ChunkCoordsByteSlice (chunk_coords , slice (* chunk_byte_slice ))
549- for chunk_coords in all_chunk_coords
550- if (chunk_byte_slice := shard_index .get_chunk_slice (chunk_coords ))
551- ]
552- if len (chunks ) == 0 :
553- return {}
554-
555- groups = self ._coalesce_chunks (chunks )
556-
557- shard_dicts = await concurrent_map (
558- [(group , byte_getter , prototype ) for group in groups ],
559- self ._get_group_bytes ,
560- config .get ("async.concurrency" ),
561- )
562-
563- shard_dict : ShardMutableMapping = {}
564- for d in shard_dicts :
565- shard_dict .update (d )
566-
567- return shard_dict
568-
569- def _coalesce_chunks (
570- self ,
571- chunks : list [_ChunkCoordsByteSlice ],
572- ) -> list [list [_ChunkCoordsByteSlice ]]:
573- """
574- Combine chunks from a single shard into groups that should be read together
575- in a single request.
576-
577- Respects the following configuration options:
578- - `sharding.read.coalesce_max_gap_bytes`: The maximum gap between
579- chunks to coalesce into a single group.
580- - `sharding.read.coalesce_max_bytes`: The maximum number of bytes in a group.
581- """
582- max_gap_bytes = config .get ("sharding.read.coalesce_max_gap_bytes" )
583- coalesce_max_bytes = config .get ("sharding.read.coalesce_max_bytes" )
584-
585- sorted_chunks = sorted (chunks , key = lambda c : c .byte_slice .start )
586-
587- groups = []
588- current_group = [sorted_chunks [0 ]]
589-
590- for chunk in sorted_chunks [1 :]:
591- gap_to_chunk = chunk .byte_slice .start - current_group [- 1 ].byte_slice .stop
592- size_if_coalesced = chunk .byte_slice .stop - current_group [0 ].byte_slice .start
593- if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes :
594- current_group .append (chunk )
595- else :
596- groups .append (current_group )
597- current_group = [chunk ]
598-
599- groups .append (current_group )
600-
601- return groups
602-
603- async def _get_group_bytes (
604- self ,
605- group : list [_ChunkCoordsByteSlice ],
606- byte_getter : ByteGetter ,
607- prototype : BufferPrototype ,
608- ) -> ShardMapping :
609- group_start = group [0 ].byte_slice .start
610- group_end = group [- 1 ].byte_slice .stop
611-
612- # A single call to retrieve the bytes for the entire group.
613- group_bytes = await byte_getter .get (
614- prototype = prototype ,
615- byte_range = RangeByteRequest (group_start , group_end ),
616- )
617- if group_bytes is None :
618- return {}
619-
620- # Extract the bytes corresponding to each chunk in group from group_bytes.
621- shard_dict = {}
622- for chunk in group :
623- chunk_slice = slice (
624- chunk .byte_slice .start - group_start ,
625- chunk .byte_slice .stop - group_start ,
626- )
627- shard_dict [chunk .coords ] = group_bytes [chunk_slice ]
628-
629- return shard_dict
630-
631550 async def _encode_single (
632551 self ,
633552 shard_array : NDBuffer ,
@@ -688,7 +607,9 @@ async def _encode_partial_single(
688607
689608 indexer = list (
690609 get_indexer (
691- selection , shape = shard_shape , chunk_grid = RegularChunkGrid (chunk_shape = chunk_shape )
610+ selection ,
611+ shape = shard_shape ,
612+ chunk_grid = RegularChunkGrid (chunk_shape = chunk_shape ),
692613 )
693614 )
694615
@@ -762,7 +683,8 @@ def _shard_index_size(self, chunks_per_shard: ChunkCoords) -> int:
762683 get_pipeline_class ()
763684 .from_codecs (self .index_codecs )
764685 .compute_encoded_size (
765- 16 * product (chunks_per_shard ), self ._get_index_chunk_spec (chunks_per_shard )
686+ 16 * product (chunks_per_shard ),
687+ self ._get_index_chunk_spec (chunks_per_shard ),
766688 )
767689 )
768690
@@ -807,7 +729,8 @@ async def _load_shard_index_maybe(
807729 )
808730 else :
809731 index_bytes = await byte_getter .get (
810- prototype = numpy_buffer_prototype (), byte_range = SuffixByteRequest (shard_index_size )
732+ prototype = numpy_buffer_prototype (),
733+ byte_range = SuffixByteRequest (shard_index_size ),
811734 )
812735 if index_bytes is not None :
813736 return await self ._decode_shard_index (index_bytes , chunks_per_shard )
@@ -821,7 +744,10 @@ async def _load_shard_index(
821744 ) or _ShardIndex .create_empty (chunks_per_shard )
822745
823746 async def _load_full_shard_maybe (
824- self , byte_getter : ByteGetter , prototype : BufferPrototype , chunks_per_shard : ChunkCoords
747+ self ,
748+ byte_getter : ByteGetter ,
749+ prototype : BufferPrototype ,
750+ chunks_per_shard : ChunkCoords ,
825751 ) -> _ShardReader | None :
826752 shard_bytes = await byte_getter .get (prototype = prototype )
827753
@@ -831,6 +757,110 @@ async def _load_full_shard_maybe(
831757 else None
832758 )
833759
760+ async def _load_partial_shard_maybe (
761+ self ,
762+ byte_getter : ByteGetter ,
763+ prototype : BufferPrototype ,
764+ chunks_per_shard : ChunkCoords ,
765+ all_chunk_coords : set [ChunkCoords ],
766+ ) -> ShardMapping | None :
767+ """
768+ Read bytes from `byte_getter` for the case where the read is less than a full shard.
769+ Returns a mapping of chunk coordinates to bytes.
770+ """
771+ shard_index = await self ._load_shard_index_maybe (byte_getter , chunks_per_shard )
772+ if shard_index is None :
773+ return None
774+
775+ chunks = [
776+ _ChunkCoordsByteSlice (chunk_coords , slice (* chunk_byte_slice ))
777+ for chunk_coords in all_chunk_coords
778+ # Drop chunks where index lookup fails
779+ if (chunk_byte_slice := shard_index .get_chunk_slice (chunk_coords ))
780+ ]
781+ if len (chunks ) == 0 :
782+ return {}
783+
784+ groups = self ._coalesce_chunks (chunks )
785+
786+ shard_dicts = await concurrent_map (
787+ [(group , byte_getter , prototype ) for group in groups ],
788+ self ._get_group_bytes ,
789+ config .get ("async.concurrency" ),
790+ )
791+
792+ shard_dict : ShardMutableMapping = {}
793+ for d in shard_dicts :
794+ shard_dict .update (d )
795+
796+ return shard_dict
797+
798+ def _coalesce_chunks (
799+ self ,
800+ chunks : list [_ChunkCoordsByteSlice ],
801+ ) -> list [list [_ChunkCoordsByteSlice ]]:
802+ """
803+ Combine chunks from a single shard into groups that should be read together
804+ in a single request.
805+
806+ Respects the following configuration options:
807+ - `sharding.read.coalesce_max_gap_bytes`: The maximum gap between
808+ chunks to coalesce into a single group.
809+ - `sharding.read.coalesce_max_bytes`: The maximum number of bytes in a group.
810+ """
811+ max_gap_bytes = config .get ("sharding.read.coalesce_max_gap_bytes" )
812+ coalesce_max_bytes = config .get ("sharding.read.coalesce_max_bytes" )
813+
814+ sorted_chunks = sorted (chunks , key = lambda c : c .byte_slice .start )
815+
816+ groups = []
817+ current_group = [sorted_chunks [0 ]]
818+
819+ for chunk in sorted_chunks [1 :]:
820+ gap_to_chunk = chunk .byte_slice .start - current_group [- 1 ].byte_slice .stop
821+ size_if_coalesced = chunk .byte_slice .stop - current_group [0 ].byte_slice .start
822+ if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes :
823+ current_group .append (chunk )
824+ else :
825+ groups .append (current_group )
826+ current_group = [chunk ]
827+
828+ groups .append (current_group )
829+
830+ return groups
831+
832+ async def _get_group_bytes (
833+ self ,
834+ group : list [_ChunkCoordsByteSlice ],
835+ byte_getter : ByteGetter ,
836+ prototype : BufferPrototype ,
837+ ) -> ShardMapping :
838+ """
839+ Reads a possibly coalesced group of one or more chunks from a shard.
840+ Returns a mapping of chunk coordinates to bytes.
841+ """
842+ group_start = group [0 ].byte_slice .start
843+ group_end = group [- 1 ].byte_slice .stop
844+
845+ # A single call to retrieve the bytes for the entire group.
846+ group_bytes = await byte_getter .get (
847+ prototype = prototype ,
848+ byte_range = RangeByteRequest (group_start , group_end ),
849+ )
850+ if group_bytes is None :
851+ return {}
852+
853+ # Extract the bytes corresponding to each chunk in group from group_bytes.
854+ shard_dict = {}
855+ for chunk in group :
856+ chunk_slice = slice (
857+ chunk .byte_slice .start - group_start ,
858+ chunk .byte_slice .stop - group_start ,
859+ )
860+ shard_dict [chunk .coords ] = group_bytes [chunk_slice ]
861+
862+ return shard_dict
863+
834864 def compute_encoded_size (self , input_byte_length : int , shard_spec : ArraySpec ) -> int :
835865 chunks_per_shard = self ._get_chunks_per_shard (shard_spec )
836866 return input_byte_length + self ._shard_index_size (chunks_per_shard )
0 commit comments