Skip to content

Commit ed4334b

Browse files
committed
test: cover version checkout persistence
1 parent 77a6883 commit ed4334b

1 file changed

Lines changed: 25 additions & 3 deletions

File tree

tests/test_persistence.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
from io import BytesIO
55
from pathlib import Path
6+
from typing import Any
67

78
import pytest
89

@@ -11,7 +12,6 @@
1112
sys.path.insert(0, str(PACKAGE_ROOT))
1213

1314
lance = pytest.importorskip("lance")
14-
Image = pytest.importorskip("PIL.Image")
1515

1616
from lance_context.api import Context
1717

@@ -22,9 +22,9 @@ def _read_rows(uri: str, version: int | None = None) -> list[dict[str, object]]:
2222
return table.to_pylist()
2323

2424

25-
def _image_bytes(image: Image.Image, *, format: str | None = None) -> bytes:
25+
def _image_bytes(image: Any, *, format: str | None = None) -> bytes:
2626
buffer = BytesIO()
27-
image.save(buffer, format=format or image.format or "PNG")
27+
image.save(buffer, format=format or getattr(image, "format", None) or "PNG")
2828
return buffer.getvalue()
2929

3030

@@ -44,6 +44,7 @@ def test_text_round_trip(tmp_path: Path) -> None:
4444

4545

4646
def test_image_round_trip(tmp_path: Path) -> None:
47+
Image = pytest.importorskip("PIL.Image")
4748
uri = tmp_path / "context.lance"
4849
ctx = Context.create(str(uri))
4950

@@ -58,3 +59,24 @@ def test_image_round_trip(tmp_path: Path) -> None:
5859
assert record["text_payload"] is None
5960
assert record["content_type"] == "image/png"
6061
assert record["binary_payload"] == _image_bytes(image)
62+
63+
64+
def test_time_travel_checkout(tmp_path: Path) -> None:
65+
uri = tmp_path / "context.lance"
66+
ctx = Context.create(str(uri))
67+
68+
ctx.add("system", "first-entry")
69+
version_first = ctx.version()
70+
71+
ctx.add("system", "second-entry")
72+
version_second = ctx.version()
73+
assert version_second >= version_first
74+
75+
ctx.checkout(version_first)
76+
77+
rows_versioned = _read_rows(str(uri), version=ctx.version())
78+
assert len(rows_versioned) == 1
79+
assert rows_versioned[0]["text_payload"] == "first-entry"
80+
81+
latest_rows = _read_rows(str(uri))
82+
assert [row["text_payload"] for row in latest_rows] == ["first-entry", "second-entry"]

0 commit comments

Comments
 (0)