Skip to content

Commit caf77a6

Browse files
committed
feat: add python context store bindings
1 parent e3e49d5 commit caf77a6

7 files changed

Lines changed: 236 additions & 12 deletions

File tree

python/Cargo.lock

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ name = "_internal"
1010
crate-type = ["cdylib"]
1111

1212
[dependencies]
13+
chrono = { version = "0.4", default-features = false, features = ["clock"] }
1314
lance-context = { path = "../rust/lance-context" }
1415
pyo3 = { version = "0.25", features = ["extension-module", "abi3-py39", "py-clone"] }
15-
16+
tokio = { version = "1", features = ["rt-multi-thread"] }

python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ lint.select = ["F", "E", "W", "I", "G", "TCH", "PERF", "B019"]
4949

5050
[tool.pyright]
5151
pythonVersion = "3.13"
52-
include = ["python/lance_context/__init__.py"]
52+
include = ["python/lance_context/__init__.py", "python/lance_context/api.py"]
5353
reportMissingTypeStubs = "warning"
5454
reportImportCycles = "error"
5555
reportUnusedImport = "error"
Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
from __future__ import annotations
22

3-
from ._internal import Context # pyright: ignore[reportMissingImports]
4-
from ._internal import version as _version # pyright: ignore[reportMissingImports]
3+
from .api import Context, __version__ # pyright: ignore[reportMissingImports]
54

65
__all__ = ["Context", "__version__"]
7-
8-
__version__ = _version()

python/python/lance_context/api.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from __future__ import annotations
2+
3+
from io import BytesIO
4+
from typing import Any
5+
6+
from ._internal import Context as _Context # pyright: ignore[reportMissingImports]
7+
from ._internal import version as _version # pyright: ignore[reportMissingImports]
8+
9+
__all__ = ["Context", "__version__"]
10+
11+
__version__ = _version()
12+
13+
_ARROW_STREAM_MIME = "application/vnd.apache.arrow.stream"
14+
15+
16+
def _is_module(value: Any, prefix: str) -> bool:
17+
return type(value).__module__.startswith(prefix)
18+
19+
20+
def _get_pyarrow():
21+
try:
22+
import pyarrow as pa # pyright: ignore[reportMissingImports]
23+
except ImportError as exc: # pragma: no cover - optional dependency
24+
raise ImportError(
25+
"pyarrow is required to serialize pandas/polars dataframes"
26+
) from exc
27+
return pa
28+
29+
30+
def _coerce_arrow_table(value: Any):
31+
pa = _get_pyarrow()
32+
if isinstance(value, pa.Table):
33+
return value
34+
if isinstance(value, pa.RecordBatch):
35+
return pa.Table.from_batches([value])
36+
if _is_module(value, "polars."):
37+
table = value.to_arrow()
38+
elif _is_module(value, "pandas."):
39+
table = pa.Table.from_pandas(value)
40+
elif hasattr(value, "to_arrow"):
41+
table = value.to_arrow()
42+
else:
43+
return None
44+
45+
if isinstance(table, pa.RecordBatch):
46+
return pa.Table.from_batches([table])
47+
if not isinstance(table, pa.Table):
48+
raise TypeError("to_arrow() did not return a pyarrow Table or RecordBatch")
49+
return table
50+
51+
52+
def _serialize_dataframe(value: Any):
53+
table = _coerce_arrow_table(value)
54+
if table is None:
55+
return None
56+
pa = _get_pyarrow()
57+
sink = pa.BufferOutputStream()
58+
with pa.ipc.new_stream(sink, table.schema) as writer:
59+
writer.write_table(table)
60+
return sink.getvalue().to_pybytes(), _ARROW_STREAM_MIME
61+
62+
63+
def _serialize_image(value: Any):
64+
if not _is_module(value, "PIL."):
65+
return None
66+
try:
67+
from PIL import Image # pyright: ignore[reportMissingImports]
68+
except ImportError as exc: # pragma: no cover - optional dependency
69+
raise ImportError("Pillow is required to serialize images") from exc
70+
if not isinstance(value, Image.Image):
71+
return None
72+
73+
image_format = value.format or "PNG"
74+
mime = None
75+
if hasattr(value, "get_format_mimetype"):
76+
mime = value.get_format_mimetype()
77+
if not mime:
78+
mime = Image.MIME.get(image_format.upper())
79+
if not mime:
80+
mime = "application/octet-stream"
81+
82+
buffer = BytesIO()
83+
value.save(buffer, format=image_format)
84+
return buffer.getvalue(), mime
85+
86+
87+
def _normalize_content(value: Any, content_type: str | None):
88+
serialized = _serialize_dataframe(value)
89+
if serialized is not None:
90+
payload, inferred = serialized
91+
return payload, content_type or inferred
92+
serialized = _serialize_image(value)
93+
if serialized is not None:
94+
payload, inferred = serialized
95+
return payload, content_type or inferred
96+
return value, content_type
97+
98+
99+
class Context:
100+
def __init__(self, uri: str) -> None:
101+
self._inner = _Context.create(uri)
102+
103+
@classmethod
104+
def create(cls, uri: str) -> Context:
105+
return cls(uri)
106+
107+
def uri(self) -> str:
108+
return self._inner.uri()
109+
110+
def branch(self) -> str:
111+
return self._inner.branch()
112+
113+
def entries(self) -> int:
114+
return self._inner.entries()
115+
116+
def add(
117+
self,
118+
role: str,
119+
content: Any,
120+
content_type: str | None = None,
121+
data_type: str | None = None,
122+
) -> None:
123+
if content_type is not None and data_type is not None:
124+
raise ValueError("Specify only one of content_type or data_type")
125+
if content_type is None:
126+
content_type = data_type
127+
payload, resolved_type = _normalize_content(content, content_type)
128+
self._inner.add(role, payload, resolved_type)
129+
130+
def snapshot(self, label: str | None = None) -> str:
131+
return self._inner.snapshot(label)
132+
133+
def fork(self, branch_name: str) -> Context:
134+
inner = self._inner.fork(branch_name)
135+
return self._from_inner(inner)
136+
137+
def checkout(self, snapshot_id: str) -> None:
138+
self._inner.checkout(snapshot_id)
139+
140+
def __repr__(self) -> str:
141+
return (
142+
f"Context(uri={self._inner.uri()!r}, "
143+
f"branch={self._inner.branch()!r}, "
144+
f"entries={self._inner.entries()})"
145+
)
146+
147+
@classmethod
148+
def _from_inner(cls, inner: _Context) -> Context:
149+
obj = cls.__new__(cls)
150+
obj._inner = inner
151+
return obj

