Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 45 additions & 15 deletions crates/lance-context-core/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,27 @@ impl ContextStore {
Ok(())
}

/// List all records in the dataset.
pub async fn list(
&self,
limit: Option<usize>,
offset: Option<usize>,
) -> LanceResult<Vec<ContextRecord>> {
let mut scanner = self.dataset.scan();
if let Some(limit) = limit {
scanner.limit(Some(limit as i64), offset.map(|o| o as i64))?;
} else if let Some(offset) = offset {
scanner.limit(None, Some(offset as i64))?;
}

let mut stream = scanner.try_into_stream().await?;
let mut results = Vec::new();
while let Some(batch) = stream.try_next().await? {
results.extend(batch_to_records(&batch)?);
}
Ok(results)
}

/// Perform a nearest-neighbor search over stored embeddings.
pub async fn search(
&self,
Expand Down Expand Up @@ -348,15 +369,7 @@ impl ContextStore {
}

fn batch_to_search_results(batch: &RecordBatch) -> LanceResult<Vec<SearchResult>> {
let id_array = column_as::<StringArray>(batch, "id")?;
let run_id_array = column_as::<StringArray>(batch, "run_id")?;
let created_at_array = column_as::<TimestampMicrosecondArray>(batch, "created_at")?;
let role_array = column_as::<DictionaryArray<Int8Type>>(batch, "role")?;
let state_array = column_as::<StructArray>(batch, "state_metadata")?;
let content_type_array = column_as::<StringArray>(batch, "content_type")?;
let text_array = column_as::<LargeStringArray>(batch, "text_payload")?;
let binary_array = column_as::<LargeBinaryArray>(batch, "binary_payload")?;
let embedding_array = column_as::<FixedSizeListArray>(batch, "embedding")?;
let records = batch_to_records(batch)?;

let distance_column = batch.column_by_name("_distance").ok_or_else(|| {
LanceError::from(ArrowError::InvalidArgumentError(
Expand All @@ -373,6 +386,28 @@ fn batch_to_search_results(batch: &RecordBatch) -> LanceResult<Vec<SearchResult>
))
})?;

Ok(records
.into_iter()
.enumerate()
.map(|(i, record)| SearchResult {
record,
distance: distance_array.value(i),
})
.collect())
}

/// Convert a record batch to context records.
fn batch_to_records(batch: &RecordBatch) -> LanceResult<Vec<ContextRecord>> {
let id_array = column_as::<StringArray>(batch, "id")?;
let run_id_array = column_as::<StringArray>(batch, "run_id")?;
let created_at_array = column_as::<TimestampMicrosecondArray>(batch, "created_at")?;
let role_array = column_as::<DictionaryArray<Int8Type>>(batch, "role")?;
let state_array = column_as::<StructArray>(batch, "state_metadata")?;
let content_type_array = column_as::<StringArray>(batch, "content_type")?;
let text_array = column_as::<LargeStringArray>(batch, "text_payload")?;
let binary_array = column_as::<LargeBinaryArray>(batch, "binary_payload")?;
let embedding_array = column_as::<FixedSizeListArray>(batch, "embedding")?;

let step_array = state_array
.column(0)
.as_ref()
Expand Down Expand Up @@ -487,7 +522,7 @@ fn batch_to_search_results(batch: &RecordBatch) -> LanceResult<Vec<SearchResult>
role_values.value(key).to_string()
};

let record = ContextRecord {
results.push(ContextRecord {
id: id_array.value(row).to_string(),
run_id: run_id_array.value(row).to_string(),
created_at,
Expand All @@ -497,11 +532,6 @@ fn batch_to_search_results(batch: &RecordBatch) -> LanceResult<Vec<SearchResult>
text_payload,
binary_payload,
embedding,
};

results.push(SearchResult {
record,
distance: distance_array.value(row),
});
}

Expand Down
27 changes: 25 additions & 2 deletions python/python/lance_context/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def _coerce_vector(query: Any) -> list[float]:
raise TypeError("search query must be a sequence of floats")


def _normalize_search_hit(raw: dict[str, Any]) -> dict[str, Any]:
def _normalize_record(raw: dict[str, Any]) -> dict[str, Any]:
"""Normalize a raw record dict from the Rust layer."""
created_at = raw.get("created_at")
if isinstance(created_at, str):
created_at = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
Expand All @@ -119,12 +120,18 @@ def _normalize_search_hit(raw: dict[str, Any]) -> dict[str, Any]:
"text": raw.get("text_payload"),
"binary": raw.get("binary_payload"),
"embedding": raw.get("embedding"),
"distance": raw.get("distance"),
"created_at": created_at,
"state_metadata": raw.get("state_metadata"),
}


def _normalize_search_hit(raw: dict[str, Any]) -> dict[str, Any]:
"""Normalize a search hit - adds distance to the base record."""
result = _normalize_record(raw)
result["distance"] = raw.get("distance")
return result


class Context:
def __init__(
self,
Expand Down Expand Up @@ -222,6 +229,22 @@ def search(self, query: Any, limit: int | None = None) -> list[dict[str, Any]]:
results = self._inner.search(vector, limit)
return [_normalize_search_hit(item) for item in results]

def list(
self, limit: int | None = None, offset: int | None = None
) -> list[dict[str, Any]]:
"""Return stored entries.

Args:
limit: Maximum number of entries to return. If None, returns all.
offset: Number of entries to skip before returning results.

Returns:
List of entry dicts with keys: id, run_id, role, content_type,
text, binary, embedding, created_at, state_metadata.
"""
results = self._inner.list(limit, offset)
return [_normalize_record(item) for item in results]

def __repr__(self) -> str:
return (
f"Context(uri={self._inner.uri()!r}, "
Expand Down
25 changes: 24 additions & 1 deletion python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,23 @@ impl Context {
.map(|hit| search_hit_to_py(py, hit))
.collect()
}

#[pyo3(signature = (limit = None, offset = None))]
fn list(
&self,
py: Python<'_>,
limit: Option<usize>,
offset: Option<usize>,
) -> PyResult<Vec<PyObject>> {
let records = self
.runtime
.block_on(self.store.list(limit, offset))
.map_err(to_py_err)?;
records
.into_iter()
.map(|record| record_to_py(py, record))
.collect()
}
}

fn new_run_id() -> String {
Expand All @@ -203,6 +220,13 @@ fn new_run_id() -> String {

fn search_hit_to_py(py: Python<'_>, hit: SearchResult) -> PyResult<PyObject> {
let SearchResult { record, distance } = hit;
let dict = record_to_py(py, record)?;
let dict_ref = dict.downcast_bound::<PyDict>(py)?;
dict_ref.set_item("distance", distance)?;
Ok(dict)
}

fn record_to_py(py: Python<'_>, record: ContextRecord) -> PyResult<PyObject> {
let ContextRecord {
id,
run_id,
Expand Down Expand Up @@ -243,7 +267,6 @@ fn search_hit_to_py(py: Python<'_>, hit: SearchResult) -> PyResult<PyObject> {
None => dict.set_item("binary_payload", py.None())?,
}
dict.set_item("embedding", embedding)?;
dict.set_item("distance", distance)?;
Ok(dict.into_pyobject(py)?.unbind().into())
}

Expand Down
83 changes: 79 additions & 4 deletions python/tests/test_search.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from datetime import datetime

import pytest
from lance_context.api import Context, _coerce_vector, _normalize_search_hit
from lance_context.api import Context, _coerce_vector, _normalize_record, _normalize_search_hit


class DummyInner:
def __init__(self) -> None:
self.calls: list[tuple[list[float], int | None]] = []
self.search_calls: list[tuple[list[float], int | None]] = []
self.list_calls: list[tuple[int | None, int | None]] = []

def search(self, vector: list[float], limit: int | None):
self.calls.append((vector, limit))
self.search_calls.append((vector, limit))
return [
{
"id": "rec-1",
Expand All @@ -25,6 +26,33 @@ def search(self, vector: list[float], limit: int | None):
}
]

def list(self, limit: int | None, offset: int | None):
self.list_calls.append((limit, offset))
return [
{
"id": "rec-1",
"run_id": "run-1",
"role": "user",
"content_type": "text/plain",
"text_payload": "hello",
"binary_payload": None,
"embedding": [0.1, 0.2],
"created_at": "2024-01-01T12:00:00Z",
"state_metadata": {"step": 1},
},
{
"id": "rec-2",
"run_id": "run-1",
"role": "assistant",
"content_type": "text/plain",
"text_payload": "world",
"binary_payload": None,
"embedding": None,
"created_at": "2024-01-02T12:00:00Z",
"state_metadata": None,
},
]


def test_coerce_vector_from_list():
assert _coerce_vector([1, 2.5]) == [1.0, 2.5]
Expand Down Expand Up @@ -60,8 +88,55 @@ def test_context_search_formats_results():

hits = ctx.search([0.5, 0.4], limit=3)

assert dummy.calls == [([0.5, 0.4], 3)]
assert dummy.search_calls == [([0.5, 0.4], 3)]
assert hits[0]["id"] == "rec-1"
assert hits[0]["text"] == "hello"
assert hits[0]["binary"] is None
assert isinstance(hits[0]["created_at"], datetime)


def test_normalize_record_without_distance():
result = _normalize_record(
{
"id": "rec-1",
"created_at": "2024-01-01T00:00:00Z",
"content_type": "text/plain",
"text_payload": "hello",
"binary_payload": None,
"embedding": None,
"run_id": "run-1",
"role": "user",
"state_metadata": None,
}
)
assert "distance" not in result
assert result["text"] == "hello"
assert isinstance(result["created_at"], datetime)


def test_context_list_returns_entries():
ctx = Context.__new__(Context)
dummy = DummyInner()
ctx._inner = dummy # type: ignore[attr-defined]

entries = ctx.list(limit=10, offset=5)

assert dummy.list_calls == [(10, 5)]
assert len(entries) == 2
assert entries[0]["id"] == "rec-1"
assert entries[0]["text"] == "hello"
assert entries[0]["role"] == "user"
assert "distance" not in entries[0]
assert entries[1]["id"] == "rec-2"
assert entries[1]["text"] == "world"
assert isinstance(entries[0]["created_at"], datetime)


def test_context_list_default_args():
ctx = Context.__new__(Context)
dummy = DummyInner()
ctx._inner = dummy # type: ignore[attr-defined]

ctx.list()

assert dummy.list_calls == [(None, None)]
Loading