Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ cython_debug/
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
# refer to https://docs.cursor.com/context/ignore-files

# BSP / IDE workspace files
.bazelbsp/
.bsp/
.ijwb/
.cursorignore
.cursorindexingignore

Expand Down
3 changes: 3 additions & 0 deletions python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ name = "_internal"
crate-type = ["cdylib"]

[dependencies]
chrono = { version = "0.4", default-features = false, features = ["clock"] }
lance-context = { path = "../rust/lance-context" }
pyo3 = { version = "0.25", features = ["extension-module", "abi3-py39", "py-clone"] }

tokio = { version = "1", features = ["rt-multi-thread"] }
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ lint.select = ["F", "E", "W", "I", "G", "TCH", "PERF", "B019"]

[tool.pyright]
pythonVersion = "3.13"
include = ["python/lance_context/__init__.py"]
include = ["python/lance_context/__init__.py", "python/lance_context/api.py"]
reportMissingTypeStubs = "warning"
reportImportCycles = "error"
reportUnusedImport = "error"
Expand Down
5 changes: 1 addition & 4 deletions python/python/lance_context/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from __future__ import annotations

from ._internal import Context # pyright: ignore[reportMissingImports]
from ._internal import version as _version # pyright: ignore[reportMissingImports]
from .api import Context, __version__ # pyright: ignore[reportMissingImports]

__all__ = ["Context", "__version__"]

__version__ = _version()
151 changes: 151 additions & 0 deletions python/python/lance_context/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from __future__ import annotations

from io import BytesIO
from typing import Any

from ._internal import Context as _Context # pyright: ignore[reportMissingImports]
from ._internal import version as _version # pyright: ignore[reportMissingImports]

__all__ = ["Context", "__version__"]

__version__ = _version()

_ARROW_STREAM_MIME = "application/vnd.apache.arrow.stream"


def _is_module(value: Any, prefix: str) -> bool:
return type(value).__module__.startswith(prefix)


def _get_pyarrow():
try:
import pyarrow as pa # pyright: ignore[reportMissingImports]
except ImportError as exc: # pragma: no cover - optional dependency
raise ImportError(
"pyarrow is required to serialize pandas/polars dataframes"
) from exc
return pa


def _coerce_arrow_table(value: Any):
pa = _get_pyarrow()
if isinstance(value, pa.Table):
return value
if isinstance(value, pa.RecordBatch):
return pa.Table.from_batches([value])
if _is_module(value, "polars."):
table = value.to_arrow()
elif _is_module(value, "pandas."):
table = pa.Table.from_pandas(value)
elif hasattr(value, "to_arrow"):
table = value.to_arrow()
else:
return None

if isinstance(table, pa.RecordBatch):
return pa.Table.from_batches([table])
if not isinstance(table, pa.Table):
raise TypeError("to_arrow() did not return a pyarrow Table or RecordBatch")
return table


def _serialize_dataframe(value: Any):
table = _coerce_arrow_table(value)
if table is None:
return None
pa = _get_pyarrow()
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, table.schema) as writer:
writer.write_table(table)
return sink.getvalue().to_pybytes(), _ARROW_STREAM_MIME


def _serialize_image(value: Any):
if not _is_module(value, "PIL."):
return None
try:
from PIL import Image # pyright: ignore[reportMissingImports]
except ImportError as exc: # pragma: no cover - optional dependency
raise ImportError("Pillow is required to serialize images") from exc
if not isinstance(value, Image.Image):
return None

image_format = value.format or "PNG"
mime = None
if hasattr(value, "get_format_mimetype"):
mime = value.get_format_mimetype()
if not mime:
mime = Image.MIME.get(image_format.upper())
if not mime:
mime = "application/octet-stream"

buffer = BytesIO()
value.save(buffer, format=image_format)
return buffer.getvalue(), mime


def _normalize_content(value: Any, content_type: str | None):
serialized = _serialize_dataframe(value)
if serialized is not None:
payload, inferred = serialized
return payload, content_type or inferred
serialized = _serialize_image(value)
if serialized is not None:
payload, inferred = serialized
return payload, content_type or inferred
return value, content_type


class Context:
def __init__(self, uri: str) -> None:
self._inner = _Context.create(uri)

@classmethod
def create(cls, uri: str) -> Context:
return cls(uri)

def uri(self) -> str:
return self._inner.uri()

def branch(self) -> str:
return self._inner.branch()

def entries(self) -> int:
return self._inner.entries()

def add(
self,
role: str,
content: Any,
content_type: str | None = None,
data_type: str | None = None,
) -> None:
if content_type is not None and data_type is not None:
raise ValueError("Specify only one of content_type or data_type")
if content_type is None:
content_type = data_type
payload, resolved_type = _normalize_content(content, content_type)
self._inner.add(role, payload, resolved_type)

