Skip to content

Commit 5f14605

Browse files
committed
feat: add context search API
1 parent 5857b06 commit 5f14605

9 files changed

Lines changed: 498 additions & 6 deletions

File tree

python/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/python/lance_context/api.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from datetime import datetime
34
from io import BytesIO
45
from typing import Any
56

@@ -96,6 +97,34 @@ def _normalize_content(value: Any, content_type: str | None):
9697
return value, content_type
9798

9899

100+
def _coerce_vector(query: Any) -> list[float]:
101+
if hasattr(query, "tolist"):
102+
query = query.tolist()
103+
elif hasattr(query, "__array__"):
104+
query = query.__array__().tolist()
105+
if isinstance(query, (list, tuple)):
106+
return [float(item) for item in query]
107+
raise TypeError("search query must be a sequence of floats")
108+
109+
110+
def _normalize_search_hit(raw: dict[str, Any]) -> dict[str, Any]:
111+
created_at = raw.get("created_at")
112+
if isinstance(created_at, str):
113+
created_at = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
114+
return {
115+
"id": raw.get("id"),
116+
"run_id": raw.get("run_id"),
117+
"role": raw.get("role"),
118+
"content_type": raw.get("content_type"),
119+
"text": raw.get("text_payload"),
120+
"binary": raw.get("binary_payload"),
121+
"embedding": raw.get("embedding"),
122+
"distance": raw.get("distance"),
123+
"created_at": created_at,
124+
"state_metadata": raw.get("state_metadata"),
125+
}
126+
127+
99128
class Context:
100129
def __init__(self, uri: str) -> None:
101130
self._inner = _Context.create(uri)
@@ -140,6 +169,11 @@ def fork(self, branch_name: str) -> Context:
140169
def checkout(self, version_id: int | str) -> None:
141170
self._inner.checkout(int(version_id))
142171

172+
def search(self, query: Any, limit: int | None = None) -> list[dict[str, Any]]:
173+
vector = _coerce_vector(query)
174+
results = self._inner.search(vector, limit)
175+
return [_normalize_search_hit(item) for item in results]
176+
143177
def __repr__(self) -> str:
144178
return (
145179
f"Context(uri={self._inner.uri()!r}, "

python/src/lib.rs

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
use std::sync::Arc;
22

3-
use chrono::Utc;
3+
use chrono::{SecondsFormat, Utc};
44
use pyo3::exceptions::PyRuntimeError;
55
use pyo3::prelude::*;
6-
use pyo3::types::PyType;
6+
use pyo3::IntoPy;
7+
use pyo3::types::{PyBytes, PyDict, PyType};
78
use tokio::runtime::Runtime;
89

910
use lance_context::serde::CONTENT_TYPE_TEXT;
10-
use lance_context::{Context as RustContext, ContextRecord, ContextStore};
11+
use lance_context::{Context as RustContext, ContextRecord, ContextStore, SearchResult};
1112

1213
const DEFAULT_BINARY_CONTENT_TYPE: &str = "application/octet-stream";
1314
const BINARY_PLACEHOLDER: &str = "[binary]";
@@ -127,12 +128,74 @@ impl Context {
127128
self.run_id = new_run_id();
128129
Ok(())
129130
}
131+
132+
#[pyo3(signature = (query, limit = None))]
133+
fn search(
134+
&self,
135+
py: Python<'_>,
136+
query: Vec<f32>,
137+
limit: Option<usize>,
138+
) -> PyResult<Vec<PyObject>> {
139+
let hits = self
140+
.runtime
141+
.block_on(self.store.search(&query, limit))
142+
.map_err(to_py_err)?;
143+
hits.into_iter()
144+
.map(|hit| search_hit_to_py(py, hit))
145+
.collect()
146+
}
130147
}
131148

132149
fn new_run_id() -> String {
133150
format!("run-{}-{}", Utc::now().timestamp_micros(), std::process::id())
134151
}
135152

