Skip to content

Commit e27ce5a

Browse files
committed
Fix: List prefix does not do overeager string match
1 parent a02d996 commit e27ce5a

File tree

5 files changed

+51
-20
lines changed

5 files changed

+51
-20
lines changed

src/zarr/storage/_memory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from zarr.core.buffer import Buffer, gpu
88
from zarr.core.buffer.core import default_buffer_prototype
99
from zarr.core.common import concurrent_map
10-
from zarr.storage._utils import _normalize_byte_range_index
10+
from zarr.storage._utils import _normalize_byte_range_index, _normalize_prefix
1111

1212
if TYPE_CHECKING:
1313
from collections.abc import AsyncIterator, Iterable, MutableMapping
@@ -152,6 +152,7 @@ async def list(self) -> AsyncIterator[str]:
152152
async def list_prefix(self, prefix: str) -> AsyncIterator[str]:
153153
# docstring inherited
154154
# note: we materialize all dict keys into a list here so we can mutate the dict in-place (e.g. in delete_prefix)
155+
prefix = _normalize_prefix(prefix)
155156
for key in list(self._store_dict):
156157
if key.startswith(prefix):
157158
yield key

src/zarr/storage/_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@
1313
from zarr.core.buffer import Buffer
1414

1515

16+
def _normalize_prefix(prefix: str) -> str:
17+
"""Normalize a store prefix to ensure it has a trailing slash.
18+
19+
This ensures that prefix matching uses directory-like semantics,
20+
so that e.g. prefix "a" does not match keys under "a_extra/".
21+
"""
22+
if prefix != "" and not prefix.endswith("/"):
23+
return prefix + "/"
24+
return prefix
25+
26+
1627
def normalize_path(path: str | bytes | Path | None) -> str:
1728
if path is None:
1829
result = ""

src/zarr/storage/_zip.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
SuffixByteRequest,
1717
)
1818
from zarr.core.buffer import Buffer, BufferPrototype
19+
from zarr.storage._utils import _normalize_prefix
1920

2021
if TYPE_CHECKING:
2122
from collections.abc import AsyncIterator, Iterable
@@ -261,6 +262,7 @@ async def list(self) -> AsyncIterator[str]:
261262

262263
async def list_prefix(self, prefix: str) -> AsyncIterator[str]:
263264
# docstring inherited
265+
prefix = _normalize_prefix(prefix)
264266
async for key in self.list():
265267
if key.startswith(prefix):
266268
yield key

src/zarr/testing/store.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -442,23 +442,46 @@ async def test_list(self, store: S) -> None:
442442
async def test_list_prefix(self, store: S) -> None:
443443
"""
444444
Test that the `list_prefix` method works as intended. Given a prefix, it should return
445-
all the keys in storage that start with this prefix.
445+
all the keys under that prefix, treating the prefix as a directory path.
446446
"""
447-
prefixes = ("", "a/", "a/b/", "a/b/c/")
448447
data = self.buffer_cls.from_bytes(b"")
449-
fname = "zarr.json"
450-
store_dict = {p + fname: data for p in prefixes}
451-
448+
store_dict = {
449+
"zarr.json": data,
450+
"a/zarr.json": data,
451+
"a/b/zarr.json": data,
452+
"a/b/c/zarr.json": data,
453+
"a_extra/zarr.json": data,
454+
}
452455
await store._set_many(store_dict.items())
456+
all_keys = sorted(store_dict.keys())
457+
458+
a_keys = ["a/b/c/zarr.json", "a/b/zarr.json", "a/zarr.json"]
459+
ab_keys = ["a/b/c/zarr.json", "a/b/zarr.json"]
460+
461+
# query prefix -> expected keys
462+
test_cases: dict[str, list[str]] = {
463+
# empty prefix returns everything
464+
"": all_keys,
465+
# with trailing /
466+
"a/": a_keys,
467+
"a/b/": ab_keys,
468+
"a/b/c/": ["a/b/c/zarr.json"],
469+
"a_extra/": ["a_extra/zarr.json"],
470+
# without trailing / should behave the same as with /
471+
"a": a_keys,
472+
"a/b": ab_keys,
473+
"a/b/c": ["a/b/c/zarr.json"],
474+
"a_extra": ["a_extra/zarr.json"],
475+
# partial prefix that doesn't match any directory
476+
"a_e": [],
477+
# prefix that doesn't match anything
478+
"b": [],
479+
"b/": [],
480+
}
453481

454-
for prefix in prefixes:
455-
observed = tuple(sorted(await _collect_aiterator(store.list_prefix(prefix))))
456-
expected: tuple[str, ...] = ()
457-
for key in store_dict:
458-
if key.startswith(prefix):
459-
expected += (key,)
460-
expected = tuple(sorted(expected))
461-
assert observed == expected
482+
for prefix, expected in test_cases.items():
483+
observed = sorted(await _collect_aiterator(store.list_prefix(prefix)))
484+
assert observed == expected, f"list_prefix({prefix!r}): {observed} != {expected}"
462485

463486
async def test_list_empty_path(self, store: S) -> None:
464487
"""

tests/test_store/test_memory.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@ def test_store_supports_writes(self, store: MemoryStore) -> None:
5757
def test_store_supports_listing(self, store: MemoryStore) -> None:
5858
assert store.supports_listing
5959

60-
async def test_list_prefix(self, store: MemoryStore) -> None:
61-
assert True
62-
6360
@pytest.mark.parametrize("dtype", ["uint8", "float32", "int64"])
6461
@pytest.mark.parametrize("zarr_format", [2, 3])
6562
async def test_deterministic_size(
@@ -163,9 +160,6 @@ def test_store_supports_writes(self, store: GpuMemoryStore) -> None:
163160
def test_store_supports_listing(self, store: GpuMemoryStore) -> None:
164161
assert store.supports_listing
165162

166-
async def test_list_prefix(self, store: GpuMemoryStore) -> None:
167-
assert True
168-
169163
def test_dict_reference(self, store: GpuMemoryStore) -> None:
170164
store_dict: dict[str, Any] = {}
171165
result = GpuMemoryStore(store_dict=store_dict)

0 commit comments

Comments
 (0)