Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3914.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added `zarr.testing.models` (`ArrayNode`, `GroupNode`) for declarative zarr hierarchy descriptions with `materialize()` and `from_store()` round-trip helpers, and `zarr.testing.strategies.trees()`, a Hypothesis strategy that produces realistic hierarchies with prefix-colliding sibling names — useful when testing custom zarr stores against a reference implementation.
149 changes: 149 additions & 0 deletions src/zarr/testing/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""Models for comparison testing.

The tree descriptors (GroupNode / ArrayNode) are pure data structures.
Materialization writes it into any zarr store.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal

import numpy as np

import zarr
import zarr.abc.store
import zarr.api.asynchronous
from zarr.core.buffer import default_buffer_prototype
from zarr.core.sync import sync

if TYPE_CHECKING:
from collections.abc import Iterator

_PROTOTYPE = default_buffer_prototype()


@dataclass(frozen=True)
class ArrayNode:
shape: tuple[int, ...]
dtype: np.dtype


@dataclass(frozen=True)
class GroupNode:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this approach to "structurally model a group" reminds me of GroupSpec in pydantic-zarr. this is clearly a useful representation, so we should figure out how we can fit it into zarr-python proper

children: dict[str, ArrayNode | GroupNode] = field(default_factory=dict)

def walk(self, prefix: str = "") -> Iterator[tuple[str, Node]]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i feel like "iter_nodes"or similar is a bit more expressive than "walk"

"""Yield ``(path, child)`` for every node, depth-first."""
for name, child in self.children.items():
p = f"{prefix}/{name}" if prefix else name
yield p, child
if isinstance(child, GroupNode):
yield from child.walk(p)

def nodes(self, prefix: str = "", *, include_root: bool = False) -> list[str]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all these iteration routines can do IO. so returning list[stuff] requires doing all the IO before the first value is available. Iterator[stuff] or Generator[stuff] gives us a bit more flexibility if we want to stream results instead of doing all the IO up-front.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this applies in a world where we put these routines on the group class that's bound to a store)

"""Return paths of all nodes, optionally including root."""
root = [prefix] if include_root else []
return root + [p for p, _ in self.walk(prefix)]

def groups(self, prefix: str = "", *, include_root: bool = False) -> list[str]:
"""Return paths of all group nodes, optionally including root."""
root = [prefix] if include_root else []
return root + [p for p, c in self.walk(prefix) if isinstance(c, GroupNode)]

def arrays(self, prefix: str = "") -> list[str]:
"""Return paths of all array nodes."""
return [p for p, c in self.walk(prefix) if isinstance(c, ArrayNode)]

def materialize(
self,
store: zarr.abc.store.Store,
*,
zarr_format: Literal[2, 3] = 3,
mode: Literal["w", "a"] = "w",
) -> zarr.Group:
"""Write this tree into *store* and return the root group.

``mode`` is forwarded to :func:`zarr.open_group` when opening the root.
"""
root = zarr.open_group(store, mode=mode, zarr_format=zarr_format)

def _write(group: zarr.Group, node: GroupNode) -> None:
for name, child in node.children.items():
if isinstance(child, ArrayNode):
group.create_array(name, shape=child.shape, dtype=child.dtype)
else:
_write(group.create_group(name), child)

_write(root, self)
return root

@classmethod
def from_dict(cls, d: dict[str, Any]) -> GroupNode:
"""Convert a nested dict (with ArrayNode leaves) to a GroupNode tree."""
children: dict[str, ArrayNode | GroupNode] = {}
for name, value in d.items():
if isinstance(value, ArrayNode):
children[name] = value
else:
children[name] = cls.from_dict(value)
return cls(children=children)

@classmethod
def from_paths(cls, arrays: set[str], groups: set[str]) -> GroupNode:
"""Build a GroupNode from flat sets of array and group paths.

Example::

GroupNode.from_paths(
arrays={"a/x", "b"},
groups={"a"},
)
"""
tree: dict[str, Any] = {}
for path in sorted(groups - {""}):
current = tree
for part in path.split("/"):
current = current.setdefault(part, {})
for path in sorted(arrays):
parts = path.split("/")
current = tree
for part in parts[:-1]:
current = current.setdefault(part, {})
current[parts[-1]] = ArrayNode(shape=(1,), dtype=np.dtype("i4"))
return cls.from_dict(tree)