def snapshot(self, label: str | None = None) -> str:
return self._inner.snapshot(label)

def fork(self, branch_name: str) -> Context:
inner = self._inner.fork(branch_name)
return self._from_inner(inner)

def checkout(self, snapshot_id: str) -> None:
self._inner.checkout(snapshot_id)

def __repr__(self) -> str:
return (
f"Context(uri={self._inner.uri()!r}, "
f"branch={self._inner.branch()!r}, "
f"entries={self._inner.entries()})"
)

@classmethod
def _from_inner(cls, inner: _Context) -> Context:
obj = cls.__new__(cls)
obj._inner = inner
return obj
79 changes: 73 additions & 6 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
use std::sync::Arc;

use chrono::Utc;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::PyType;
use tokio::runtime::Runtime;

use lance_context::serde::CONTENT_TYPE_TEXT;
use lance_context::{Context as RustContext, ContextRecord, ContextStore};

use lance_context::Context as RustContext;
const DEFAULT_BINARY_CONTENT_TYPE: &str = "application/octet-stream";
const BINARY_PLACEHOLDER: &str = "[binary]";

#[pyfunction]
fn version() -> &'static str {
Expand All @@ -11,15 +20,26 @@ fn version() -> &'static str {
#[pyclass]
struct Context {
inner: RustContext,
store: ContextStore,
runtime: Arc<Runtime>,
run_id: String,
}

#[pymethods]
impl Context {
#[classmethod]
fn create(_cls: &Bound<'_, PyType>, uri: &str) -> Self {
Self {
fn create(_cls: &Bound<'_, PyType>, uri: &str) -> PyResult<Self> {
let runtime = Arc::new(Runtime::new().map_err(to_py_err)?);
let store = runtime
.block_on(ContextStore::open(uri))
.map_err(to_py_err)?;
let run_id = new_run_id();
Ok(Self {
inner: RustContext::new(uri),
}
store,
runtime,
run_id,
})
}

fn uri(&self) -> &str {
Expand All @@ -41,8 +61,44 @@ impl Context {
content: &Bound<'_, PyAny>,
data_type: Option<&str>,
) -> PyResult<()> {
let content_str = content.str()?.to_string();
self.inner.add(role, &content_str, data_type);
let (content_type, text_payload, binary_payload, inner_content) =
match content.extract::<&[u8]>() {
Ok(bytes) => (
data_type
.unwrap_or(DEFAULT_BINARY_CONTENT_TYPE)
.to_string(),
None,
Some(bytes.to_vec()),
BINARY_PLACEHOLDER.to_string(),
),
Err(_) => {
let content_str = content.str()?.to_string();
(
data_type.unwrap_or(CONTENT_TYPE_TEXT).to_string(),
Some(content_str.clone()),
None,
content_str,
)
}
};

let record_id = format!("{}-{}", self.run_id, self.inner.entries() + 1);
let record = ContextRecord {
id: record_id,
run_id: self.run_id.clone(),
created_at: Utc::now(),
role: role.to_string(),
state_metadata: None,
content_type,
text_payload,
binary_payload,
embedding: None,
};

self.runtime
.block_on(self.store.add(std::slice::from_ref(&record)))
.map_err(to_py_err)?;
self.inner.add(role, &inner_content, data_type);
Ok(())
}

Expand All @@ -54,6 +110,9 @@ impl Context {
fn fork(&self, branch_name: &str) -> Self {
Self {
inner: self.inner.fork(branch_name),
store: self.store.clone(),
runtime: Arc::clone(&self.runtime),
run_id: new_run_id(),
}
}

Expand All @@ -62,6 +121,14 @@ impl Context {
}
}

fn new_run_id() -> String {
format!("run-{}-{}", Utc::now().timestamp_micros(), std::process::id())
}

fn to_py_err<E: std::fmt::Display>(err: E) -> PyErr {
PyRuntimeError::new_err(err.to_string())
}

#[pymodule]
fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(version, m)?)?;
Expand Down
1 change: 1 addition & 0 deletions rust/lance-context/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/lance-context/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ categories = ["database", "data-structures", "science"]

[dependencies]
arrow-array = "56.2.0"
arrow-ipc = "56.2.0"
arrow-schema = "56.2.0"
chrono = { version = "0.4", default-features = false, features = ["clock"] }
lance = "1.0.0"
Expand Down
1 change: 1 addition & 0 deletions rust/lance-context/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

mod context;
mod record;
pub mod serde;
mod store;

pub use context::{Context, ContextEntry, Snapshot};
Expand Down
Loading
Loading