|
2 | 2 |
|
3 | 3 | from abc import abstractmethod |
4 | 4 | from collections.abc import Mapping |
5 | | -from typing import TYPE_CHECKING, Generic, Protocol, TypeGuard, TypeVar, runtime_checkable |
| 5 | +from dataclasses import dataclass |
| 6 | +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeGuard, TypeVar, runtime_checkable |
6 | 7 |
|
7 | 8 | from typing_extensions import ReadOnly, TypedDict |
8 | 9 |
|
|
19 | 20 | from zarr.core.array_spec import ArraySpec |
20 | 21 | from zarr.core.chunk_grids import ChunkGrid |
21 | 22 | from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType |
22 | | - from zarr.core.indexing import SelectorTuple |
| 23 | + from zarr.core.indexing import ChunkProjection, SelectorTuple |
23 | 24 | from zarr.core.metadata import ArrayMetadata |
24 | 25 |
|
25 | 26 | __all__ = [ |
|
32 | 33 | "CodecInput", |
33 | 34 | "CodecOutput", |
34 | 35 | "CodecPipeline", |
| 36 | + "PreparedWrite", |
35 | 37 | "SupportsSyncCodec", |
36 | 38 | ] |
37 | 39 |
|
@@ -200,9 +202,188 @@ class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer]): |
200 | 202 | """Base class for array-to-array codecs.""" |
201 | 203 |
|
202 | 204 |
|
| 205 | +def _is_complete_selection(selection: Any, shape: tuple[int, ...]) -> bool: |
| 206 | + """Check whether a chunk selection covers the entire chunk shape.""" |
| 207 | + if not isinstance(selection, tuple): |
| 208 | + selection = (selection,) |
| 209 | + for sel, dim_len in zip(selection, shape, strict=False): |
| 210 | + if isinstance(sel, int): |
| 211 | + if dim_len != 1: |
| 212 | + return False |
| 213 | + elif isinstance(sel, slice): |
| 214 | + start, stop, step = sel.indices(dim_len) |
| 215 | + if not (start == 0 and stop == dim_len and step == 1): |
| 216 | + return False |
| 217 | + else: |
| 218 | + return False |
| 219 | + return True |
| 220 | + |
| 221 | + |
| 222 | +@dataclass |
| 223 | +class PreparedWrite: |
| 224 | + """Result of prepare_write: existing encoded chunk bytes + selection info.""" |
| 225 | + |
| 226 | + chunk_dict: dict[tuple[int, ...], Buffer | None] |
| 227 | + inner_codec_chain: Any # CodecChain |
| 228 | + inner_chunk_spec: ArraySpec |
| 229 | + indexer: list[ChunkProjection] |
| 230 | + value_selection: SelectorTuple | None = None |
| 231 | + # If not None, slice value with this before using inner out_selections. |
| 232 | + # For sharding: the outer out_selection from batch_info. |
| 233 | + # For non-sharded: None (inner out_selection IS the outer out_selection). |
| 234 | + write_full_shard: bool = True |
| 235 | + # True when the entire shard blob will be written from scratch (either |
| 236 | + # because the shard doesn't exist yet or because the selection is complete). |
| 237 | + # Used by ShardingCodec.finalize_write to decide between set vs set_range. |
| 238 | + is_complete_shard: bool = False |
| 239 | + # True when the outer selection covers the entire shard. When True, |
| 240 | + # the indexer is empty and finalize_write receives the shard value |
| 241 | + # via shard_data. The codec then encodes the full shard in one shot |
| 242 | + # rather than iterating over individual inner chunks. |
| 243 | + shard_data: NDBuffer | None = None |
| 244 | + # The full shard value for complete-selection writes. Set by the pipeline |
| 245 | + # when is_complete_shard is True, before calling finalize_write. |
| 246 | + |
| 247 | + |
203 | 248 | class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]): |
204 | 249 | """Base class for array-to-bytes codecs.""" |
205 | 250 |
|
| 251 | + @property |
| 252 | + def inner_codec_chain(self) -> Any: |
| 253 | + """The codec chain for decoding inner chunks after deserialization. |
| 254 | +
|
| 255 | + Returns None by default — the pipeline should use its own codec_chain. |
| 256 | + ShardingCodec overrides to return its inner codec chain. |
| 257 | + """ |
| 258 | + return None |
| 259 | + |
| 260 | + def deserialize( |
| 261 | + self, raw: Buffer | None, chunk_spec: ArraySpec |
| 262 | + ) -> dict[tuple[int, ...], Buffer | None]: |
| 263 | + """Pure compute: unpack stored bytes into per-inner-chunk buffers. |
| 264 | +
|
| 265 | + Default implementation: single chunk at (0,). |
| 266 | + ShardingCodec overrides to decode shard index and slice blob into per-chunk buffers. |
| 267 | + """ |
| 268 | + return {(0,): raw} |
| 269 | + |
| 270 | + def serialize( |
| 271 | + self, chunk_dict: dict[tuple[int, ...], Buffer | None], chunk_spec: ArraySpec |
| 272 | + ) -> Buffer | None: |
| 273 | + """Pure compute: pack per-inner-chunk buffers into a storage blob. |
| 274 | +
|
| 275 | + Default implementation: return the single chunk's bytes (or None if absent). |
| 276 | + ShardingCodec overrides to concatenate chunks + build index. |
| 277 | + Returns None if all chunks are empty (caller should delete the key). |
| 278 | + """ |
| 279 | + return chunk_dict.get((0,)) |
| 280 | + |
| 281 | + def prepare_read_sync( |
| 282 | + self, |
| 283 | + byte_getter: Any, |
| 284 | + chunk_spec: ArraySpec, |
| 285 | + chunk_selection: SelectorTuple, |
| 286 | + codec_chain: Any, |
| 287 | + aa_chain: Any, |
| 288 | + ab_pair: Any, |
| 289 | + bb_chain: Any, |
| 290 | + ) -> NDBuffer | None: |
| 291 | + """IO + full decode for the selected region. Returns decoded sub-array.""" |
| 292 | + raw = byte_getter.get_sync(prototype=chunk_spec.prototype) |
| 293 | + chunk_array: NDBuffer | None = codec_chain.decode_chunk( |
| 294 | + raw, chunk_spec, aa_chain, ab_pair, bb_chain |
| 295 | + ) |
| 296 | + if chunk_array is not None: |
| 297 | + return chunk_array[chunk_selection] |
| 298 | + return None |
| 299 | + |
| 300 | + def prepare_write_sync( |
| 301 | + self, |
| 302 | + byte_setter: Any, |
| 303 | + chunk_spec: ArraySpec, |
| 304 | + chunk_selection: SelectorTuple, |
| 305 | + out_selection: SelectorTuple, |
| 306 | + codec_chain: Any, |
| 307 | + ) -> PreparedWrite: |
| 308 | + """IO + deserialize. Returns PreparedWrite for the pipeline to decode/merge/encode.""" |
| 309 | + is_complete = _is_complete_selection(chunk_selection, chunk_spec.shape) |
| 310 | + existing: Buffer | None = None |
| 311 | + if not is_complete: |
| 312 | + existing = byte_setter.get_sync(prototype=chunk_spec.prototype) |
| 313 | + chunk_dict = self.deserialize(existing, chunk_spec) |
| 314 | + inner_chain = self.inner_codec_chain or codec_chain |
| 315 | + return PreparedWrite( |
| 316 | + chunk_dict=chunk_dict, |
| 317 | + inner_codec_chain=inner_chain, |
| 318 | + inner_chunk_spec=chunk_spec, |
| 319 | + indexer=[((0,), chunk_selection, out_selection, is_complete)], # type: ignore[list-item] |
| 320 | + ) |
| 321 | + |
| 322 | + async def prepare_read( |
| 323 | + self, |
| 324 | + byte_getter: Any, |
| 325 | + chunk_spec: ArraySpec, |
| 326 | + chunk_selection: SelectorTuple, |
| 327 | + codec_chain: Any, |
| 328 | + aa_chain: Any, |
| 329 | + ab_pair: Any, |
| 330 | + bb_chain: Any, |
| 331 | + ) -> NDBuffer | None: |
| 332 | + """Async IO + full decode for the selected region. Returns decoded sub-array.""" |
| 333 | + raw = await byte_getter.get(prototype=chunk_spec.prototype) |
| 334 | + chunk_array: NDBuffer | None = codec_chain.decode_chunk( |
| 335 | + raw, chunk_spec, aa_chain, ab_pair, bb_chain |
| 336 | + ) |
| 337 | + if chunk_array is not None: |
| 338 | + return chunk_array[chunk_selection] |
| 339 | + return None |
| 340 | + |
| 341 | + async def prepare_write( |
| 342 | + self, |
| 343 | + byte_setter: Any, |
| 344 | + chunk_spec: ArraySpec, |
| 345 | + chunk_selection: SelectorTuple, |
| 346 | + out_selection: SelectorTuple, |
| 347 | + codec_chain: Any, |
| 348 | + ) -> PreparedWrite: |
| 349 | + """Async IO + deserialize. Returns PreparedWrite for the pipeline to decode/merge/encode.""" |
| 350 | + is_complete = _is_complete_selection(chunk_selection, chunk_spec.shape) |
| 351 | + existing: Buffer | None = None |
| 352 | + if not is_complete: |
| 353 | + existing = await byte_setter.get(prototype=chunk_spec.prototype) |
| 354 | + chunk_dict = self.deserialize(existing, chunk_spec) |
| 355 | + inner_chain = self.inner_codec_chain or codec_chain |
| 356 | + return PreparedWrite( |
| 357 | + chunk_dict=chunk_dict, |
| 358 | + inner_codec_chain=inner_chain, |
| 359 | + inner_chunk_spec=chunk_spec, |
| 360 | + indexer=[((0,), chunk_selection, out_selection, is_complete)], # type: ignore[list-item] |
| 361 | + ) |
| 362 | + |
| 363 | + def finalize_write_sync( |
| 364 | + self, prepared: PreparedWrite, chunk_spec: ArraySpec, byte_setter: Any |
| 365 | + ) -> None: |
| 366 | + """Serialize prepared chunk_dict and write to store. |
| 367 | +
|
| 368 | + Default: serialize to a single blob and call set (or delete if all empty). |
| 369 | + ShardingCodec overrides this for byte-range writes when inner codecs are fixed-size. |
| 370 | + """ |
| 371 | + blob = self.serialize(prepared.chunk_dict, chunk_spec) |
| 372 | + if blob is None: |
| 373 | + byte_setter.delete_sync() |
| 374 | + else: |
| 375 | + byte_setter.set_sync(blob) |
| 376 | + |
| 377 | + async def finalize_write( |
| 378 | + self, prepared: PreparedWrite, chunk_spec: ArraySpec, byte_setter: Any |
| 379 | + ) -> None: |
| 380 | + """Async version of finalize_write_sync.""" |
| 381 | + blob = self.serialize(prepared.chunk_dict, chunk_spec) |
| 382 | + if blob is None: |
| 383 | + await byte_setter.delete() |
| 384 | + else: |
| 385 | + await byte_setter.set(blob) |
| 386 | + |
206 | 387 |
|
207 | 388 | class BytesBytesCodec(BaseCodec[Buffer, Buffer]): |
208 | 389 | """Base class for bytes-to-bytes codecs.""" |
|
0 commit comments