153+
fn search_hit_to_py(py: Python<'_>, hit: SearchResult) -> PyResult<PyObject> {
154+
let SearchResult { record, distance } = hit;
155+
let ContextRecord {
156+
id,
157+
run_id,
158+
created_at,
159+
role,
160+
state_metadata,
161+
content_type,
162+
text_payload,
163+
binary_payload,
164+
embedding,
165+
} = record;
166+
167+
let dict = PyDict::new(py);
168+
dict.set_item("id", id)?;
169+
dict.set_item("run_id", run_id)?;
170+
dict.set_item(
171+
"created_at",
172+
created_at.to_rfc3339_opts(SecondsFormat::Micros, true),
173+
)?;
174+
dict.set_item("role", role)?;
175+
176+
let state_obj: PyObject = match state_metadata {
177+
Some(metadata) => {
178+
let state_dict = PyDict::new(py);
179+
state_dict.set_item("step", metadata.step)?;
180+
state_dict.set_item("active_plan_id", metadata.active_plan_id)?;
181+
state_dict.set_item("tokens_used", metadata.tokens_used)?;
182+
state_dict.set_item("custom", metadata.custom)?;
183+
state_dict.into_py(py)
184+
}
185+
None => py.None().into_py(py),
186+
};
187+
dict.set_item("state_metadata", state_obj)?;
188+
dict.set_item("content_type", content_type)?;
189+
dict.set_item("text_payload", text_payload)?;
190+
match binary_payload {
191+
Some(payload) => dict.set_item("binary_payload", PyBytes::new(py, &payload))?,
192+
None => dict.set_item("binary_payload", py.None())?,
193+
}
194+
dict.set_item("embedding", embedding)?;
195+
dict.set_item("distance", distance)?;
196+
Ok(dict.into_py(py))
197+
}
198+
136199
fn to_py_err<E: std::fmt::Display>(err: E) -> PyErr {
137200
PyRuntimeError::new_err(err.to_string())
138201
}

python/tests/test_search.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from datetime import datetime
2+
3+
import pytest
4+
5+
from lance_context.api import Context, _coerce_vector, _normalize_search_hit
6+
7+
8+
class DummyInner:
9+
def __init__(self) -> None:
10+
self.calls: list[tuple[list[float], int | None]] = []
11+
12+
def search(self, vector: list[float], limit: int | None):
13+
self.calls.append((vector, limit))
14+
return [
15+
{
16+
"id": "rec-1",
17+
"run_id": "run-1",
18+
"role": "user",
19+
"content_type": "text/plain",
20+
"text_payload": "hello",
21+
"binary_payload": None,
22+
"embedding": [0.1, 0.2],
23+
"distance": 0.12,
24+
"created_at": "2024-01-01T12:00:00Z",
25+
"state_metadata": {"step": 1},
26+
}
27+
]
28+
29+
30+
def test_coerce_vector_from_list():
31+
assert _coerce_vector([1, 2.5]) == [1.0, 2.5]
32+
33+
34+
def test_coerce_vector_rejects_invalid():
35+
with pytest.raises(TypeError):
36+
_coerce_vector("invalid")
37+
38+
39+
def test_normalize_search_hit_converts_timestamp():
40+
result = _normalize_search_hit(
41+
{
42+
"id": "rec-2",
43+
"created_at": "2024-01-01T00:00:00Z",
44+
"content_type": "text/plain",
45+
"text_payload": None,
46+
"binary_payload": None,
47+
"embedding": None,
48+
"distance": 0.5,
49+
"run_id": "run-2",
50+
"role": "assistant",
51+
"state_metadata": None,
52+
}
53+
)
54+
assert isinstance(result["created_at"], datetime)
55+
56+
57+
def test_context_search_formats_results():
58+
ctx = Context.__new__(Context)
59+
dummy = DummyInner()
60+
ctx._inner = dummy # type: ignore[attr-defined]
61+
62+
hits = ctx.search([0.5, 0.4], limit=3)
63+
64+
assert dummy.calls == [([0.5, 0.4], 3)]
65+
assert hits[0]["id"] == "rec-1"
66+
assert hits[0]["text"] == "hello"
67+
assert hits[0]["binary"] is None
68+
assert isinstance(hits[0]["created_at"], datetime)

rust/lance-context/Cargo.lock

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/lance-context/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,8 @@ lancedb = "0.23.1"
2020
lance-namespace = "1.0.1"
2121
lance-graph = "0.4.0"
2222
serde = { version = "1", features = ["derive"] }
23+
futures = "0.3"
24+
25+
[dev-dependencies]
26+
tempfile = "3"
27+
tokio = { version = "1", features = ["rt-multi-thread"] }

rust/lance-context/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@ pub mod serde;
66
mod store;
77

88
pub use context::{Context, ContextEntry, Snapshot};
9-
pub use record::{ContextRecord, StateMetadata};
9+
pub use record::{ContextRecord, SearchResult, StateMetadata};
1010
pub use store::ContextStore;

rust/lance-context/src/record.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,10 @@ pub struct ContextRecord {
2222
pub binary_payload: Option<Vec<u8>>,
2323
pub embedding: Option<Vec<f32>>,
2424
}
25+
26+
/// Result returned from a vector similarity search.
27+
#[derive(Debug, Clone)]
28+
pub struct SearchResult {
29+
pub record: ContextRecord,
30+
pub distance: f32,
31+
}

0 commit comments

Comments
 (0)