diff --git a/python/src/lib.rs b/python/src/lib.rs index 30d8e44..96ebb77 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -71,6 +71,7 @@ impl Context { #[pyo3(signature = (uri, *, storage_options=None))] fn create( _cls: &Bound<'_, PyType>, + py: Python<'_>, uri: &str, storage_options: Option<&Bound<'_, PyDict>>, ) -> PyResult { @@ -80,9 +81,10 @@ impl Context { storage_options: storage_options_from_dict(storage_options)?, }; - let store = runtime - .block_on(ContextStore::open_with_options(uri, options)) - .map_err(to_py_err)?; + let store_res = py.allow_threads(|| { + runtime.block_on(ContextStore::open_with_options(uri, options)) + }); + let store = store_res.map_err(to_py_err)?; let run_id = new_run_id(); Ok(Self { inner: RustContext::new(uri), @@ -111,6 +113,7 @@ impl Context { #[pyo3(signature = (role, content, data_type = None))] fn add( &mut self, + py: Python<'_>, role: &str, content: &Bound<'_, PyAny>, data_type: Option<&str>, @@ -147,9 +150,11 @@ impl Context { embedding: None, }; - self.runtime - .block_on(self.store.add(std::slice::from_ref(&record))) - .map_err(to_py_err)?; + let add_res = py.allow_threads(|| { + self.runtime + .block_on(self.store.add(std::slice::from_ref(&record))) + }); + add_res.map_err(to_py_err)?; self.inner.add(role, &inner_content, data_type); Ok(()) } @@ -168,10 +173,9 @@ impl Context { } } - fn checkout(&mut self, version_id: u64) -> PyResult<()> { - self.runtime - .block_on(self.store.checkout(version_id)) - .map_err(to_py_err)?; + fn checkout(&mut self, py: Python<'_>, version_id: u64) -> PyResult<()> { + let res = py.allow_threads(|| self.runtime.block_on(self.store.checkout(version_id))); + res.map_err(to_py_err)?; self.run_id = new_run_id(); Ok(()) } @@ -183,10 +187,11 @@ impl Context { query: Vec, limit: Option, ) -> PyResult> { - let hits = self - .runtime - .block_on(self.store.search(&query, limit)) - .map_err(to_py_err)?; + let hits_res = py.allow_threads(|| { + self.runtime + .block_on(self.store.search(&query, limit)) + }); + let hits = hits_res.map_err(to_py_err)?; hits.into_iter() .map(|hit| search_hit_to_py(py, hit)) .collect()