diff --git a/.gitignore b/.gitignore index b7faf40..90cdcf1 100644 --- a/.gitignore +++ b/.gitignore @@ -198,6 +198,11 @@ cython_debug/ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data # refer to https://docs.cursor.com/context/ignore-files + +# BSP / IDE workspace files +.bazelbsp/ +.bsp/ +.ijwb/ .cursorignore .cursorindexingignore diff --git a/python/Cargo.lock b/python/Cargo.lock index 394d431..9145dd0 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -3394,6 +3394,7 @@ name = "lance-context" version = "0.1.0" dependencies = [ "arrow-array", + "arrow-ipc", "arrow-schema", "chrono", "lance", @@ -3407,8 +3408,10 @@ dependencies = [ name = "lance-context-python" version = "0.1.0" dependencies = [ + "chrono", "lance-context", "pyo3", + "tokio", ] [[package]] diff --git a/python/Cargo.toml b/python/Cargo.toml index 4fac45b..a785741 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -10,6 +10,7 @@ name = "_internal" crate-type = ["cdylib"] [dependencies] +chrono = { version = "0.4", default-features = false, features = ["clock"] } lance-context = { path = "../rust/lance-context" } pyo3 = { version = "0.25", features = ["extension-module", "abi3-py39", "py-clone"] } - +tokio = { version = "1", features = ["rt-multi-thread"] } diff --git a/python/pyproject.toml b/python/pyproject.toml index ec3ffd0..427e8f6 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -49,7 +49,7 @@ lint.select = ["F", "E", "W", "I", "G", "TCH", "PERF", "B019"] [tool.pyright] pythonVersion = "3.13" -include = ["python/lance_context/__init__.py"] +include = ["python/lance_context/__init__.py", "python/lance_context/api.py"] reportMissingTypeStubs = "warning" reportImportCycles = "error" reportUnusedImport = "error" diff --git a/python/python/lance_context/__init__.py b/python/python/lance_context/__init__.py index 27c69c5..3fafd00 100644 --- a/python/python/lance_context/__init__.py +++ b/python/python/lance_context/__init__.py @@ -1,8 +1,5 @@ from __future__ import annotations -from ._internal import Context # pyright: ignore[reportMissingImports] -from ._internal import version as _version # pyright: ignore[reportMissingImports] +from .api import Context, __version__ # pyright: ignore[reportMissingImports] __all__ = ["Context", "__version__"] - -__version__ = _version() diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py new file mode 100644 index 0000000..8f3a728 --- /dev/null +++ b/python/python/lance_context/api.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from io import BytesIO +from typing import Any + +from ._internal import Context as _Context # pyright: ignore[reportMissingImports] +from ._internal import version as _version # pyright: ignore[reportMissingImports] + +__all__ = ["Context", "__version__"] + +__version__ = _version() + +_ARROW_STREAM_MIME = "application/vnd.apache.arrow.stream" + + +def _is_module(value: Any, prefix: str) -> bool: + return type(value).__module__.startswith(prefix) + + +def _get_pyarrow(): + try: + import pyarrow as pa # pyright: ignore[reportMissingImports] + except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "pyarrow is required to serialize pandas/polars dataframes" + ) from exc + return pa + + +def _coerce_arrow_table(value: Any): + pa = _get_pyarrow() + if isinstance(value, pa.Table): + return value + if isinstance(value, pa.RecordBatch): + return pa.Table.from_batches([value]) + if _is_module(value, "polars."): + table = value.to_arrow() + elif _is_module(value, "pandas."): + table = pa.Table.from_pandas(value) + elif hasattr(value, "to_arrow"): + table = value.to_arrow() + else: + return None + + if isinstance(table, pa.RecordBatch): + return pa.Table.from_batches([table]) + if not isinstance(table, pa.Table): + raise TypeError("to_arrow() did not return a pyarrow Table or RecordBatch") + return table + + +def _serialize_dataframe(value: Any): + table = _coerce_arrow_table(value) + if table is None: + return None + pa = _get_pyarrow() + sink = pa.BufferOutputStream() + with pa.ipc.new_stream(sink, table.schema) as writer: + writer.write_table(table) + return sink.getvalue().to_pybytes(), _ARROW_STREAM_MIME + + +def _serialize_image(value: Any): + if not _is_module(value, "PIL."): + return None + try: + from PIL import Image # pyright: ignore[reportMissingImports] + except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError("Pillow is required to serialize images") from exc + if not isinstance(value, Image.Image): + return None + + image_format = value.format or "PNG" + mime = None + if hasattr(value, "get_format_mimetype"): + mime = value.get_format_mimetype() + if not mime: + mime = Image.MIME.get(image_format.upper()) + if not mime: + mime = "application/octet-stream" + + buffer = BytesIO() + value.save(buffer, format=image_format) + return buffer.getvalue(), mime + + +def _normalize_content(value: Any, content_type: str | None): + serialized = _serialize_dataframe(value) + if serialized is not None: + payload, inferred = serialized + return payload, content_type or inferred + serialized = _serialize_image(value) + if serialized is not None: + payload, inferred = serialized + return payload, content_type or inferred + return value, content_type + + +class Context: + def __init__(self, uri: str) -> None: + self._inner = _Context.create(uri) + + @classmethod + def create(cls, uri: str) -> Context: + return cls(uri) + + def uri(self) -> str: + return self._inner.uri() + + def branch(self) -> str: + return self._inner.branch() + + def entries(self) -> int: + return self._inner.entries() + + def add( + self, + role: str, + content: Any, + content_type: str | None = None, + data_type: str | None = None, + ) -> None: + if content_type is not None and data_type is not None: + raise ValueError("Specify only one of content_type or data_type") + if content_type is None: + content_type = data_type + payload, resolved_type = _normalize_content(content, content_type) + self._inner.add(role, payload, resolved_type) + + def snapshot(self, label: str | None = None) -> str: + return self._inner.snapshot(label) + + def fork(self, branch_name: str) -> Context: + inner = self._inner.fork(branch_name) + return self._from_inner(inner) + + def checkout(self, snapshot_id: str) -> None: + self._inner.checkout(snapshot_id) + + def __repr__(self) -> str: + return ( + f"Context(uri={self._inner.uri()!r}, " + f"branch={self._inner.branch()!r}, " + f"entries={self._inner.entries()})" + ) + + @classmethod + def _from_inner(cls, inner: _Context) -> Context: + obj = cls.__new__(cls) + obj._inner = inner + return obj diff --git a/python/src/lib.rs b/python/src/lib.rs index fb002c7..5738551 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1,7 +1,16 @@ +use std::sync::Arc; + +use chrono::Utc; +use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use pyo3::types::PyType; +use tokio::runtime::Runtime; + +use lance_context::serde::CONTENT_TYPE_TEXT; +use lance_context::{Context as RustContext, ContextRecord, ContextStore}; -use lance_context::Context as RustContext; +const DEFAULT_BINARY_CONTENT_TYPE: &str = "application/octet-stream"; +const BINARY_PLACEHOLDER: &str = "[binary]"; #[pyfunction] fn version() -> &'static str { @@ -11,15 +20,26 @@ fn version() -> &'static str { #[pyclass] struct Context { inner: RustContext, + store: ContextStore, + runtime: Arc, + run_id: String, } #[pymethods] impl Context { #[classmethod] - fn create(_cls: &Bound<'_, PyType>, uri: &str) -> Self { - Self { + fn create(_cls: &Bound<'_, PyType>, uri: &str) -> PyResult { + let runtime = Arc::new(Runtime::new().map_err(to_py_err)?); + let store = runtime + .block_on(ContextStore::open(uri)) + .map_err(to_py_err)?; + let run_id = new_run_id(); + Ok(Self { inner: RustContext::new(uri), - } + store, + runtime, + run_id, + }) } fn uri(&self) -> &str { @@ -41,8 +61,44 @@ impl Context { content: &Bound<'_, PyAny>, data_type: Option<&str>, ) -> PyResult<()> { - let content_str = content.str()?.to_string(); - self.inner.add(role, &content_str, data_type); + let (content_type, text_payload, binary_payload, inner_content) = + match content.extract::<&[u8]>() { + Ok(bytes) => ( + data_type + .unwrap_or(DEFAULT_BINARY_CONTENT_TYPE) + .to_string(), + None, + Some(bytes.to_vec()), + BINARY_PLACEHOLDER.to_string(), + ), + Err(_) => { + let content_str = content.str()?.to_string(); + ( + data_type.unwrap_or(CONTENT_TYPE_TEXT).to_string(), + Some(content_str.clone()), + None, + content_str, + ) + } + }; + + let record_id = format!("{}-{}", self.run_id, self.inner.entries() + 1); + let record = ContextRecord { + id: record_id, + run_id: self.run_id.clone(), + created_at: Utc::now(), + role: role.to_string(), + state_metadata: None, + content_type, + text_payload, + binary_payload, + embedding: None, + }; + + self.runtime + .block_on(self.store.add(std::slice::from_ref(&record))) + .map_err(to_py_err)?; + self.inner.add(role, &inner_content, data_type); Ok(()) } @@ -54,6 +110,9 @@ impl Context { fn fork(&self, branch_name: &str) -> Self { Self { inner: self.inner.fork(branch_name), + store: self.store.clone(), + runtime: Arc::clone(&self.runtime), + run_id: new_run_id(), } } @@ -62,6 +121,14 @@ impl Context { } } +fn new_run_id() -> String { + format!("run-{}-{}", Utc::now().timestamp_micros(), std::process::id()) +} + +fn to_py_err(err: E) -> PyErr { + PyRuntimeError::new_err(err.to_string()) +} + #[pymodule] fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(version, m)?)?; diff --git a/rust/lance-context/Cargo.lock b/rust/lance-context/Cargo.lock index 2914054..8deac67 100644 --- a/rust/lance-context/Cargo.lock +++ b/rust/lance-context/Cargo.lock @@ -3385,6 +3385,7 @@ name = "lance-context" version = "0.1.0" dependencies = [ "arrow-array", + "arrow-ipc", "arrow-schema", "chrono", "lance", diff --git a/rust/lance-context/Cargo.toml b/rust/lance-context/Cargo.toml index 91ed5d1..016d020 100644 --- a/rust/lance-context/Cargo.toml +++ b/rust/lance-context/Cargo.toml @@ -12,6 +12,7 @@ categories = ["database", "data-structures", "science"] [dependencies] arrow-array = "56.2.0" +arrow-ipc = "56.2.0" arrow-schema = "56.2.0" chrono = { version = "0.4", default-features = false, features = ["clock"] } lance = "1.0.0" diff --git a/rust/lance-context/src/lib.rs b/rust/lance-context/src/lib.rs index c6e3165..5e8460a 100644 --- a/rust/lance-context/src/lib.rs +++ b/rust/lance-context/src/lib.rs @@ -2,6 +2,7 @@ mod context; mod record; +pub mod serde; mod store; pub use context::{Context, ContextEntry, Snapshot}; diff --git a/rust/lance-context/src/serde.rs b/rust/lance-context/src/serde.rs new file mode 100644 index 0000000..d0c97b1 --- /dev/null +++ b/rust/lance-context/src/serde.rs @@ -0,0 +1,141 @@ +use arrow_array::RecordBatch; +use arrow_ipc::writer::StreamWriter; +use arrow_schema::ArrowError; +use serde::{Deserialize, Serialize}; + +pub const CONTENT_TYPE_TEXT: &str = "text/plain"; +pub const CONTENT_TYPE_ARROW_STREAM: &str = "application/vnd.apache.arrow.stream"; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct SerializedContent { + pub content_type: String, + pub text_payload: Option, + pub binary_payload: Option>, +} + +impl SerializedContent { + pub fn text(value: impl Into, content_type: Option<&str>) -> Self { + Self { + content_type: content_type.unwrap_or(CONTENT_TYPE_TEXT).to_string(), + text_payload: Some(value.into()), + binary_payload: None, + } + } + + pub fn image(bytes: impl Into>, mime: impl Into) -> Self { + Self { + content_type: mime.into(), + text_payload: None, + binary_payload: Some(bytes.into()), + } + } + + pub fn dataframe_batches(batches: &[RecordBatch]) -> Result { + let ipc_bytes = record_batches_to_ipc(batches)?; + Ok(Self::dataframe_ipc_bytes(ipc_bytes)) + } + + pub fn dataframe_ipc_bytes(bytes: impl Into>) -> Self { + Self { + content_type: CONTENT_TYPE_ARROW_STREAM.to_string(), + text_payload: None, + binary_payload: Some(bytes.into()), + } + } +} + +pub fn serialize_image(bytes: impl Into>, mime: impl Into) -> SerializedContent { + SerializedContent::image(bytes, mime) +} + +pub fn serialize_dataframe(batches: &[RecordBatch]) -> Result { + SerializedContent::dataframe_batches(batches) +} + +pub fn serialize_dataframe_ipc(bytes: impl Into>) -> SerializedContent { + SerializedContent::dataframe_ipc_bytes(bytes) +} + +fn record_batches_to_ipc(batches: &[RecordBatch]) -> Result, ArrowError> { + if batches.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "no record batches provided".to_string(), + )); + } + + let schema = batches[0].schema(); + let mut buffer = Vec::new(); + { + let mut writer = StreamWriter::try_new(&mut buffer, &schema)?; + for batch in batches { + if batch.schema() != schema { + return Err(ArrowError::SchemaError( + "record batch schema mismatch".to_string(), + )); + } + writer.write(batch)?; + } + writer.finish()?; + } + Ok(buffer) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{Int32Array, RecordBatch, StringArray}; + use arrow_ipc::reader::StreamReader; + use arrow_schema::{DataType, Field, Schema}; + use std::io::Cursor; + use std::sync::Arc; + + fn make_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + let id_array = Arc::new(Int32Array::from(vec![1, 2])); + let name_array = Arc::new(StringArray::from(vec!["alpha", "beta"])); + RecordBatch::try_new(schema, vec![id_array, name_array]).unwrap() + } + + #[test] + fn image_serialization_sets_payloads() { + let content = serialize_image(vec![1, 2, 3], "image/png"); + assert_eq!(content.content_type, "image/png"); + assert_eq!(content.text_payload, None); + assert_eq!(content.binary_payload, Some(vec![1, 2, 3])); + } + + #[test] + fn dataframe_serialization_writes_ipc_stream() { + let batch = make_batch(); + let content = serialize_dataframe(std::slice::from_ref(&batch)).unwrap(); + assert_eq!(content.content_type, CONTENT_TYPE_ARROW_STREAM); + let bytes = content.binary_payload.expect("expected IPC payload"); + + let reader = StreamReader::try_new(Cursor::new(bytes), None).unwrap(); + let batches: Vec = reader.map(|item| item.unwrap()).collect(); + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].schema(), batch.schema()); + assert_eq!(batches[0].num_rows(), batch.num_rows()); + } + + #[test] + fn dataframe_serialization_rejects_empty_batches() { + let err = serialize_dataframe(&[]).unwrap_err(); + assert!(matches!(err, ArrowError::InvalidArgumentError(_))); + } + + #[test] + fn dataframe_serialization_rejects_mismatched_schema() { + let batch = make_batch(); + let other_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let other_batch = + RecordBatch::try_new(other_schema, vec![Arc::new(Int32Array::from(vec![1, 2]))]) + .unwrap(); + + let err = serialize_dataframe(&[batch, other_batch]).unwrap_err(); + assert!(matches!(err, ArrowError::SchemaError(_))); + } +} diff --git a/rust/lance-context/src/store.rs b/rust/lance-context/src/store.rs index 20f9ea4..049cfd6 100644 --- a/rust/lance-context/src/store.rs +++ b/rust/lance-context/src/store.rs @@ -203,6 +203,11 @@ impl ContextStore { } embedding_builder.append(true); } else { + // FixedSizeListBuilder requires padding values for null slots. + let values_builder = embedding_builder.values(); + for _ in 0..DEFAULT_EMBEDDING_DIM { + values_builder.append_null(); + } embedding_builder.append(false); } }