diff --git a/.codex/skills/ci-pr-helper/SKILL.md b/.codex/skills/ci-pr-helper/SKILL.md
index e0f76f9..c87f8ac 100644
--- a/.codex/skills/ci-pr-helper/SKILL.md
+++ b/.codex/skills/ci-pr-helper/SKILL.md
@@ -57,4 +57,7 @@ When ready:
gh pr create --title "
" --body ""
```
+- Some environments emit a spurious `Unsupported subcommand 'pr'` warning before running
+ `gh`; ignore that message and continue with the command.
+
- If `gh` is missing or fails, print the command instead so the user can run it locally.
diff --git a/python/Cargo.lock b/python/Cargo.lock
index 9145dd0..c310169 100644
--- a/python/Cargo.lock
+++ b/python/Cargo.lock
@@ -3397,6 +3397,7 @@ dependencies = [
"arrow-ipc",
"arrow-schema",
"chrono",
+ "futures",
"lance",
"lance-graph",
"lance-namespace",
diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py
index 50b18ad..f563087 100644
--- a/python/python/lance_context/api.py
+++ b/python/python/lance_context/api.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+from datetime import datetime
from io import BytesIO
from typing import Any
@@ -96,6 +97,34 @@ def _normalize_content(value: Any, content_type: str | None):
return value, content_type
+def _coerce_vector(query: Any) -> list[float]:
+ if hasattr(query, "tolist"):
+ query = query.tolist()
+ elif hasattr(query, "__array__"):
+ query = query.__array__().tolist()
+ if isinstance(query, (list, tuple)):
+ return [float(item) for item in query]
+ raise TypeError("search query must be a sequence of floats")
+
+
+def _normalize_search_hit(raw: dict[str, Any]) -> dict[str, Any]:
+ created_at = raw.get("created_at")
+ if isinstance(created_at, str):
+ created_at = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
+ return {
+ "id": raw.get("id"),
+ "run_id": raw.get("run_id"),
+ "role": raw.get("role"),
+ "content_type": raw.get("content_type"),
+ "text": raw.get("text_payload"),
+ "binary": raw.get("binary_payload"),
+ "embedding": raw.get("embedding"),
+ "distance": raw.get("distance"),
+ "created_at": created_at,
+ "state_metadata": raw.get("state_metadata"),
+ }
+
+
class Context:
def __init__(self, uri: str) -> None:
self._inner = _Context.create(uri)
@@ -140,6 +169,11 @@ def fork(self, branch_name: str) -> Context:
def checkout(self, version_id: int | str) -> None:
self._inner.checkout(int(version_id))
+ def search(self, query: Any, limit: int | None = None) -> list[dict[str, Any]]:
+ vector = _coerce_vector(query)
+ results = self._inner.search(vector, limit)
+ return [_normalize_search_hit(item) for item in results]
+
def __repr__(self) -> str:
return (
f"Context(uri={self._inner.uri()!r}, "
diff --git a/python/src/lib.rs b/python/src/lib.rs
index f0f89cf..befdd70 100644
--- a/python/src/lib.rs
+++ b/python/src/lib.rs
@@ -1,13 +1,14 @@
use std::sync::Arc;
-use chrono::Utc;
+use chrono::{SecondsFormat, Utc};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
-use pyo3::types::PyType;
+use pyo3::types::{PyBytes, PyDict, PyType};
+use pyo3::IntoPyObject;
use tokio::runtime::Runtime;
use lance_context::serde::CONTENT_TYPE_TEXT;
-use lance_context::{Context as RustContext, ContextRecord, ContextStore};
+use lance_context::{Context as RustContext, ContextRecord, ContextStore, SearchResult};
const DEFAULT_BINARY_CONTENT_TYPE: &str = "application/octet-stream";
const BINARY_PLACEHOLDER: &str = "[binary]";
@@ -68,9 +69,7 @@ impl Context {
let (content_type, text_payload, binary_payload, inner_content) =
match content.extract::<&[u8]>() {
Ok(bytes) => (
- data_type
- .unwrap_or(DEFAULT_BINARY_CONTENT_TYPE)
- .to_string(),
+ data_type.unwrap_or(DEFAULT_BINARY_CONTENT_TYPE).to_string(),
None,
Some(bytes.to_vec()),
BINARY_PLACEHOLDER.to_string(),
@@ -127,10 +126,76 @@ impl Context {
self.run_id = new_run_id();
Ok(())
}
+
+ #[pyo3(signature = (query, limit = None))]
+ fn search(
+ &self,
+ py: Python<'_>,
+ query: Vec,
+ limit: Option,
+ ) -> PyResult> {
+ let hits = self
+ .runtime
+ .block_on(self.store.search(&query, limit))
+ .map_err(to_py_err)?;
+ hits.into_iter()
+ .map(|hit| search_hit_to_py(py, hit))
+ .collect()
+ }
}
fn new_run_id() -> String {
- format!("run-{}-{}", Utc::now().timestamp_micros(), std::process::id())
+ format!(
+ "run-{}-{}",
+ Utc::now().timestamp_micros(),
+ std::process::id()
+ )
+}
+
+fn search_hit_to_py(py: Python<'_>, hit: SearchResult) -> PyResult {
+ let SearchResult { record, distance } = hit;
+ let ContextRecord {
+ id,
+ run_id,
+ created_at,
+ role,
+ state_metadata,
+ content_type,
+ text_payload,
+ binary_payload,
+ embedding,
+ } = record;
+
+ let dict = PyDict::new(py);
+ dict.set_item("id", id)?;
+ dict.set_item("run_id", run_id)?;
+ dict.set_item(
+ "created_at",
+ created_at.to_rfc3339_opts(SecondsFormat::Micros, true),
+ )?;
+ dict.set_item("role", role)?;
+
+ let state_obj: PyObject = match state_metadata {
+ Some(metadata) => {
+ let state_dict = PyDict::new(py);
+ state_dict.set_item("step", metadata.step)?;
+ state_dict.set_item("active_plan_id", metadata.active_plan_id)?;
+ state_dict.set_item("tokens_used", metadata.tokens_used)?;
+ state_dict.set_item("custom", metadata.custom)?;
+ state_dict.into_pyobject(py)?.unbind().into()
+ }
+ None => py.None().into_pyobject(py)?.unbind().into(),
+ };
+ dict.set_item("state_metadata", state_obj)?;
+ dict.set_item("content_type", content_type)?;
+ dict.set_item("text_payload", text_payload)?;
+ match binary_payload {
+ Some(payload) => dict.set_item("binary_payload", PyBytes::new(py, &payload))?,
+ None => dict.set_item("binary_payload", py.None())?,
+ }
+ dict.set_item("embedding", embedding)?;
+ dict.set_item("distance", distance)?;
+ Ok(dict.into_pyobject(py)?.unbind().into())
}
fn to_py_err(err: E) -> PyErr {
diff --git a/python/tests/test_search.py b/python/tests/test_search.py
new file mode 100644
index 0000000..6181c5a
--- /dev/null
+++ b/python/tests/test_search.py
@@ -0,0 +1,68 @@
+from datetime import datetime
+
+import pytest
+
+from lance_context.api import Context, _coerce_vector, _normalize_search_hit
+
+
+class DummyInner:
+ def __init__(self) -> None:
+ self.calls: list[tuple[list[float], int | None]] = []
+
+ def search(self, vector: list[float], limit: int | None):
+ self.calls.append((vector, limit))
+ return [
+ {
+ "id": "rec-1",
+ "run_id": "run-1",
+ "role": "user",
+ "content_type": "text/plain",
+ "text_payload": "hello",
+ "binary_payload": None,
+ "embedding": [0.1, 0.2],
+ "distance": 0.12,
+ "created_at": "2024-01-01T12:00:00Z",
+ "state_metadata": {"step": 1},
+ }
+ ]
+
+
+def test_coerce_vector_from_list():
+ assert _coerce_vector([1, 2.5]) == [1.0, 2.5]
+
+
+def test_coerce_vector_rejects_invalid():
+ with pytest.raises(TypeError):
+ _coerce_vector("invalid")
+
+
+def test_normalize_search_hit_converts_timestamp():
+ result = _normalize_search_hit(
+ {
+ "id": "rec-2",
+ "created_at": "2024-01-01T00:00:00Z",
+ "content_type": "text/plain",
+ "text_payload": None,
+ "binary_payload": None,
+ "embedding": None,
+ "distance": 0.5,
+ "run_id": "run-2",
+ "role": "assistant",
+ "state_metadata": None,
+ }
+ )
+ assert isinstance(result["created_at"], datetime)
+
+
+def test_context_search_formats_results():
+ ctx = Context.__new__(Context)
+ dummy = DummyInner()
+ ctx._inner = dummy # type: ignore[attr-defined]
+
+ hits = ctx.search([0.5, 0.4], limit=3)
+
+ assert dummy.calls == [([0.5, 0.4], 3)]
+ assert hits[0]["id"] == "rec-1"
+ assert hits[0]["text"] == "hello"
+ assert hits[0]["binary"] is None
+ assert isinstance(hits[0]["created_at"], datetime)
diff --git a/rust/lance-context/Cargo.lock b/rust/lance-context/Cargo.lock
index 8deac67..ace5cb2 100644
--- a/rust/lance-context/Cargo.lock
+++ b/rust/lance-context/Cargo.lock
@@ -3388,11 +3388,14 @@ dependencies = [
"arrow-ipc",
"arrow-schema",
"chrono",
+ "futures",
"lance",
"lance-graph",
"lance-namespace",
"lancedb",
"serde",
+ "tempfile",
+ "tokio",
]
[[package]]
diff --git a/rust/lance-context/Cargo.toml b/rust/lance-context/Cargo.toml
index 016d020..67a5752 100644
--- a/rust/lance-context/Cargo.toml
+++ b/rust/lance-context/Cargo.toml
@@ -20,3 +20,8 @@ lancedb = "0.23.1"
lance-namespace = "1.0.1"
lance-graph = "0.4.0"
serde = { version = "1", features = ["derive"] }
+futures = "0.3"
+
+[dev-dependencies]
+tempfile = "3"
+tokio = { version = "1", features = ["rt-multi-thread"] }
diff --git a/rust/lance-context/src/lib.rs b/rust/lance-context/src/lib.rs
index 5e8460a..cf6f90f 100644
--- a/rust/lance-context/src/lib.rs
+++ b/rust/lance-context/src/lib.rs
@@ -6,5 +6,5 @@ pub mod serde;
mod store;
pub use context::{Context, ContextEntry, Snapshot};
-pub use record::{ContextRecord, StateMetadata};
+pub use record::{ContextRecord, SearchResult, StateMetadata};
pub use store::ContextStore;
diff --git a/rust/lance-context/src/record.rs b/rust/lance-context/src/record.rs
index 1ff6933..4e3aa49 100644
--- a/rust/lance-context/src/record.rs
+++ b/rust/lance-context/src/record.rs
@@ -22,3 +22,10 @@ pub struct ContextRecord {
pub binary_payload: Option>,
pub embedding: Option>,
}
+
+/// Result returned from a vector similarity search.
+#[derive(Debug, Clone)]
+pub struct SearchResult {
+ pub record: ContextRecord,
+ pub distance: f32,
+}
diff --git a/rust/lance-context/src/store.rs b/rust/lance-context/src/store.rs
index 466e597..f750442 100644
--- a/rust/lance-context/src/store.rs
+++ b/rust/lance-context/src/store.rs
@@ -5,15 +5,22 @@ use arrow_array::builder::{
StringBuilder, StringDictionaryBuilder, StructBuilder, TimestampMicrosecondBuilder,
};
use arrow_array::types::Int8Type;
-use arrow_array::{ArrayRef, RecordBatch, RecordBatchIterator};
+use arrow_array::{
+ Array, ArrayRef, DictionaryArray, FixedSizeListArray, Float32Array, Int32Array,
+ LargeBinaryArray, LargeStringArray, RecordBatch, RecordBatchIterator, StringArray, StructArray,
+ TimestampMicrosecondArray,
+};
use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, TimeUnit};
+use chrono::DateTime;
+use futures::TryStreamExt;
use lance::dataset::{Dataset, WriteMode, WriteParams};
use lance::{Error as LanceError, Result as LanceResult};
-use crate::record::ContextRecord;
+use crate::record::{ContextRecord, SearchResult, StateMetadata};
/// Embedding length used for the semantic index column.
const DEFAULT_EMBEDDING_DIM: i32 = 1536;
+const DEFAULT_SEARCH_LIMIT: usize = 10;
/// Persistent Lance-backed context store.
#[derive(Clone)]
@@ -73,6 +80,40 @@ impl ContextStore {
Ok(())
}
+ /// Perform a nearest-neighbor search over stored embeddings.
+ pub async fn search(
+ &self,
+ query: &[f32],
+ limit: Option,
+ ) -> LanceResult> {
+ if query.len() != DEFAULT_EMBEDDING_DIM as usize {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "query length {} does not match embedding dimension {}",
+ query.len(),
+ DEFAULT_EMBEDDING_DIM
+ ))
+ .into());
+ }
+
+ let top_k = limit.unwrap_or(DEFAULT_SEARCH_LIMIT);
+ if top_k == 0 {
+ return Ok(Vec::new());
+ }
+
+ let query_array = Float32Array::from(query.to_vec());
+
+ let mut scanner = self.dataset.scan();
+ scanner.nearest("embedding", &query_array, top_k)?;
+ scanner.limit(Some(top_k as i64), Some(0))?;
+
+ let mut stream = scanner.try_into_stream().await?;
+ let mut results = Vec::new();
+ while let Some(batch) = stream.try_next().await? {
+ results.extend(batch_to_search_results(&batch)?);
+ }
+ Ok(results)
+ }
+
/// Lance schema for the context store.
pub fn schema() -> Schema {
Schema::new(vec![
@@ -253,3 +294,273 @@ impl ContextStore {
Ok(batch)
}
}
+
+fn batch_to_search_results(batch: &RecordBatch) -> LanceResult> {
+ let id_array = column_as::(batch, "id")?;
+ let run_id_array = column_as::(batch, "run_id")?;
+ let created_at_array = column_as::(batch, "created_at")?;
+ let role_array = column_as::>(batch, "role")?;
+ let state_array = column_as::(batch, "state_metadata")?;
+ let content_type_array = column_as::(batch, "content_type")?;
+ let text_array = column_as::(batch, "text_payload")?;
+ let binary_array = column_as::(batch, "binary_payload")?;
+ let embedding_array = column_as::(batch, "embedding")?;
+
+ let distance_column = batch.column_by_name("_distance").ok_or_else(|| {
+ LanceError::from(ArrowError::InvalidArgumentError(
+ "search results missing _distance column".to_string(),
+ ))
+ })?;
+ let distance_array = distance_column
+ .as_ref()
+ .as_any()
+ .downcast_ref::()
+ .ok_or_else(|| {
+ LanceError::from(ArrowError::InvalidArgumentError(
+ "_distance column has unexpected data type".to_string(),
+ ))
+ })?;
+
+ let step_array = state_array
+ .column(0)
+ .as_ref()
+ .as_any()
+ .downcast_ref::()
+ .ok_or_else(|| {
+ LanceError::from(ArrowError::InvalidArgumentError(
+ "step column has unexpected data type".to_string(),
+ ))
+ })?;
+ let active_plan_array = state_array
+ .column(1)
+ .as_ref()
+ .as_any()
+ .downcast_ref::()
+ .ok_or_else(|| {
+ LanceError::from(ArrowError::InvalidArgumentError(
+ "active_plan_id column has unexpected data type".to_string(),
+ ))
+ })?;
+ let tokens_used_array = state_array
+ .column(2)
+ .as_ref()
+ .as_any()
+ .downcast_ref::()
+ .ok_or_else(|| {
+ LanceError::from(ArrowError::InvalidArgumentError(
+ "tokens_used column has unexpected data type".to_string(),
+ ))
+ })?;
+ let custom_array = state_array
+ .column(3)
+ .as_ref()
+ .as_any()
+ .downcast_ref::()
+ .ok_or_else(|| {
+ LanceError::from(ArrowError::InvalidArgumentError(
+ "custom column has unexpected data type".to_string(),
+ ))
+ })?;
+
+ let mut results = Vec::with_capacity(batch.num_rows());
+ for row in 0..batch.num_rows() {
+ let created_at =
+ DateTime::from_timestamp_micros(created_at_array.value(row)).ok_or_else(|| {
+ LanceError::from(ArrowError::InvalidArgumentError(format!(
+ "invalid timestamp value {}",
+ created_at_array.value(row)
+ )))
+ })?;
+
+ let state_metadata = if state_array.is_null(row) {
+ None
+ } else {
+ Some(StateMetadata {
+ step: if step_array.is_null(row) {
+ None
+ } else {
+ Some(step_array.value(row))
+ },
+ active_plan_id: if active_plan_array.is_null(row) {
+ None
+ } else {
+ Some(active_plan_array.value(row).to_string())
+ },
+ tokens_used: if tokens_used_array.is_null(row) {
+ None
+ } else {
+ Some(tokens_used_array.value(row))
+ },
+ custom: if custom_array.is_null(row) {
+ None
+ } else {
+ Some(custom_array.value(row).to_string())
+ },
+ })
+ };
+
+ let text_payload = if text_array.is_null(row) {
+ None
+ } else {
+ Some(text_array.value(row).to_string())
+ };
+
+ let binary_payload = if binary_array.is_null(row) {
+ None
+ } else {
+ Some(binary_array.value(row).to_vec())
+ };
+
+ let embedding = if embedding_array.is_null(row) {
+ None
+ } else {
+ Some(embedding_from_list(embedding_array, row)?)
+ };
+
+ let role = if role_array.is_null(row) {
+ return Err(LanceError::from(ArrowError::InvalidArgumentError(
+ "role column contains null values".to_string(),
+ )));
+ } else {
+ let role_values = role_array
+ .values()
+ .as_any()
+ .downcast_ref::()
+ .ok_or_else(|| {
+ LanceError::from(ArrowError::InvalidArgumentError(
+ "role dictionary values are not strings".to_string(),
+ ))
+ })?;
+ let key = role_array.keys().value(row) as usize;
+ role_values.value(key).to_string()
+ };
+
+ let record = ContextRecord {
+ id: id_array.value(row).to_string(),
+ run_id: run_id_array.value(row).to_string(),
+ created_at,
+ role,
+ state_metadata,
+ content_type: content_type_array.value(row).to_string(),
+ text_payload,
+ binary_payload,
+ embedding,
+ };
+
+ results.push(SearchResult {
+ record,
+ distance: distance_array.value(row),
+ });
+ }
+
+ Ok(results)
+}
+
+fn embedding_from_list(list: &FixedSizeListArray, row: usize) -> LanceResult> {
+ let values = list.value(row);
+ let float_array = values
+ .as_ref()
+ .as_any()
+ .downcast_ref::()
+ .ok_or_else(|| {
+ LanceError::from(ArrowError::InvalidArgumentError(
+ "embedding column does not contain float32 values".to_string(),
+ ))
+ })?;
+ let mut embedding = Vec::with_capacity(float_array.len());
+ for idx in 0..float_array.len() {
+ embedding.push(float_array.value(idx));
+ }
+ Ok(embedding)
+}
+
+fn column_as<'a, A>(batch: &'a RecordBatch, name: &str) -> LanceResult<&'a A>
+where
+ A: Array + 'static,
+{
+ let column = batch.column_by_name(name).ok_or_else(|| {
+ LanceError::from(ArrowError::InvalidArgumentError(format!(
+ "column '{name}' not found"
+ )))
+ })?;
+ column.as_ref().as_any().downcast_ref::().ok_or_else(|| {
+ LanceError::from(ArrowError::InvalidArgumentError(format!(
+ "column '{name}' has unexpected data type"
+ )))
+ })
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::serde::CONTENT_TYPE_TEXT;
+ use chrono::Utc;
+ use tempfile::TempDir;
+
+ fn make_embedding(pivot: f32) -> Vec {
+ let mut values = vec![0.0; DEFAULT_EMBEDDING_DIM as usize];
+ if !values.is_empty() {
+ values[0] = pivot;
+ }
+ values
+ }
+
+ fn text_record(id: &str, embedding_pivot: f32) -> ContextRecord {
+ ContextRecord {
+ id: id.to_string(),
+ run_id: format!("run-{id}"),
+ created_at: Utc::now(),
+ role: "user".to_string(),
+ state_metadata: Some(StateMetadata {
+ step: Some(1),
+ active_plan_id: Some("plan".to_string()),
+ tokens_used: Some(10),
+ custom: None,
+ }),
+ content_type: CONTENT_TYPE_TEXT.to_string(),
+ text_payload: Some(format!("payload-{id}")),
+ binary_payload: None,
+ embedding: Some(make_embedding(embedding_pivot)),
+ }
+ }
+
+ #[test]
+ fn search_orders_by_distance() {
+ let dir = TempDir::new().unwrap();
+ let uri = dir.path().to_string_lossy().to_string();
+ let runtime = tokio::runtime::Runtime::new().unwrap();
+ runtime.block_on(async {
+ let mut store = ContextStore::open(&uri).await.unwrap();
+ let first = text_record("a", 0.0);
+ let second = text_record("b", 1.0);
+ store.add(&[first.clone(), second.clone()]).await.unwrap();
+
+ let query = make_embedding(1.0);
+ let results = store.search(&query, Some(2)).await.unwrap();
+
+ assert_eq!(results.len(), 2);
+ assert_eq!(results[0].record.id, second.id);
+ assert!(
+ results[0].distance <= results[1].distance,
+ "results not ordered by distance: {:?}",
+ results
+ );
+ });
+ }
+
+ #[test]
+ fn search_validates_query_length() {
+ let dir = TempDir::new().unwrap();
+ let uri = dir.path().to_string_lossy().to_string();
+ let runtime = tokio::runtime::Runtime::new().unwrap();
+ runtime.block_on(async {
+ let store = ContextStore::open(&uri).await.unwrap();
+ let err = store.search(&[0.0_f32], None).await.unwrap_err();
+ let message = err.to_string();
+ assert!(
+ message.contains("embedding dimension"),
+ "unexpected error message: {message}"
+ );
+ });
+ }
+}