Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .codex/skills/ci-pr-helper/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,7 @@ When ready:
gh pr create --title "<title>" --body "<body>"
```

- Some environments emit a spurious `Unsupported subcommand 'pr'` warning before running
`gh`; ignore that message and continue with the command.

- If `gh` is missing or fails, print the command instead so the user can run it locally.
1 change: 1 addition & 0 deletions python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 34 additions & 0 deletions python/python/lance_context/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from datetime import datetime
from io import BytesIO
from typing import Any

Expand Down Expand Up @@ -96,6 +97,34 @@ def _normalize_content(value: Any, content_type: str | None):
return value, content_type


def _coerce_vector(query: Any) -> list[float]:
if hasattr(query, "tolist"):
query = query.tolist()
elif hasattr(query, "__array__"):
query = query.__array__().tolist()
if isinstance(query, (list, tuple)):
return [float(item) for item in query]
raise TypeError("search query must be a sequence of floats")


def _normalize_search_hit(raw: dict[str, Any]) -> dict[str, Any]:
created_at = raw.get("created_at")
if isinstance(created_at, str):
created_at = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
return {
"id": raw.get("id"),
"run_id": raw.get("run_id"),
"role": raw.get("role"),
"content_type": raw.get("content_type"),
"text": raw.get("text_payload"),
"binary": raw.get("binary_payload"),
"embedding": raw.get("embedding"),
"distance": raw.get("distance"),
"created_at": created_at,
"state_metadata": raw.get("state_metadata"),
}


class Context:
def __init__(self, uri: str) -> None:
self._inner = _Context.create(uri)
Expand Down Expand Up @@ -140,6 +169,11 @@ def fork(self, branch_name: str) -> Context:
def checkout(self, version_id: int | str) -> None:
self._inner.checkout(int(version_id))

def search(self, query: Any, limit: int | None = None) -> list[dict[str, Any]]:
vector = _coerce_vector(query)
results = self._inner.search(vector, limit)
return [_normalize_search_hit(item) for item in results]

def __repr__(self) -> str:
return (
f"Context(uri={self._inner.uri()!r}, "
Expand Down
79 changes: 72 additions & 7 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::sync::Arc;

use chrono::Utc;
use chrono::{SecondsFormat, Utc};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::PyType;
use pyo3::types::{PyBytes, PyDict, PyType};
use pyo3::IntoPyObject;
use tokio::runtime::Runtime;

use lance_context::serde::CONTENT_TYPE_TEXT;
use lance_context::{Context as RustContext, ContextRecord, ContextStore};
use lance_context::{Context as RustContext, ContextRecord, ContextStore, SearchResult};

const DEFAULT_BINARY_CONTENT_TYPE: &str = "application/octet-stream";
const BINARY_PLACEHOLDER: &str = "[binary]";
Expand Down Expand Up @@ -68,9 +69,7 @@ impl Context {
let (content_type, text_payload, binary_payload, inner_content) =
match content.extract::<&[u8]>() {
Ok(bytes) => (
data_type
.unwrap_or(DEFAULT_BINARY_CONTENT_TYPE)
.to_string(),
data_type.unwrap_or(DEFAULT_BINARY_CONTENT_TYPE).to_string(),
None,
Some(bytes.to_vec()),
BINARY_PLACEHOLDER.to_string(),
Expand Down Expand Up @@ -127,10 +126,76 @@ impl Context {
self.run_id = new_run_id();
Ok(())
}

#[pyo3(signature = (query, limit = None))]
fn search(
&self,
py: Python<'_>,
query: Vec<f32>,
limit: Option<usize>,
) -> PyResult<Vec<PyObject>> {
let hits = self
.runtime
.block_on(self.store.search(&query, limit))
.map_err(to_py_err)?;
hits.into_iter()
.map(|hit| search_hit_to_py(py, hit))
.collect()
}
}

fn new_run_id() -> String {
format!("run-{}-{}", Utc::now().timestamp_micros(), std::process::id())
format!(
"run-{}-{}",
Utc::now().timestamp_micros(),
std::process::id()
)
}

fn search_hit_to_py(py: Python<'_>, hit: SearchResult) -> PyResult<PyObject> {
let SearchResult { record, distance } = hit;
let ContextRecord {
id,
run_id,
created_at,
role,
state_metadata,
content_type,
text_payload,
binary_payload,
embedding,
} = record;

let dict = PyDict::new(py);
dict.set_item("id", id)?;
dict.set_item("run_id", run_id)?;
dict.set_item(
"created_at",
created_at.to_rfc3339_opts(SecondsFormat::Micros, true),
)?;
dict.set_item("role", role)?;

let state_obj: PyObject = match state_metadata {
Some(metadata) => {
let state_dict = PyDict::new(py);
state_dict.set_item("step", metadata.step)?;
state_dict.set_item("active_plan_id", metadata.active_plan_id)?;
state_dict.set_item("tokens_used", metadata.tokens_used)?;
state_dict.set_item("custom", metadata.custom)?;
state_dict.into_pyobject(py)?.unbind().into()
}
None => py.None().into_pyobject(py)?.unbind().into(),
};
dict.set_item("state_metadata", state_obj)?;
dict.set_item("content_type", content_type)?;
dict.set_item("text_payload", text_payload)?;
match binary_payload {
Some(payload) => dict.set_item("binary_payload", PyBytes::new(py, &payload))?,
None => dict.set_item("binary_payload", py.None())?,
}
dict.set_item("embedding", embedding)?;
dict.set_item("distance", distance)?;
Ok(dict.into_pyobject(py)?.unbind().into())
}

fn to_py_err<E: std::fmt::Display>(err: E) -> PyErr {
Expand Down
68 changes: 68 additions & 0 deletions python/tests/test_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from datetime import datetime

import pytest

from lance_context.api import Context, _coerce_vector, _normalize_search_hit


class DummyInner:
def __init__(self) -> None:
self.calls: list[tuple[list[float], int | None]] = []

def search(self, vector: list[float], limit: int | None):
self.calls.append((vector, limit))
return [
{
"id": "rec-1",
"run_id": "run-1",
"role": "user",
"content_type": "text/plain",
"text_payload": "hello",
"binary_payload": None,
"embedding": [0.1, 0.2],
"distance": 0.12,
"created_at": "2024-01-01T12:00:00Z",
"state_metadata": {"step": 1},
}
]


def test_coerce_vector_from_list():
assert _coerce_vector([1, 2.5]) == [1.0, 2.5]


def test_coerce_vector_rejects_invalid():
with pytest.raises(TypeError):
_coerce_vector("invalid")


def test_normalize_search_hit_converts_timestamp():
result = _normalize_search_hit(
{
"id": "rec-2",
"created_at": "2024-01-01T00:00:00Z",
"content_type": "text/plain",
"text_payload": None,
"binary_payload": None,
"embedding": None,
"distance": 0.5,
"run_id": "run-2",
"role": "assistant",
"state_metadata": None,
}
)
assert isinstance(result["created_at"], datetime)


def test_context_search_formats_results():
ctx = Context.__new__(Context)
dummy = DummyInner()
ctx._inner = dummy # type: ignore[attr-defined]

hits = ctx.search([0.5, 0.4], limit=3)

assert dummy.calls == [([0.5, 0.4], 3)]
assert hits[0]["id"] == "rec-1"
assert hits[0]["text"] == "hello"
assert hits[0]["binary"] is None
assert isinstance(hits[0]["created_at"], datetime)
3 changes: 3 additions & 0 deletions rust/lance-context/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions rust/lance-context/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,8 @@ lancedb = "0.23.1"
lance-namespace = "1.0.1"
lance-graph = "0.4.0"
serde = { version = "1", features = ["derive"] }
futures = "0.3"

[dev-dependencies]
tempfile = "3"
tokio = { version = "1", features = ["rt-multi-thread"] }
2 changes: 1 addition & 1 deletion rust/lance-context/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ pub mod serde;
mod store;

pub use context::{Context, ContextEntry, Snapshot};
pub use record::{ContextRecord, StateMetadata};
pub use record::{ContextRecord, SearchResult, StateMetadata};
pub use store::ContextStore;
7 changes: 7 additions & 0 deletions rust/lance-context/src/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,10 @@ pub struct ContextRecord {
pub binary_payload: Option<Vec<u8>>,
pub embedding: Option<Vec<f32>>,
}

/// Result returned from a vector similarity search.
#[derive(Debug, Clone)]
pub struct SearchResult {
pub record: ContextRecord,
pub distance: f32,
}
Loading
Loading