From 2968bd43cb331f060cffb5c1721870fab57a920b Mon Sep 17 00:00:00 2001 From: beinan Date: Tue, 20 Jan 2026 02:26:49 +0000 Subject: [PATCH] feat: support remote storage options --- README.md | 12 ++ crates/lance-context-core/src/lib.rs | 2 +- crates/lance-context-core/src/store.rs | 78 ++++++++++--- python/pyproject.toml | 2 +- python/python/lance_context/api.py | 56 ++++++++- python/src/lib.rs | 55 ++++++++- python/tests/test_persistence.py | 150 ++++++++++++++++++++++++- python/tests/test_search.py | 1 - 8 files changed, 327 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 7de4f2e..5f8d96a 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ Key motivations inspired by the broader Lance roadmap[1](https://github.com - Unified schema for agent messages (`ContextRecord`) with optional embeddings and metadata. - Automatic versioning via Lance manifests with `checkout(version)` support. +- Remote persistence: point the store at `s3://` URIs with either AWS environment variables or explicit credentials/endpoint overrides. - Python API (`lance_context.api.Context`) aligned with the Rust implementation. - Integration tests that exercise real persistence, image serialization, and version rollbacks. @@ -65,6 +66,17 @@ ctx.add("assistant", "Let me fetch suggestions…") ctx.checkout(first_version) print("Entries after checkout:", ctx.entries()) + +# Store context in S3 (e.g., for MinIO/moto test endpoints) +ctx = Context.create( + "s3://my-bucket/context.lance", + aws_access_key_id="minioadmin", + aws_secret_access_key="minioadmin", + region="us-east-1", + endpoint_url="http://localhost:9000", + allow_http=True, +) +# AWS_* environment variables work too—pass overrides only when you need custom endpoints. ``` ### Rust diff --git a/crates/lance-context-core/src/lib.rs b/crates/lance-context-core/src/lib.rs index cf6f90f..799db6a 100644 --- a/crates/lance-context-core/src/lib.rs +++ b/crates/lance-context-core/src/lib.rs @@ -7,4 +7,4 @@ mod store; pub use context::{Context, ContextEntry, Snapshot}; pub use record::{ContextRecord, SearchResult, StateMetadata}; -pub use store::ContextStore; +pub use store::{ContextStore, ContextStoreOptions}; diff --git a/crates/lance-context-core/src/store.rs b/crates/lance-context-core/src/store.rs index f750442..381ba8e 100644 --- a/crates/lance-context-core/src/store.rs +++ b/crates/lance-context-core/src/store.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::sync::Arc; use arrow_array::builder::{ @@ -13,7 +14,8 @@ use arrow_array::{ use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, TimeUnit}; use chrono::DateTime; use futures::TryStreamExt; -use lance::dataset::{Dataset, WriteMode, WriteParams}; +use lance::dataset::{builder::DatasetBuilder, Dataset, WriteMode, WriteParams}; +use lance::io::ObjectStoreParams; use lance::{Error as LanceError, Result as LanceResult}; use crate::record::{ContextRecord, SearchResult, StateMetadata}; @@ -28,23 +30,32 @@ pub struct ContextStore { dataset: Dataset, } +/// Additional configuration when opening a [`ContextStore`]. +#[derive(Debug, Clone, Default)] +pub struct ContextStoreOptions { + pub storage_options: Option>, +} + +impl ContextStoreOptions { + #[must_use] + pub fn storage_options(&self) -> Option> { + self.storage_options.clone() + } +} + impl ContextStore { /// Open an existing context dataset or create a new one with the project schema. pub async fn open(uri: &str) -> LanceResult { - match Dataset::open(uri).await { + Self::open_with_options(uri, ContextStoreOptions::default()).await + } + + /// Open a dataset with explicit object store configuration (e.g. S3 credentials). + pub async fn open_with_options(uri: &str, options: ContextStoreOptions) -> LanceResult { + let storage_options = options.storage_options(); + match Self::load_with_options(uri, storage_options.clone()).await { Ok(dataset) => Ok(Self { dataset }), Err(LanceError::DatasetNotFound { .. }) => { - let schema = Arc::new(Self::schema()); - let empty_batch = RecordBatch::new_empty(schema.clone()); - let batches = RecordBatchIterator::new( - vec![Ok::(empty_batch)].into_iter(), - schema.clone(), - ); - let params = WriteParams { - mode: WriteMode::Create, - ..Default::default() - }; - let dataset = Dataset::write(batches, uri, Some(params)).await?; + let dataset = Self::create_with_options(uri, storage_options).await?; Ok(Self { dataset }) } Err(err) => Err(err), @@ -156,6 +167,47 @@ impl ContextStore { ]) } + async fn load_with_options( + uri: &str, + storage_options: Option>, + ) -> LanceResult { + if let Some(options) = storage_options { + DatasetBuilder::from_uri(uri) + .with_storage_options(options) + .load() + .await + } else { + Dataset::open(uri).await + } + } + + async fn create_with_options( + uri: &str, + storage_options: Option>, + ) -> LanceResult { + let schema = Arc::new(Self::schema()); + let empty_batch = RecordBatch::new_empty(schema.clone()); + let batches = RecordBatchIterator::new( + vec![Ok::(empty_batch)].into_iter(), + schema.clone(), + ); + + let mut params = WriteParams { + mode: WriteMode::Create, + ..Default::default() + }; + + if let Some(options) = storage_options { + let store_params = ObjectStoreParams { + storage_options: Some(options), + ..Default::default() + }; + params.store_params = Some(store_params); + } + + Dataset::write(batches, uri, Some(params)).await + } + fn records_to_batch(entries: &[ContextRecord]) -> LanceResult { let mut id_builder = StringBuilder::new(); let mut run_id_builder = StringBuilder::new(); diff --git a/python/pyproject.toml b/python/pyproject.toml index 427e8f6..a275786 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -41,7 +41,7 @@ requires = ["maturin>=1.4"] build-backend = "maturin" [project.optional-dependencies] -tests = ["pytest", "ruff"] +tests = ["pytest", "ruff", "moto[s3]", "boto3", "botocore"] dev = ["ruff", "pyright"] [tool.ruff] diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py index f563087..9f33d5e 100644 --- a/python/python/lance_context/api.py +++ b/python/python/lance_context/api.py @@ -126,12 +126,60 @@ def _normalize_search_hit(raw: dict[str, Any]) -> dict[str, Any]: class Context: - def __init__(self, uri: str) -> None: - self._inner = _Context.create(uri) + def __init__( + self, + uri: str, + *, + storage_options: dict[str, Any] | None = None, + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + aws_session_token: str | None = None, + region: str | None = None, + endpoint_url: str | None = None, + allow_http: bool = False, + ) -> None: + options = dict(storage_options or {}) + if aws_access_key_id is not None: + options["aws_access_key_id"] = aws_access_key_id + if aws_secret_access_key is not None: + options["aws_secret_access_key"] = aws_secret_access_key + if aws_session_token is not None: + options["aws_session_token"] = aws_session_token + if region is not None: + options["aws_region"] = region + if endpoint_url is not None: + options["aws_endpoint_url"] = endpoint_url + if allow_http: + options["aws_allow_http"] = True + + if options: + self._inner = _Context.create(uri, storage_options=options) + else: + self._inner = _Context.create(uri) @classmethod - def create(cls, uri: str) -> Context: - return cls(uri) + def create( + cls, + uri: str, + *, + storage_options: dict[str, Any] | None = None, + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + aws_session_token: str | None = None, + region: str | None = None, + endpoint_url: str | None = None, + allow_http: bool = False, + ) -> Context: + return cls( + uri, + storage_options=storage_options, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + region=region, + endpoint_url=endpoint_url, + allow_http=allow_http, + ) def uri(self) -> str: return self._inner.uri() diff --git a/python/src/lib.rs b/python/src/lib.rs index d4e37b9..30d8e44 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::sync::Arc; use chrono::{SecondsFormat, Utc}; @@ -8,7 +9,9 @@ use pyo3::IntoPyObject; use tokio::runtime::Runtime; use lance_context::serde::CONTENT_TYPE_TEXT; -use lance_context::{Context as RustContext, ContextRecord, ContextStore, SearchResult}; +use lance_context::{ + Context as RustContext, ContextRecord, ContextStore, ContextStoreOptions, SearchResult, +}; const DEFAULT_BINARY_CONTENT_TYPE: &str = "application/octet-stream"; const BINARY_PLACEHOLDER: &str = "[binary]"; @@ -26,13 +29,59 @@ struct Context { run_id: String, } +fn storage_options_from_dict<'py>( + dict: Option<&Bound<'py, PyDict>>, +) -> PyResult>> { + let Some(dict) = dict else { + return Ok(None); + }; + + let mut options = HashMap::new(); + for (key, value) in dict.iter() { + let key_str = key.extract::()?; + if value.is_none() { + continue; + } + let string_value = if let Ok(boolean) = value.extract::() { + if boolean { + "true".to_string() + } else { + "false".to_string() + } + } else if let Ok(number) = value.extract::() { + number.to_string() + } else if let Ok(float_val) = value.extract::() { + float_val.to_string() + } else { + value.str()?.to_string() + }; + options.insert(key_str, string_value); + } + + if options.is_empty() { + Ok(None) + } else { + Ok(Some(options)) + } +} + #[pymethods] impl Context { #[classmethod] - fn create(_cls: &Bound<'_, PyType>, uri: &str) -> PyResult { + #[pyo3(signature = (uri, *, storage_options=None))] + fn create( + _cls: &Bound<'_, PyType>, + uri: &str, + storage_options: Option<&Bound<'_, PyDict>>, + ) -> PyResult { let runtime = Arc::new(Runtime::new().map_err(to_py_err)?); + + let options = ContextStoreOptions { + storage_options: storage_options_from_dict(storage_options)?, + }; + let store = runtime - .block_on(ContextStore::open(uri)) + .block_on(ContextStore::open_with_options(uri, options)) .map_err(to_py_err)?; let run_id = new_run_id(); Ok(Self { diff --git a/python/tests/test_persistence.py b/python/tests/test_persistence.py index adab7c9..108f0b8 100644 --- a/python/tests/test_persistence.py +++ b/python/tests/test_persistence.py @@ -1,6 +1,10 @@ from __future__ import annotations +import socket +import subprocess import sys +import time +import uuid from io import BytesIO from pathlib import Path from typing import Any @@ -11,13 +15,120 @@ if str(PACKAGE_ROOT) not in sys.path: sys.path.insert(0, str(PACKAGE_ROOT)) -lance = pytest.importorskip("lance") - -from lance_context.api import Context +from lance_context.api import Context # noqa: E402 +lance = pytest.importorskip("lance") -def _read_rows(uri: str, version: int | None = None) -> list[dict[str, object]]: - dataset = lance.dataset(uri, version=version) if version is not None else lance.dataset(uri) +_S3_ACCESS_KEY = "test" +_S3_SECRET_KEY = "test" +_S3_REGION = "us-east-1" + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return int(sock.getsockname()[1]) + + +def _s3_storage_options(endpoint: str) -> dict[str, str]: + return { + "aws_access_key_id": _S3_ACCESS_KEY, + "aws_secret_access_key": _S3_SECRET_KEY, + "aws_region": _S3_REGION, + "aws_endpoint_url": endpoint, + "aws_allow_http": "true", + } + + +def _wait_for_moto_ready(client: Any, timeout: float = 5.0) -> None: + deadline = time.time() + timeout + last_error: Exception | None = None + while time.time() < deadline: + try: + client.list_buckets() + return + except Exception as exc: # pragma: no cover - best effort + last_error = exc + time.sleep(0.1) + raise RuntimeError("moto server did not become ready") from last_error + + +@pytest.fixture(scope="module") +def moto_endpoint() -> str: + pytest.importorskip("moto.server") + boto3 = pytest.importorskip("boto3") + from botocore.config import Config # type: ignore[import-not-found] + + port = _free_port() + cmd = [ + sys.executable, + "-m", + "moto.server", + "s3", + "-H", + "127.0.0.1", + "-p", + str(port), + ] + process = subprocess.Popen( + cmd, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + endpoint = f"http://127.0.0.1:{port}" + + session = boto3.session.Session( + aws_access_key_id=_S3_ACCESS_KEY, + aws_secret_access_key=_S3_SECRET_KEY, + region_name=_S3_REGION, + ) + client = session.client( + "s3", + endpoint_url=endpoint, + config=Config(signature_version="s3v4"), + ) + + try: + _wait_for_moto_ready(client) + yield endpoint + finally: + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait(timeout=5) + + +@pytest.fixture +def s3_client(moto_endpoint: str): + boto3 = pytest.importorskip("boto3") + from botocore.config import Config # type: ignore[import-not-found] + + session = boto3.session.Session( + aws_access_key_id=_S3_ACCESS_KEY, + aws_secret_access_key=_S3_SECRET_KEY, + region_name=_S3_REGION, + ) + return session.client( + "s3", + endpoint_url=moto_endpoint, + config=Config(signature_version="s3v4"), + ) + + +def _read_rows( + uri: str, + version: int | None = None, + storage_options: dict[str, str] | None = None, +) -> list[dict[str, object]]: + kwargs: dict[str, Any] = {} + if version is not None: + kwargs["version"] = version + if storage_options is not None: + kwargs["storage_options"] = storage_options + dataset = lance.dataset(uri, **kwargs) table = dataset.to_table() return table.to_pylist() @@ -79,4 +190,31 @@ def test_time_travel_checkout(tmp_path: Path) -> None: assert rows_versioned[0]["text_payload"] == "first-entry" latest_rows = _read_rows(str(uri)) - assert [row["text_payload"] for row in latest_rows] == ["first-entry", "second-entry"] + assert [row["text_payload"] for row in latest_rows] == [ + "first-entry", + "second-entry", + ] + + +def test_s3_round_trip_remote_store(moto_endpoint: str, s3_client) -> None: + bucket = f"context-{uuid.uuid4().hex}" + s3_client.create_bucket(Bucket=bucket) + key = f"contexts/{uuid.uuid4().hex}/context.lance" + uri = f"s3://{bucket}/{key}" + + ctx = Context.create( + uri, + aws_access_key_id=_S3_ACCESS_KEY, + aws_secret_access_key=_S3_SECRET_KEY, + region=_S3_REGION, + endpoint_url=moto_endpoint, + allow_http=True, + ) + + ctx.add("user", "remote-hello") + ctx.add("assistant", "remote-response") + ctx.checkout(ctx.version()) + + rows = _read_rows(uri, storage_options=_s3_storage_options(moto_endpoint)) + assert [row["text_payload"] for row in rows] == ["remote-hello", "remote-response"] + assert ctx.entries() == 2 diff --git a/python/tests/test_search.py b/python/tests/test_search.py index 6181c5a..4e152a1 100644 --- a/python/tests/test_search.py +++ b/python/tests/test_search.py @@ -1,7 +1,6 @@ from datetime import datetime import pytest - from lance_context.api import Context, _coerce_vector, _normalize_search_hit