forked from lance-format/lance-context
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_persistence.py
More file actions
82 lines (57 loc) · 2.35 KB
/
test_persistence.py
File metadata and controls
82 lines (57 loc) · 2.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from __future__ import annotations
import sys
from io import BytesIO
from pathlib import Path
from typing import Any
import pytest
PACKAGE_ROOT = Path(__file__).resolve().parents[2] / "python" / "python"
if str(PACKAGE_ROOT) not in sys.path:
sys.path.insert(0, str(PACKAGE_ROOT))
lance = pytest.importorskip("lance")
from lance_context.api import Context
def _read_rows(uri: str, version: int | None = None) -> list[dict[str, object]]:
dataset = lance.dataset(uri, version=version) if version is not None else lance.dataset(uri)
table = dataset.to_table()
return table.to_pylist()
def _image_bytes(image: Any, *, format: str | None = None) -> bytes:
buffer = BytesIO()
image.save(buffer, format=format or getattr(image, "format", None) or "PNG")
return buffer.getvalue()
def test_text_round_trip(tmp_path: Path) -> None:
uri = tmp_path / "context.lance"
ctx = Context.create(str(uri))
ctx.add("user", "hello world")
rows = _read_rows(str(uri))
assert len(rows) == 1
record = rows[0]
assert record["role"] == "user"
assert record["text_payload"] == "hello world"
assert record["binary_payload"] is None
assert record["content_type"] == "text/plain"
def test_image_round_trip(tmp_path: Path) -> None:
Image = pytest.importorskip("PIL.Image")
uri = tmp_path / "context.lance"
ctx = Context.create(str(uri))
image = Image.new("RGB", (4, 4), color="magenta")
ctx.add("assistant", image)
rows = _read_rows(str(uri))
assert len(rows) == 1
record = rows[0]
assert record["role"] == "assistant"
assert record["text_payload"] is None
assert record["content_type"] == "image/png"
assert record["binary_payload"] == _image_bytes(image)
def test_time_travel_checkout(tmp_path: Path) -> None:
uri = tmp_path / "context.lance"
ctx = Context.create(str(uri))
ctx.add("system", "first-entry")
version_first = ctx.version()
ctx.add("system", "second-entry")
version_second = ctx.version()
assert version_second >= version_first
ctx.checkout(version_first)
rows_versioned = _read_rows(str(uri), version=ctx.version())
assert len(rows_versioned) == 1
assert rows_versioned[0]["text_payload"] == "first-entry"
latest_rows = _read_rows(str(uri))
assert [row["text_payload"] for row in latest_rows] == ["first-entry", "second-entry"]