Skip to content

Commit fa6205d

Browse files
authored
feat: add embedding parameter to Context.add() (#31)
Exposes the embedding field so users can store vectors for semantic search.
1 parent e12ce46 commit fa6205d

3 files changed

Lines changed: 55 additions & 3 deletions

File tree

python/python/lance_context/api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,13 +236,14 @@ def add(
236236
content: Any,
237237
content_type: str | None = None,
238238
data_type: str | None = None,
239+
embedding: list[float] | None = None,
239240
) -> None:
240241
if content_type is not None and data_type is not None:
241242
raise ValueError("Specify only one of content_type or data_type")
242243
if content_type is None:
243244
content_type = data_type
244245
payload, resolved_type = _normalize_content(content, content_type)
245-
self._inner.add(role, payload, resolved_type)
246+
self._inner.add(role, payload, resolved_type, embedding)
246247

247248
def snapshot(self, label: str | None = None) -> str:
248249
return self._inner.snapshot(label)

python/src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,14 @@ impl Context {
153153
self.store.version()
154154
}
155155

156-
#[pyo3(signature = (role, content, data_type = None))]
156+
#[pyo3(signature = (role, content, data_type = None, embedding = None))]
157157
fn add(
158158
&mut self,
159159
py: Python<'_>,
160160
role: &str,
161161
content: &Bound<'_, PyAny>,
162162
data_type: Option<&str>,
163+
embedding: Option<Vec<f32>>,
163164
) -> PyResult<()> {
164165
let (content_type, text_payload, binary_payload, inner_content) =
165166
match content.extract::<&[u8]>() {
@@ -190,7 +191,7 @@ impl Context {
190191
content_type,
191192
text_payload,
192193
binary_payload,
193-
embedding: None,
194+
embedding,
194195
};
195196

196197
let add_res = py.allow_threads(|| {

python/tests/test_search.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from datetime import datetime
2+
from typing import Any
23

34
import pytest
45
from lance_context.api import Context, _coerce_vector, _normalize_record, _normalize_search_hit
@@ -8,6 +9,10 @@ class DummyInner:
89
def __init__(self) -> None:
910
self.search_calls: list[tuple[list[float], int | None]] = []
1011
self.list_calls: list[tuple[int | None, int | None]] = []
12+
self.add_calls: list[tuple[str, Any, str | None, list[float] | None]] = []
13+
14+
def add(self, role: str, content: Any, data_type: str | None, embedding: list[float] | None):
15+
self.add_calls.append((role, content, data_type, embedding))
1116

1217
def search(self, vector: list[float], limit: int | None):
1318
self.search_calls.append((vector, limit))
@@ -140,3 +145,48 @@ def test_context_list_default_args():
140145
ctx.list()
141146

142147
assert dummy.list_calls == [(None, None)]
148+
149+
150+
def test_context_add_with_embedding():
151+
ctx = Context.__new__(Context)
152+
dummy = DummyInner()
153+
ctx._inner = dummy # type: ignore[attr-defined]
154+
155+
embedding = [0.1, 0.2, 0.3]
156+
ctx.add("user", "hello", embedding=embedding)
157+
158+
assert len(dummy.add_calls) == 1
159+
role, content, data_type, passed_embedding = dummy.add_calls[0]
160+
assert role == "user"
161+
assert content == "hello"
162+
assert data_type is None
163+
assert passed_embedding == [0.1, 0.2, 0.3]
164+
165+
166+
def test_context_add_without_embedding():
167+
ctx = Context.__new__(Context)
168+
dummy = DummyInner()
169+
ctx._inner = dummy # type: ignore[attr-defined]
170+
171+
ctx.add("assistant", "world")
172+
173+
assert len(dummy.add_calls) == 1
174+
role, content, data_type, passed_embedding = dummy.add_calls[0]
175+
assert role == "assistant"
176+
assert content == "world"
177+
assert passed_embedding is None
178+
179+
180+
def test_context_add_with_content_type_and_embedding():
181+
ctx = Context.__new__(Context)
182+
dummy = DummyInner()
183+
ctx._inner = dummy # type: ignore[attr-defined]
184+
185+
embedding = [0.5, 0.6]
186+
ctx.add("system", "prompt", content_type="text/markdown", embedding=embedding)
187+
188+
assert len(dummy.add_calls) == 1
189+
role, content, data_type, passed_embedding = dummy.add_calls[0]
190+
assert role == "system"
191+
assert data_type == "text/markdown"
192+
assert passed_embedding == [0.5, 0.6]

0 commit comments

Comments
 (0)