@classmethod
async def from_store_async(cls, store: zarr.abc.store.Store) -> GroupNode:
"""Build a GroupNode by reading a zarr store's structure.

Example::

await GroupNode.from_store_async(some_memory_store)
"""
root = await zarr.api.asynchronous.open_group(store, mode="r")
tree: dict[str, Any] = {}
async for path, obj in root.members(max_depth=None):
parts = path.split("/")
current = tree
if isinstance(obj, zarr.AsyncArray):
for part in parts[:-1]:
current = current.setdefault(part, {})
current[parts[-1]] = ArrayNode(shape=obj.shape, dtype=obj.dtype)
else:
for part in parts:
current = current.setdefault(part, {})
return cls.from_dict(tree)

@classmethod
def from_store(cls, store: zarr.abc.store.Store) -> GroupNode:
"""Build a GroupNode by reading a zarr store's structure.

Example::

GroupNode.from_store(some_memory_store)
"""
return sync(cls.from_store_async(store))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using sync here forces any consumer to use the zarr-python event loop. flagging this for visibility



Node = ArrayNode | GroupNode
229 changes: 229 additions & 0 deletions src/zarr/testing/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,22 @@
from zarr.storage import MemoryStore, StoreLike
from zarr.storage._common import _dereference_path
from zarr.storage._utils import normalize_path
from zarr.testing.models import ArrayNode, GroupNode, Node
from zarr.types import AnyArray

TrueOrFalse = Literal[True, False]


def default_fs_case_insensitive() -> bool:
"""Return whether the current platform defaults to a case-insensitive filesystem.

macOS APFS and Windows NTFS are case-insensitive by default; Linux
filesystems are typically case-sensitive. Used as the default for the
``trees()`` strategy so sibling names won't collide when the tree is
materialized into a filesystem-backed store.
"""
return sys.platform in ("darwin", "win32")