python/src/lib.rs

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
1+
use std::sync::Arc;
2+
3+
use chrono::Utc;
4+
use pyo3::exceptions::PyRuntimeError;
15
use pyo3::prelude::*;
26
use pyo3::types::PyType;
7+
use tokio::runtime::Runtime;
8+
9+
use lance_context::serde::CONTENT_TYPE_TEXT;
10+
use lance_context::{Context as RustContext, ContextRecord, ContextStore};
311

4-
use lance_context::Context as RustContext;
12+
const DEFAULT_BINARY_CONTENT_TYPE: &str = "application/octet-stream";
13+
const BINARY_PLACEHOLDER: &str = "[binary]";
514

615
#[pyfunction]
716
fn version() -> &'static str {
@@ -11,15 +20,26 @@ fn version() -> &'static str {
1120
#[pyclass]
1221
struct Context {
1322
inner: RustContext,
23+
store: ContextStore,
24+
runtime: Arc<Runtime>,
25+
run_id: String,
1426
}
1527

1628
#[pymethods]
1729
impl Context {
1830
#[classmethod]
19-
fn create(_cls: &Bound<'_, PyType>, uri: &str) -> Self {
20-
Self {
31+
fn create(_cls: &Bound<'_, PyType>, uri: &str) -> PyResult<Self> {
32+
let runtime = Arc::new(Runtime::new().map_err(to_py_err)?);
33+
let store = runtime
34+
.block_on(ContextStore::open(uri))
35+
.map_err(to_py_err)?;
36+
let run_id = new_run_id();
37+
Ok(Self {
2138
inner: RustContext::new(uri),
22-
}
39+
store,
40+
runtime,
41+
run_id,
42+
})
2343
}
2444

2545
fn uri(&self) -> &str {
@@ -41,8 +61,44 @@ impl Context {
4161
content: &Bound<'_, PyAny>,
4262
data_type: Option<&str>,
4363
) -> PyResult<()> {
44-
let content_str = content.str()?.to_string();
45-
self.inner.add(role, &content_str, data_type);
64+
let (content_type, text_payload, binary_payload, inner_content) =
65+
match content.extract::<&[u8]>() {
66+
Ok(bytes) => (
67+
data_type
68+
.unwrap_or(DEFAULT_BINARY_CONTENT_TYPE)
69+
.to_string(),
70+
None,
71+
Some(bytes.to_vec()),
72+
BINARY_PLACEHOLDER.to_string(),
73+
),
74+
Err(_) => {
75+
let content_str = content.str()?.to_string();
76+
(
77+
data_type.unwrap_or(CONTENT_TYPE_TEXT).to_string(),
78+
Some(content_str.clone()),
79+
None,
80+
content_str,
81+
)
82+
}
83+
};
84+
85+
let record_id = format!("{}-{}", self.run_id, self.inner.entries() + 1);
86+
let record = ContextRecord {
87+
id: record_id,
88+
run_id: self.run_id.clone(),
89+
created_at: Utc::now(),
90+
role: role.to_string(),
91+
state_metadata: None,
92+
content_type,
93+
text_payload,
94+
binary_payload,
95+
embedding: None,
96+
};
97+
98+
self.runtime
99+
.block_on(self.store.add(std::slice::from_ref(&record)))
100+
.map_err(to_py_err)?;
101+
self.inner.add(role, &inner_content, data_type);
46102
Ok(())
47103
}
48104

@@ -54,6 +110,9 @@ impl Context {
54110
fn fork(&self, branch_name: &str) -> Self {
55111
Self {
56112
inner: self.inner.fork(branch_name),
113+
store: self.store.clone(),
114+
runtime: Arc::clone(&self.runtime),
115+
run_id: new_run_id(),
57116
}
58117
}
59118

@@ -62,6 +121,14 @@ impl Context {
62121
}
63122
}
64123

124+
fn new_run_id() -> String {
125+
format!("run-{}-{}", Utc::now().timestamp_micros(), std::process::id())
126+
}
127+
128+
fn to_py_err<E: std::fmt::Display>(err: E) -> PyErr {
129+
PyRuntimeError::new_err(err.to_string())
130+
}
131+
65132
#[pymodule]
66133
fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> {
67134
m.add_function(wrap_pyfunction!(version, m)?)?;

rust/lance-context/src/store.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,11 @@ impl ContextStore {
203203
}
204204
embedding_builder.append(true);
205205
} else {
206+
// FixedSizeListBuilder requires padding values for null slots.
207+
let values_builder = embedding_builder.values();
208+
for _ in 0..DEFAULT_EMBEDDING_DIM {
209+
values_builder.append_null();
210+
}
206211
embedding_builder.append(false);
207212
}
208213
}

0 commit comments

Comments
 (0)