@@ -90,9 +90,9 @@ async def get(
9090 self , prototype : BufferPrototype , byte_range : ByteRequest | None = None
9191 ) -> Buffer | None :
9292 assert byte_range is None , "byte_range is not supported within shards"
93- assert (
94- prototype == default_buffer_prototype ()
95- ), f"prototype is not supported within shards currently. diff: { prototype } != { default_buffer_prototype () } "
93+ assert prototype == default_buffer_prototype (), (
94+ f" prototype is not supported within shards currently. diff: { prototype } != { default_buffer_prototype ()} "
95+ )
9696 return self .shard_dict .get (self .chunk_coords )
9797
9898
@@ -124,9 +124,7 @@ def chunks_per_shard(self) -> ChunkCoords:
124124 def _localize_chunk (self , chunk_coords : ChunkCoords ) -> ChunkCoords :
125125 return tuple (
126126 chunk_i % shard_i
127- for chunk_i , shard_i in zip (
128- chunk_coords , self .offsets_and_lengths .shape , strict = False
129- )
127+ for chunk_i , shard_i in zip (chunk_coords , self .offsets_and_lengths .shape , strict = False )
130128 )
131129
132130 def is_all_empty (self ) -> bool :
@@ -143,9 +141,7 @@ def get_chunk_slice(self, chunk_coords: ChunkCoords) -> tuple[int, int] | None:
143141 else :
144142 return (int (chunk_start ), int (chunk_start + chunk_len ))
145143
146- def set_chunk_slice (
147- self , chunk_coords : ChunkCoords , chunk_slice : slice | None
148- ) -> None :
144+ def set_chunk_slice (self , chunk_coords : ChunkCoords , chunk_slice : slice | None ) -> None :
149145 localized_chunk = self ._localize_chunk (chunk_coords )
150146 if chunk_slice is None :
151147 self .offsets_and_lengths [localized_chunk ] = (MAX_UINT_64 , MAX_UINT_64 )
@@ -167,11 +163,7 @@ def is_dense(self, chunk_byte_length: int) -> bool:
167163
168164 # Are all non-empty offsets unique?
169165 if len (
170- {
171- offset
172- for offset , _ in sorted_offsets_and_lengths
173- if offset != MAX_UINT_64
174- }
166+ {offset for offset , _ in sorted_offsets_and_lengths if offset != MAX_UINT_64 }
175167 ) != len (sorted_offsets_and_lengths ):
176168 return False
177169
@@ -275,9 +267,7 @@ def __setitem__(self, chunk_coords: ChunkCoords, value: Buffer) -> None:
275267 chunk_start = len (self .buf )
276268 chunk_length = len (value )
277269 self .buf += value
278- self .index .set_chunk_slice (
279- chunk_coords , slice (chunk_start , chunk_start + chunk_length )
280- )
270+ self .index .set_chunk_slice (chunk_coords , slice (chunk_start , chunk_start + chunk_length ))
281271
282272 def __delitem__ (self , chunk_coords : ChunkCoords ) -> None :
283273 raise NotImplementedError
@@ -291,9 +281,7 @@ async def finalize(
291281 if index_location == ShardingCodecIndexLocation .start :
292282 empty_chunks_mask = self .index .offsets_and_lengths [..., 0 ] == MAX_UINT_64
293283 self .index .offsets_and_lengths [~ empty_chunks_mask , 0 ] += len (index_bytes )
294- index_bytes = await index_encoder (
295- self .index
296- ) # encode again with corrected offsets
284+ index_bytes = await index_encoder (self .index ) # encode again with corrected offsets
297285 out_buf = index_bytes + self .buf
298286 else :
299287 out_buf = self .buf + index_bytes
@@ -371,8 +359,7 @@ def __init__(
371359 chunk_shape : ChunkCoordsLike ,
372360 codecs : Iterable [Codec | dict [str , JSON ]] = (BytesCodec (),),
373361 index_codecs : Iterable [Codec | dict [str , JSON ]] = (BytesCodec (), Crc32cCodec ()),
374- index_location : ShardingCodecIndexLocation
375- | str = ShardingCodecIndexLocation .end ,
362+ index_location : ShardingCodecIndexLocation | str = ShardingCodecIndexLocation .end ,
376363 ) -> None :
377364 chunk_shape_parsed = parse_shapelike (chunk_shape )
378365 codecs_parsed = parse_codecs (codecs )
@@ -402,9 +389,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:
402389 object .__setattr__ (self , "chunk_shape" , parse_shapelike (config ["chunk_shape" ]))
403390 object .__setattr__ (self , "codecs" , parse_codecs (config ["codecs" ]))
404391 object .__setattr__ (self , "index_codecs" , parse_codecs (config ["index_codecs" ]))
405- object .__setattr__ (
406- self , "index_location" , parse_index_location (config ["index_location" ])
407- )
392+ object .__setattr__ (self , "index_location" , parse_index_location (config ["index_location" ]))
408393
409394 # Use instance-local lru_cache to avoid memory leaks
410395 # object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec))
@@ -433,9 +418,7 @@ def to_dict(self) -> dict[str, JSON]:
433418
434419 def evolve_from_array_spec (self , array_spec : ArraySpec ) -> Self :
435420 shard_spec = self ._get_chunk_spec (array_spec )
436- evolved_codecs = tuple (
437- c .evolve_from_array_spec (array_spec = shard_spec ) for c in self .codecs
438- )
421+ evolved_codecs = tuple (c .evolve_from_array_spec (array_spec = shard_spec ) for c in self .codecs )
439422 if evolved_codecs != self .codecs :
440423 return replace (self , codecs = evolved_codecs )
441424 return self
@@ -610,9 +593,7 @@ async def _encode_single(
610593 shard_array ,
611594 )
612595
613- return await shard_builder .finalize (
614- self .index_location , self ._encode_shard_index
615- )
596+ return await shard_builder .finalize (self .index_location , self ._encode_shard_index )
616597
617598 async def _encode_partial_single (
618599 self ,
@@ -672,8 +653,7 @@ def _is_total_shard(
672653 self , all_chunk_coords : set [ChunkCoords ], chunks_per_shard : ChunkCoords
673654 ) -> bool :
674655 return len (all_chunk_coords ) == product (chunks_per_shard ) and all (
675- chunk_coords in all_chunk_coords
676- for chunk_coords in c_order_iter (chunks_per_shard )
656+ chunk_coords in all_chunk_coords for chunk_coords in c_order_iter (chunks_per_shard )
677657 )
678658
679659 async def _decode_shard_index (
@@ -699,9 +679,7 @@ async def _encode_shard_index(self, index: _ShardIndex) -> Buffer:
699679 .encode (
700680 [
701681 (
702- get_ndbuffer_class ().from_numpy_array (
703- index .offsets_and_lengths
704- ),
682+ get_ndbuffer_class ().from_numpy_array (index .offsets_and_lengths ),
705683 self ._get_index_chunk_spec (index .chunks_per_shard ),
706684 )
707685 ],
@@ -810,9 +788,10 @@ async def _load_partial_shard_maybe(
810788 _ChunkCoordsByteSlice (chunk_coords , slice (* chunk_byte_slice ))
811789 for chunk_coords in all_chunk_coords
812790 # Drop chunks where index lookup fails
791+ # e.g. when write_empty_chunks = False and the chunk is empty
813792 if (chunk_byte_slice := shard_index .get_chunk_slice (chunk_coords ))
814793 ]
815- if len (chunks ) < len ( all_chunk_coords ) :
794+ if len (chunks ) == 0 :
816795 return None
817796
818797 groups = self ._coalesce_chunks (chunks )
@@ -854,9 +833,7 @@ def _coalesce_chunks(
854833
855834 for chunk in sorted_chunks [1 :]:
856835 gap_to_chunk = chunk .byte_slice .start - current_group [- 1 ].byte_slice .stop
857- size_if_coalesced = (
858- chunk .byte_slice .stop - current_group [0 ].byte_slice .start
859- )
836+ size_if_coalesced = chunk .byte_slice .stop - current_group [0 ].byte_slice .start
860837 if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes :
861838 current_group .append (chunk )
862839 else :
@@ -899,9 +876,7 @@ async def _get_group_bytes(
899876
900877 return shard_dict
901878
902- def compute_encoded_size (
903- self , input_byte_length : int , shard_spec : ArraySpec
904- ) -> int :
879+ def compute_encoded_size (self , input_byte_length : int , shard_spec : ArraySpec ) -> int :
905880 chunks_per_shard = self ._get_chunks_per_shard (shard_spec )
906881 return input_byte_length + self ._shard_index_size (chunks_per_shard )
907882
0 commit comments