diff --git a/.codex/skills/ci-pr-helper/SKILL.md b/.codex/skills/ci-pr-helper/SKILL.md index e0f76f9..c87f8ac 100644 --- a/.codex/skills/ci-pr-helper/SKILL.md +++ b/.codex/skills/ci-pr-helper/SKILL.md @@ -57,4 +57,7 @@ When ready: gh pr create --title "" --body "<body>" ``` +- Some environments emit a spurious `Unsupported subcommand 'pr'` warning before running + `gh`; ignore that message and continue with the command. + - If `gh` is missing or fails, print the command instead so the user can run it locally. diff --git a/python/Cargo.lock b/python/Cargo.lock index 9145dd0..c310169 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -3397,6 +3397,7 @@ dependencies = [ "arrow-ipc", "arrow-schema", "chrono", + "futures", "lance", "lance-graph", "lance-namespace", diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py index 50b18ad..f563087 100644 --- a/python/python/lance_context/api.py +++ b/python/python/lance_context/api.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import datetime from io import BytesIO from typing import Any @@ -96,6 +97,34 @@ def _normalize_content(value: Any, content_type: str | None): return value, content_type +def _coerce_vector(query: Any) -> list[float]: + if hasattr(query, "tolist"): + query = query.tolist() + elif hasattr(query, "__array__"): + query = query.__array__().tolist() + if isinstance(query, (list, tuple)): + return [float(item) for item in query] + raise TypeError("search query must be a sequence of floats") + + +def _normalize_search_hit(raw: dict[str, Any]) -> dict[str, Any]: + created_at = raw.get("created_at") + if isinstance(created_at, str): + created_at = datetime.fromisoformat(created_at.replace("Z", "+00:00")) + return { + "id": raw.get("id"), + "run_id": raw.get("run_id"), + "role": raw.get("role"), + "content_type": raw.get("content_type"), + "text": raw.get("text_payload"), + "binary": raw.get("binary_payload"), + "embedding": raw.get("embedding"), + "distance": raw.get("distance"), + "created_at": created_at, + "state_metadata": raw.get("state_metadata"), + } + + class Context: def __init__(self, uri: str) -> None: self._inner = _Context.create(uri) @@ -140,6 +169,11 @@ def fork(self, branch_name: str) -> Context: def checkout(self, version_id: int | str) -> None: self._inner.checkout(int(version_id)) + def search(self, query: Any, limit: int | None = None) -> list[dict[str, Any]]: + vector = _coerce_vector(query) + results = self._inner.search(vector, limit) + return [_normalize_search_hit(item) for item in results] + def __repr__(self) -> str: return ( f"Context(uri={self._inner.uri()!r}, " diff --git a/python/src/lib.rs b/python/src/lib.rs index f0f89cf..befdd70 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1,13 +1,14 @@ use std::sync::Arc; -use chrono::Utc; +use chrono::{SecondsFormat, Utc}; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; -use pyo3::types::PyType; +use pyo3::types::{PyBytes, PyDict, PyType}; +use pyo3::IntoPyObject; use tokio::runtime::Runtime; use lance_context::serde::CONTENT_TYPE_TEXT; -use lance_context::{Context as RustContext, ContextRecord, ContextStore}; +use lance_context::{Context as RustContext, ContextRecord, ContextStore, SearchResult}; const DEFAULT_BINARY_CONTENT_TYPE: &str = "application/octet-stream"; const BINARY_PLACEHOLDER: &str = "[binary]"; @@ -68,9 +69,7 @@ impl Context { let (content_type, text_payload, binary_payload, inner_content) = match content.extract::<&[u8]>() { Ok(bytes) => ( - data_type - .unwrap_or(DEFAULT_BINARY_CONTENT_TYPE) - .to_string(), + data_type.unwrap_or(DEFAULT_BINARY_CONTENT_TYPE).to_string(), None, Some(bytes.to_vec()), BINARY_PLACEHOLDER.to_string(), @@ -127,10 +126,76 @@ impl Context { self.run_id = new_run_id(); Ok(()) } + + #[pyo3(signature = (query, limit = None))] + fn search( + &self, + py: Python<'_>, + query: Vec<f32>, + limit: Option<usize>, + ) -> PyResult<Vec<PyObject>> { + let hits = self + .runtime + .block_on(self.store.search(&query, limit)) + .map_err(to_py_err)?; + hits.into_iter() + .map(|hit| search_hit_to_py(py, hit)) + .collect() + } } fn new_run_id() -> String { - format!("run-{}-{}", Utc::now().timestamp_micros(), std::process::id()) + format!( + "run-{}-{}", + Utc::now().timestamp_micros(), + std::process::id() + ) +} + +fn search_hit_to_py(py: Python<'_>, hit: SearchResult) -> PyResult<PyObject> { + let SearchResult { record, distance } = hit; + let ContextRecord { + id, + run_id, + created_at, + role, + state_metadata, + content_type, + text_payload, + binary_payload, + embedding, + } = record; + + let dict = PyDict::new(py); + dict.set_item("id", id)?; + dict.set_item("run_id", run_id)?; + dict.set_item( + "created_at", + created_at.to_rfc3339_opts(SecondsFormat::Micros, true), + )?; + dict.set_item("role", role)?; + + let state_obj: PyObject = match state_metadata { + Some(metadata) => { + let state_dict = PyDict::new(py); + state_dict.set_item("step", metadata.step)?; + state_dict.set_item("active_plan_id", metadata.active_plan_id)?; + state_dict.set_item("tokens_used", metadata.tokens_used)?; + state_dict.set_item("custom", metadata.custom)?; + state_dict.into_pyobject(py)?.unbind().into() + } + None => py.None().into_pyobject(py)?.unbind().into(), + }; + dict.set_item("state_metadata", state_obj)?; + dict.set_item("content_type", content_type)?; + dict.set_item("text_payload", text_payload)?; + match binary_payload { + Some(payload) => dict.set_item("binary_payload", PyBytes::new(py, &payload))?, + None => dict.set_item("binary_payload", py.None())?, + } + dict.set_item("embedding", embedding)?; + dict.set_item("distance", distance)?; + Ok(dict.into_pyobject(py)?.unbind().into()) } fn to_py_err<E: std::fmt::Display>(err: E) -> PyErr { diff --git a/python/tests/test_search.py b/python/tests/test_search.py new file mode 100644 index 0000000..6181c5a --- /dev/null +++ b/python/tests/test_search.py @@ -0,0 +1,68 @@ +from datetime import datetime + +import pytest + +from lance_context.api import Context, _coerce_vector, _normalize_search_hit + + +class DummyInner: + def __init__(self) -> None: + self.calls: list[tuple[list[float], int | None]] = [] + + def search(self, vector: list[float], limit: int | None): + self.calls.append((vector, limit)) + return [ + { + "id": "rec-1", + "run_id": "run-1", + "role": "user", + "content_type": "text/plain", + "text_payload": "hello", + "binary_payload": None, + "embedding": [0.1, 0.2], + "distance": 0.12, + "created_at": "2024-01-01T12:00:00Z", + "state_metadata": {"step": 1}, + } + ] + + +def test_coerce_vector_from_list(): + assert _coerce_vector([1, 2.5]) == [1.0, 2.5] + + +def test_coerce_vector_rejects_invalid(): + with pytest.raises(TypeError): + _coerce_vector("invalid") + + +def test_normalize_search_hit_converts_timestamp(): + result = _normalize_search_hit( + { + "id": "rec-2", + "created_at": "2024-01-01T00:00:00Z", + "content_type": "text/plain", + "text_payload": None, + "binary_payload": None, + "embedding": None, + "distance": 0.5, + "run_id": "run-2", + "role": "assistant", + "state_metadata": None, + } + ) + assert isinstance(result["created_at"], datetime) + + +def test_context_search_formats_results(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + hits = ctx.search([0.5, 0.4], limit=3) + + assert dummy.calls == [([0.5, 0.4], 3)] + assert hits[0]["id"] == "rec-1" + assert hits[0]["text"] == "hello" + assert hits[0]["binary"] is None + assert isinstance(hits[0]["created_at"], datetime) diff --git a/rust/lance-context/Cargo.lock b/rust/lance-context/Cargo.lock index 8deac67..ace5cb2 100644 --- a/rust/lance-context/Cargo.lock +++ b/rust/lance-context/Cargo.lock @@ -3388,11 +3388,14 @@ dependencies = [ "arrow-ipc", "arrow-schema", "chrono", + "futures", "lance", "lance-graph", "lance-namespace", "lancedb", "serde", + "tempfile", + "tokio", ] [[package]] diff --git a/rust/lance-context/Cargo.toml b/rust/lance-context/Cargo.toml index 016d020..67a5752 100644 --- a/rust/lance-context/Cargo.toml +++ b/rust/lance-context/Cargo.toml @@ -20,3 +20,8 @@ lancedb = "0.23.1" lance-namespace = "1.0.1" lance-graph = "0.4.0" serde = { version = "1", features = ["derive"] } +futures = "0.3" + +[dev-dependencies] +tempfile = "3" +tokio = { version = "1", features = ["rt-multi-thread"] } diff --git a/rust/lance-context/src/lib.rs b/rust/lance-context/src/lib.rs index 5e8460a..cf6f90f 100644 --- a/rust/lance-context/src/lib.rs +++ b/rust/lance-context/src/lib.rs @@ -6,5 +6,5 @@ pub mod serde; mod store; pub use context::{Context, ContextEntry, Snapshot}; -pub use record::{ContextRecord, StateMetadata}; +pub use record::{ContextRecord, SearchResult, StateMetadata}; pub use store::ContextStore; diff --git a/rust/lance-context/src/record.rs b/rust/lance-context/src/record.rs index 1ff6933..4e3aa49 100644 --- a/rust/lance-context/src/record.rs +++ b/rust/lance-context/src/record.rs @@ -22,3 +22,10 @@ pub struct ContextRecord { pub binary_payload: Option<Vec<u8>>, pub embedding: Option<Vec<f32>>, } + +/// Result returned from a vector similarity search. +#[derive(Debug, Clone)] +pub struct SearchResult { + pub record: ContextRecord, + pub distance: f32, +} diff --git a/rust/lance-context/src/store.rs b/rust/lance-context/src/store.rs index 466e597..f750442 100644 --- a/rust/lance-context/src/store.rs +++ b/rust/lance-context/src/store.rs @@ -5,15 +5,22 @@ use arrow_array::builder::{ StringBuilder, StringDictionaryBuilder, StructBuilder, TimestampMicrosecondBuilder, }; use arrow_array::types::Int8Type; -use arrow_array::{ArrayRef, RecordBatch, RecordBatchIterator}; +use arrow_array::{ + Array, ArrayRef, DictionaryArray, FixedSizeListArray, Float32Array, Int32Array, + LargeBinaryArray, LargeStringArray, RecordBatch, RecordBatchIterator, StringArray, StructArray, + TimestampMicrosecondArray, +}; use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, TimeUnit}; +use chrono::DateTime; +use futures::TryStreamExt; use lance::dataset::{Dataset, WriteMode, WriteParams}; use lance::{Error as LanceError, Result as LanceResult}; -use crate::record::ContextRecord; +use crate::record::{ContextRecord, SearchResult, StateMetadata}; /// Embedding length used for the semantic index column. const DEFAULT_EMBEDDING_DIM: i32 = 1536; +const DEFAULT_SEARCH_LIMIT: usize = 10; /// Persistent Lance-backed context store. #[derive(Clone)] @@ -73,6 +80,40 @@ impl ContextStore { Ok(()) } + /// Perform a nearest-neighbor search over stored embeddings. + pub async fn search( + &self, + query: &[f32], + limit: Option<usize>, + ) -> LanceResult<Vec<SearchResult>> { + if query.len() != DEFAULT_EMBEDDING_DIM as usize { + return Err(ArrowError::InvalidArgumentError(format!( + "query length {} does not match embedding dimension {}", + query.len(), + DEFAULT_EMBEDDING_DIM + )) + .into()); + } + + let top_k = limit.unwrap_or(DEFAULT_SEARCH_LIMIT); + if top_k == 0 { + return Ok(Vec::new()); + } + + let query_array = Float32Array::from(query.to_vec()); + + let mut scanner = self.dataset.scan(); + scanner.nearest("embedding", &query_array, top_k)?; + scanner.limit(Some(top_k as i64), Some(0))?; + + let mut stream = scanner.try_into_stream().await?; + let mut results = Vec::new(); + while let Some(batch) = stream.try_next().await? { + results.extend(batch_to_search_results(&batch)?); + } + Ok(results) + } + /// Lance schema for the context store. pub fn schema() -> Schema { Schema::new(vec![ @@ -253,3 +294,273 @@ impl ContextStore { Ok(batch) } } + +fn batch_to_search_results(batch: &RecordBatch) -> LanceResult<Vec<SearchResult>> { + let id_array = column_as::<StringArray>(batch, "id")?; + let run_id_array = column_as::<StringArray>(batch, "run_id")?; + let created_at_array = column_as::<TimestampMicrosecondArray>(batch, "created_at")?; + let role_array = column_as::<DictionaryArray<Int8Type>>(batch, "role")?; + let state_array = column_as::<StructArray>(batch, "state_metadata")?; + let content_type_array = column_as::<StringArray>(batch, "content_type")?; + let text_array = column_as::<LargeStringArray>(batch, "text_payload")?; + let binary_array = column_as::<LargeBinaryArray>(batch, "binary_payload")?; + let embedding_array = column_as::<FixedSizeListArray>(batch, "embedding")?; + + let distance_column = batch.column_by_name("_distance").ok_or_else(|| { + LanceError::from(ArrowError::InvalidArgumentError( + "search results missing _distance column".to_string(), + )) + })?; + let distance_array = distance_column + .as_ref() + .as_any() + .downcast_ref::<Float32Array>() + .ok_or_else(|| { + LanceError::from(ArrowError::InvalidArgumentError( + "_distance column has unexpected data type".to_string(), + )) + })?; + + let step_array = state_array + .column(0) + .as_ref() + .as_any() + .downcast_ref::<Int32Array>() + .ok_or_else(|| { + LanceError::from(ArrowError::InvalidArgumentError( + "step column has unexpected data type".to_string(), + )) + })?; + let active_plan_array = state_array + .column(1) + .as_ref() + .as_any() + .downcast_ref::<StringArray>() + .ok_or_else(|| { + LanceError::from(ArrowError::InvalidArgumentError( + "active_plan_id column has unexpected data type".to_string(), + )) + })?; + let tokens_used_array = state_array + .column(2) + .as_ref() + .as_any() + .downcast_ref::<Int32Array>() + .ok_or_else(|| { + LanceError::from(ArrowError::InvalidArgumentError( + "tokens_used column has unexpected data type".to_string(), + )) + })?; + let custom_array = state_array + .column(3) + .as_ref() + .as_any() + .downcast_ref::<StringArray>() + .ok_or_else(|| { + LanceError::from(ArrowError::InvalidArgumentError( + "custom column has unexpected data type".to_string(), + )) + })?; + + let mut results = Vec::with_capacity(batch.num_rows()); + for row in 0..batch.num_rows() { + let created_at = + DateTime::from_timestamp_micros(created_at_array.value(row)).ok_or_else(|| { + LanceError::from(ArrowError::InvalidArgumentError(format!( + "invalid timestamp value {}", + created_at_array.value(row) + ))) + })?; + + let state_metadata = if state_array.is_null(row) { + None + } else { + Some(StateMetadata { + step: if step_array.is_null(row) { + None + } else { + Some(step_array.value(row)) + }, + active_plan_id: if active_plan_array.is_null(row) { + None + } else { + Some(active_plan_array.value(row).to_string()) + }, + tokens_used: if tokens_used_array.is_null(row) { + None + } else { + Some(tokens_used_array.value(row)) + }, + custom: if custom_array.is_null(row) { + None + } else { + Some(custom_array.value(row).to_string()) + }, + }) + }; + + let text_payload = if text_array.is_null(row) { + None + } else { + Some(text_array.value(row).to_string()) + }; + + let binary_payload = if binary_array.is_null(row) { + None + } else { + Some(binary_array.value(row).to_vec()) + }; + + let embedding = if embedding_array.is_null(row) { + None + } else { + Some(embedding_from_list(embedding_array, row)?) + }; + + let role = if role_array.is_null(row) { + return Err(LanceError::from(ArrowError::InvalidArgumentError( + "role column contains null values".to_string(), + ))); + } else { + let role_values = role_array + .values() + .as_any() + .downcast_ref::<StringArray>() + .ok_or_else(|| { + LanceError::from(ArrowError::InvalidArgumentError( + "role dictionary values are not strings".to_string(), + )) + })?; + let key = role_array.keys().value(row) as usize; + role_values.value(key).to_string() + }; + + let record = ContextRecord { + id: id_array.value(row).to_string(), + run_id: run_id_array.value(row).to_string(), + created_at, + role, + state_metadata, + content_type: content_type_array.value(row).to_string(), + text_payload, + binary_payload, + embedding, + }; + + results.push(SearchResult { + record, + distance: distance_array.value(row), + }); + } + + Ok(results) +} + +fn embedding_from_list(list: &FixedSizeListArray, row: usize) -> LanceResult<Vec<f32>> { + let values = list.value(row); + let float_array = values + .as_ref() + .as_any() + .downcast_ref::<Float32Array>() + .ok_or_else(|| { + LanceError::from(ArrowError::InvalidArgumentError( + "embedding column does not contain float32 values".to_string(), + )) + })?; + let mut embedding = Vec::with_capacity(float_array.len()); + for idx in 0..float_array.len() { + embedding.push(float_array.value(idx)); + } + Ok(embedding) +} + +fn column_as<'a, A>(batch: &'a RecordBatch, name: &str) -> LanceResult<&'a A> +where + A: Array + 'static, +{ + let column = batch.column_by_name(name).ok_or_else(|| { + LanceError::from(ArrowError::InvalidArgumentError(format!( + "column '{name}' not found" + ))) + })?; + column.as_ref().as_any().downcast_ref::<A>().ok_or_else(|| { + LanceError::from(ArrowError::InvalidArgumentError(format!( + "column '{name}' has unexpected data type" + ))) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::serde::CONTENT_TYPE_TEXT; + use chrono::Utc; + use tempfile::TempDir; + + fn make_embedding(pivot: f32) -> Vec<f32> { + let mut values = vec![0.0; DEFAULT_EMBEDDING_DIM as usize]; + if !values.is_empty() { + values[0] = pivot; + } + values + } + + fn text_record(id: &str, embedding_pivot: f32) -> ContextRecord { + ContextRecord { + id: id.to_string(), + run_id: format!("run-{id}"), + created_at: Utc::now(), + role: "user".to_string(), + state_metadata: Some(StateMetadata { + step: Some(1), + active_plan_id: Some("plan".to_string()), + tokens_used: Some(10), + custom: None, + }), + content_type: CONTENT_TYPE_TEXT.to_string(), + text_payload: Some(format!("payload-{id}")), + binary_payload: None, + embedding: Some(make_embedding(embedding_pivot)), + } + } + + #[test] + fn search_orders_by_distance() { + let dir = TempDir::new().unwrap(); + let uri = dir.path().to_string_lossy().to_string(); + let runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let mut store = ContextStore::open(&uri).await.unwrap(); + let first = text_record("a", 0.0); + let second = text_record("b", 1.0); + store.add(&[first.clone(), second.clone()]).await.unwrap(); + + let query = make_embedding(1.0); + let results = store.search(&query, Some(2)).await.unwrap(); + + assert_eq!(results.len(), 2); + assert_eq!(results[0].record.id, second.id); + assert!( + results[0].distance <= results[1].distance, + "results not ordered by distance: {:?}", + results + ); + }); + } + + #[test] + fn search_validates_query_length() { + let dir = TempDir::new().unwrap(); + let uri = dir.path().to_string_lossy().to_string(); + let runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let store = ContextStore::open(&uri).await.unwrap(); + let err = store.search(&[0.0_f32], None).await.unwrap_err(); + let message = err.to_string(); + assert!( + message.contains("embedding dimension"), + "unexpected error message: {message}" + ); + }); + } +}