|
1 | 1 | use std::sync::Arc; |
2 | 2 |
|
3 | | -use chrono::Utc; |
| 3 | +use chrono::{SecondsFormat, Utc}; |
4 | 4 | use pyo3::exceptions::PyRuntimeError; |
5 | 5 | use pyo3::prelude::*; |
6 | | -use pyo3::types::PyType; |
| 6 | +use pyo3::IntoPy; |
| 7 | +use pyo3::types::{PyBytes, PyDict, PyType}; |
7 | 8 | use tokio::runtime::Runtime; |
8 | 9 |
|
9 | 10 | use lance_context::serde::CONTENT_TYPE_TEXT; |
10 | | -use lance_context::{Context as RustContext, ContextRecord, ContextStore}; |
| 11 | +use lance_context::{Context as RustContext, ContextRecord, ContextStore, SearchResult}; |
11 | 12 |
|
12 | 13 | const DEFAULT_BINARY_CONTENT_TYPE: &str = "application/octet-stream"; |
13 | 14 | const BINARY_PLACEHOLDER: &str = "[binary]"; |
@@ -127,12 +128,74 @@ impl Context { |
127 | 128 | self.run_id = new_run_id(); |
128 | 129 | Ok(()) |
129 | 130 | } |
| 131 | + |
| 132 | + #[pyo3(signature = (query, limit = None))] |
| 133 | + fn search( |
| 134 | + &self, |
| 135 | + py: Python<'_>, |
| 136 | + query: Vec<f32>, |
| 137 | + limit: Option<usize>, |
| 138 | + ) -> PyResult<Vec<PyObject>> { |
| 139 | + let hits = self |
| 140 | + .runtime |
| 141 | + .block_on(self.store.search(&query, limit)) |
| 142 | + .map_err(to_py_err)?; |
| 143 | + hits.into_iter() |
| 144 | + .map(|hit| search_hit_to_py(py, hit)) |
| 145 | + .collect() |
| 146 | + } |
130 | 147 | } |
131 | 148 |
|
132 | 149 | fn new_run_id() -> String { |
133 | 150 | format!("run-{}-{}", Utc::now().timestamp_micros(), std::process::id()) |
134 | 151 | } |
135 | 152 |
|
| 153 | +fn search_hit_to_py(py: Python<'_>, hit: SearchResult) -> PyResult<PyObject> { |
| 154 | + let SearchResult { record, distance } = hit; |
| 155 | + let ContextRecord { |
| 156 | + id, |
| 157 | + run_id, |
| 158 | + created_at, |
| 159 | + role, |
| 160 | + state_metadata, |
| 161 | + content_type, |
| 162 | + text_payload, |
| 163 | + binary_payload, |
| 164 | + embedding, |
| 165 | + } = record; |
| 166 | + |
| 167 | + let dict = PyDict::new(py); |
| 168 | + dict.set_item("id", id)?; |
| 169 | + dict.set_item("run_id", run_id)?; |
| 170 | + dict.set_item( |
| 171 | + "created_at", |
| 172 | + created_at.to_rfc3339_opts(SecondsFormat::Micros, true), |
| 173 | + )?; |
| 174 | + dict.set_item("role", role)?; |
| 175 | + |
| 176 | + let state_obj: PyObject = match state_metadata { |
| 177 | + Some(metadata) => { |
| 178 | + let state_dict = PyDict::new(py); |
| 179 | + state_dict.set_item("step", metadata.step)?; |
| 180 | + state_dict.set_item("active_plan_id", metadata.active_plan_id)?; |
| 181 | + state_dict.set_item("tokens_used", metadata.tokens_used)?; |
| 182 | + state_dict.set_item("custom", metadata.custom)?; |
| 183 | + state_dict.into_py(py) |
| 184 | + } |
| 185 | + None => py.None().into_py(py), |
| 186 | + }; |
| 187 | + dict.set_item("state_metadata", state_obj)?; |
| 188 | + dict.set_item("content_type", content_type)?; |
| 189 | + dict.set_item("text_payload", text_payload)?; |
| 190 | + match binary_payload { |
| 191 | + Some(payload) => dict.set_item("binary_payload", PyBytes::new(py, &payload))?, |
| 192 | + None => dict.set_item("binary_payload", py.None())?, |
| 193 | + } |
| 194 | + dict.set_item("embedding", embedding)?; |
| 195 | + dict.set_item("distance", distance)?; |
| 196 | + Ok(dict.into_py(py)) |
| 197 | +} |
| 198 | + |
136 | 199 | fn to_py_err<E: std::fmt::Display>(err: E) -> PyErr { |
137 | 200 | PyRuntimeError::new_err(err.to_string()) |
138 | 201 | } |
|
0 commit comments