# Copied from Xarray
_attr_keys = st.text(st.characters(), min_size=1)
_attr_values = st.recursive(
Expand Down Expand Up @@ -645,3 +657,220 @@ def chunk_paths(draw: st.DrawFn, ndim: int, numblocks: tuple[int, ...], subset:
)
subset_slicer = slice(draw(st.integers(min_value=0, max_value=ndim))) if subset else slice(None)
return "/".join(map(str, blockidx[subset_slicer]))


# ---------------------------------------------------------------------------
# Name strategies — pool-based derivation for prefix collisions
# ---------------------------------------------------------------------------

# Short affix drawn from the zarr key alphabet for prefix/suffix collisions.
affix = st.text(zarr_key_chars, min_size=1, max_size=3)
separators = st.sampled_from(["_", "-", "."])


def similar_name(
non_sibling_names: set[str],
sibling_names: set[str],
*,
case_insensitive: bool | None = None,
) -> st.SearchStrategy[str]:
"""Strategy that picks a name similar to existing ones, for prefix collisions.

Either an affixed variant of a sibling name (e.g. ``"foo"`` → ``"foo-bar"``)
or an exact copy of a non-sibling (e.g. cousin) name.

Parameters
----------
non_sibling_names : set[str]
Names from elsewhere in the tree (cousins, ancestors, etc.).
These may be reused exactly as the generated name.
sibling_names : set[str]
Names of nodes at the same level as the one being generated.
These are used as bases for affixed variants (prefix/suffix collisions).
case_insensitive : bool | None
If ``True``, produced names will not differ from ``sibling_names``
only in letter case. Required when the target store backs onto a
case-insensitive filesystem (macOS APFS, Windows NTFS default),
where ``foo`` and ``FOO`` resolve to the same path. Defaults to the
current platform's filesystem behavior when ``None``.

Examples
--------
Given a tree like::

/
├── alpha/
│ ├── x
│ └── y
└── beta/
├── z
└── ? ← generating a new name here

``sibling_names = {"z"}``, ``non_sibling_names = {"alpha", "x", "y", "beta"}``.
The strategy might produce ``"z_0"`` (affixed sibling) or ``"x"`` (reused cousin).
or ``beta`` or a new random name entirely.
"""
if case_insensitive is None:
case_insensitive = default_fs_case_insensitive()
siblings = sorted(sibling_names)
non_siblings = sorted(non_sibling_names - sibling_names)

strategies = []
if bool(siblings):
# if there are any named siblings we can affix a sibling name
# choosing to not affix all names in the tree (e.g. cousin names) because
# that doesn't seem likely to bring a bug, and would expand the search space.
strategies.append(
st.sampled_from(siblings).flatmap(
lambda base: st.one_of(
separators.flatmap(lambda sep: affix.map(lambda afx: base + sep + afx)),
separators.flatmap(lambda sep: affix.map(lambda afx: afx + sep + base)),
)
)
)
if bool(non_siblings):
strategies.append(st.sampled_from(non_siblings))
key = str.casefold if case_insensitive else (lambda n: n)
forbidden = {key(n) for n in sibling_names}
return st.one_of(*strategies).filter(lambda name: key(name) not in forbidden)


@st.composite
def unique_sibling_names(
draw: st.DrawFn,
existing_names: set[str],
num_names: int,
existing_siblings: set[str] | None = None,
*,
case_insensitive: bool | None = None,
) -> list[str]:
"""Draw *num_names* unique names, biased toward collisions with existing ones.

Parameters
----------
existing_names : set[str]
All names already present in the tree. Used to generate
similar-looking candidates (affixed siblings, reused cousins).
num_names : int
Number of unique names to generate.
existing_siblings : set[str] | None
Names already present at the destination that must not be reused.
Used by valid_moves to avoid collisions with existing children.

Returns
-------
list[str]
The generated names, unique among themselves and not in existing_siblings.
"""
if case_insensitive is None:
case_insensitive = default_fs_case_insensitive()
generated_names: set[str] = set()
already_taken = existing_siblings or set()
key = str.casefold if case_insensitive else (lambda n: n)

for _ in range(num_names):
excluded = generated_names | already_taken
forbidden = {key(n) for n in excluded}
# Filter the whole strategy — similar_name can produce collisions.
generated_names.add(
draw(
(
st.one_of(
node_names,
similar_name(
existing_names, excluded, case_insensitive=case_insensitive
),
)
if bool(existing_names) | bool(generated_names)
else node_names
).filter(
lambda name_, f=forbidden, k=key: k(name_) not in f # type: ignore[misc]
)
)
)
return list(generated_names)


# ---------------------------------------------------------------------------
# Tree skeleton + naming
# ---------------------------------------------------------------------------


def skeletons(*, max_leaves: int = 50, max_children: int = 4) -> st.SearchStrategy[GroupNode]:
"""Unnamed tree skeletons via st.recursive.

Always returns a GroupNode (the root group). Child names are placeholder
indices ("0", "1", ...); real names are assigned later by ``trees``.
"""
leaves = st.just(ArrayNode(shape=(1,), dtype=np.dtype("i4")))

def extend(children: st.SearchStrategy[Node]) -> st.SearchStrategy[GroupNode]:
return st.lists(children, min_size=1, max_size=max_children).map(
lambda child_list: GroupNode(
children={str(i): child for i, child in enumerate(child_list)}
)
)

# Wrap in extend so the top level is always a GroupNode (the root group).
return extend(st.recursive(leaves, extend, max_leaves=max_leaves))


@st.composite
def trees(
draw: st.DrawFn,
*,
max_leaves: st.SearchStrategy[int] = st.integers(min_value=5, max_value=50), # noqa: B008
max_children: st.SearchStrategy[int] = st.integers(min_value=1, max_value=4), # noqa: B008
case_insensitive: bool | None = None,
) -> GroupNode:
"""Strategy producing a GroupNode tree descriptor.

Uses st.recursive for the tree structure (good structural shrinking)
and @composite for name assignment (pool-based prefix collisions).

Parameters
----------
case_insensitive : bool | None
If ``True``, sibling names will not differ only in letter case, so
the tree can be materialized into a store backed by a case-insensitive
filesystem (macOS APFS, Windows NTFS default), where ``foo`` and
``FOO`` resolve to the same path. Defaults to the current platform's
filesystem behavior when ``None``.

Examples
--------
>>> @given(tree=trees())
... def test_something(tree):
... reference = tree.materialize(zarr.storage.MemoryStore())
... under_test = tree.materialize(my_icechunk_store())
... # compare...
"""
if case_insensitive is None:
case_insensitive = default_fs_case_insensitive()

def rebuild_with_names(
group: GroupNode, existing_names: set[str]
) -> tuple[GroupNode, set[str]]:
new_names = draw(
unique_sibling_names(
existing_names,
num_names=len(group.children),
case_insensitive=case_insensitive,
)
)
existing_names = existing_names | set(new_names)
children: dict[str, Node] = {}
for name, child in zip(new_names, group.children.values(), strict=True):
if isinstance(child, GroupNode):
child, existing_names = rebuild_with_names(child, existing_names)
children[name] = child
return GroupNode(children=children), existing_names

# Two-step generation: first draw the tree structure (skeletons uses
# st.recursive which gives good structural shrinking), then assign real
# names via @composite (which allows pool-based derivation for realistic
# prefix collisions). Doing both in one step would sacrifice either
# structural shrinking or name similarity.
skeleton = draw(skeletons(max_leaves=draw(max_leaves), max_children=draw(max_children)))
result, _ = rebuild_with_names(skeleton, set())
return result
Loading
Loading