diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py index 8f3a728..50b18ad 100644 --- a/python/python/lance_context/api.py +++ b/python/python/lance_context/api.py @@ -19,7 +19,7 @@ def _is_module(value: Any, prefix: str) -> bool: def _get_pyarrow(): try: - import pyarrow as pa # pyright: ignore[reportMissingImports] + import pyarrow as pa # pyright: ignore[reportMissingImports,reportMissingTypeStubs] except ImportError as exc: # pragma: no cover - optional dependency raise ImportError( "pyarrow is required to serialize pandas/polars dataframes" @@ -113,6 +113,9 @@ def branch(self) -> str: def entries(self) -> int: return self._inner.entries() + def version(self) -> int: + return self._inner.version() + def add( self, role: str, @@ -134,8 +137,8 @@ def fork(self, branch_name: str) -> Context: inner = self._inner.fork(branch_name) return self._from_inner(inner) - def checkout(self, snapshot_id: str) -> None: - self._inner.checkout(snapshot_id) + def checkout(self, version_id: int | str) -> None: + self._inner.checkout(int(version_id)) def __repr__(self) -> str: return ( diff --git a/python/python/tests/test_context.py b/python/python/tests/test_context.py index 97227bb..a286123 100644 --- a/python/python/tests/test_context.py +++ b/python/python/tests/test_context.py @@ -10,6 +10,9 @@ def test_context_create_and_add(): ctx.add("user", "hello") ctx.add("assistant", 123, data_type="text/plain") assert ctx.entries() == 2 + version = ctx.version() + assert isinstance(version, int) + ctx.checkout(version) def test_context_snapshot_and_fork(): @@ -25,3 +28,4 @@ def test_context_snapshot_and_fork(): fork = ctx.fork("branch-a") assert fork.branch() == "branch-a" assert fork.entries() == ctx.entries() + assert fork.version() == ctx.version() diff --git a/python/src/lib.rs b/python/src/lib.rs index 5738551..f0f89cf 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -54,6 +54,10 @@ impl Context { self.inner.entries() } + fn version(&self) -> u64 { + self.store.version() + } + #[pyo3(signature = (role, content, data_type = None))] fn add( &mut self, @@ -116,8 +120,12 @@ impl Context { } } - fn checkout(&mut self, snapshot_id: &str) { - self.inner.checkout(snapshot_id); + fn checkout(&mut self, version_id: u64) -> PyResult<()> { + self.runtime + .block_on(self.store.checkout(version_id)) + .map_err(to_py_err)?; + self.run_id = new_run_id(); + Ok(()) } } diff --git a/rust/lance-context/src/store.rs b/rust/lance-context/src/store.rs index 049cfd6..466e597 100644 --- a/rust/lance-context/src/store.rs +++ b/rust/lance-context/src/store.rs @@ -61,6 +61,18 @@ impl ContextStore { Ok(self.dataset.manifest.version) } + /// Current dataset version. + pub fn version(&self) -> u64 { + self.dataset.manifest.version + } + + /// Checkout a specific dataset version. + pub async fn checkout(&mut self, version_id: u64) -> LanceResult<()> { + let dataset = self.dataset.checkout_version(version_id).await?; + self.dataset = dataset; + Ok(()) + } + /// Lance schema for the context store. pub fn schema() -> Schema { Schema::new(vec![