From d8f6fc25a98b556f11e60c2ce5247dcfab33f95b Mon Sep 17 00:00:00 2001 From: Matt Brown Date: Tue, 7 Oct 2025 11:06:54 +0100 Subject: [PATCH 1/4] Add `get_model_state` to get validated doc. This adds `get_model_state` which returns the entire doc state as the pydantic model pass in as `Model=` during instantiation. It's effectively the other side of `apply_update`. If the doc has no Model defined it will raise a RuntimeError. If the doc is invalid it will raise a pydantic ValidationError. Also, update `apply_update` to use `model_validate` instead of Model(**value). This is more of a "style" thing. See https://github.com/pydantic/pydantic/discussions/9676 Also add ruff to the test dependencies and format. --- pyproject.toml | 1 + python/pycrdt/_base.py | 1 - python/pycrdt/_doc.py | 40 ++++++++++++++++++++++++++++++++------- python/pycrdt/_xml.py | 4 +--- tests/test_doc.py | 2 +- tests/test_model.py | 9 ++++++++- tests/test_transaction.py | 1 - 7 files changed, 44 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 92146e2..445c918 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ test = [ "mypy", "coverage[toml] >=7", "exceptiongroup; python_version<'3.11'", + "ruff>=0.13.3", ] docs = [ "mkdocs", diff --git a/python/pycrdt/_base.py b/python/pycrdt/_base.py index 3af3bb4..1af3ff2 100644 --- a/python/pycrdt/_base.py +++ b/python/pycrdt/_base.py @@ -47,7 +47,6 @@ class BaseDoc: _txn_lock: threading.Lock _txn_async_lock: anyio.Lock _allow_multithreading: bool - _Model: Any _subscriptions: list[Subscription] _origins: dict[int, Any] _task_group: TaskGroup | None diff --git a/python/pycrdt/_doc.py b/python/pycrdt/_doc.py index a72fc94..527bb02 100644 --- a/python/pycrdt/_doc.py +++ b/python/pycrdt/_doc.py @@ -2,7 +2,19 @@ from functools import partial from inspect import iscoroutinefunction -from typing import Any, Awaitable, Callable, Generic, Iterable, Literal, Type, TypeVar, Union, cast, overload +from typing import ( + Any, + Awaitable, + Callable, + Generic, + Iterable, + Literal, + Type, + TypeVar, + Union, + cast, + overload, +) from anyio import BrokenResourceError, create_memory_object_stream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -15,7 +27,9 @@ from ._transaction import NewTransaction, ReadTransaction, Transaction T = TypeVar("T", bound=BaseType) -TransactionOrSubdocsEvent = TypeVar("TransactionOrSubdocsEvent", bound=TransactionEvent | SubdocsEvent) +TransactionOrSubdocsEvent = TypeVar( + "TransactionOrSubdocsEvent", bound=TransactionEvent | SubdocsEvent +) class Doc(BaseDoc, Generic[T]): @@ -35,7 +49,7 @@ def __init__( client_id: int | None = None, skip_gc: bool | None = None, doc: _Doc | None = None, - Model=None, + Model: Any | None = None, allow_multithreading: bool = False, ) -> None: """ @@ -47,8 +61,9 @@ def __init__( allow_multithreading: Whether to allow the document to be used in different threads. """ super().__init__( - client_id=client_id, skip_gc=skip_gc, doc=doc, Model=Model, allow_multithreading=allow_multithreading + client_id=client_id, skip_gc=skip_gc, doc=doc, allow_multithreading=allow_multithreading ) + self._Model = Model for k, v in init.items(): self[k] = v if Model is not None: @@ -150,6 +165,14 @@ def get_state(self) -> bytes: assert txn._txn is not None return self._doc.get_state(txn._txn) + def get_model_state(self) -> Any: + if self._Model is None: + raise RuntimeError( + "no Model defined for doc. Instantiate Doc with Doc(Model=PydanticModel)" + ) + d = {k: self[k].to_py() for k in self._Model.model_fields} + return self._Model.model_validate(d) + def get_update(self, state: bytes | None = None) -> bytes: """ Args: @@ -174,7 +197,7 @@ def apply_update(self, update: bytes) -> None: twin_doc.apply_update(update) d = {k: twin_doc[k].to_py() for k in self._Model.model_fields} try: - self._Model(**d) + self._Model.model_validate(d) except Exception as e: self._twin_doc = Doc(dict(self)) raise e @@ -292,7 +315,8 @@ def _roots(self) -> dict[str, T]: def observe( self, - callback: Callable[[TransactionEvent], None] | Callable[[TransactionEvent], Awaitable[None]], + callback: Callable[[TransactionEvent], None] + | Callable[[TransactionEvent], Awaitable[None]], ) -> Subscription: """ Subscribes a callback to be called with the document change event. @@ -405,7 +429,9 @@ async def main(): observe = self.observe_subdocs if subdocs else self.observe if not self._send_streams[subdocs]: if async_transactions: - self._event_subscription[subdocs] = observe(partial(self._async_send_event, subdocs)) + self._event_subscription[subdocs] = observe( + partial(self._async_send_event, subdocs) + ) else: self._event_subscription[subdocs] = observe(partial(self._send_event, subdocs)) send_stream, receive_stream = create_memory_object_stream[ diff --git a/python/pycrdt/_xml.py b/python/pycrdt/_xml.py index 4a87a9b..b9125a0 100644 --- a/python/pycrdt/_xml.py +++ b/python/pycrdt/_xml.py @@ -269,9 +269,7 @@ def insert_embed(self, index: int, value: Any, attrs: dict[str, Any] | None = No self._do_and_integrate("insert", value, txn._txn, index, _attrs) else: # primitive type - self.integrated.insert_embed( - txn._txn, index, value, _attrs - ) + self.integrated.insert_embed(txn._txn, index, value, _attrs) def format(self, start: int, stop: int, attrs: dict[str, Any]) -> None: """ diff --git a/tests/test_doc.py b/tests/test_doc.py index 6eb189d..e22eb84 100644 --- a/tests/test_doc.py +++ b/tests/test_doc.py @@ -234,7 +234,7 @@ def test_get_update_exception(): def test_apply_update_exception(): doc = Doc() with pytest.raises(ValueError) as excinfo: - doc.apply_update(b"\xFF\xFF\xFF\xFF") + doc.apply_update(b"\xff\xff\xff\xff") assert "Cannot decode update" in str(excinfo.value) diff --git a/tests/test_model.py b/tests/test_model.py index b317d68..e3c3459 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import UTC, datetime from typing import Tuple import pytest @@ -52,3 +52,10 @@ class Delivery(BaseModel): assert str(local_doc["timestamp"]) == "2020-02-02T03:04:05Z" assert list(local_doc["dimensions"]) == ["10", "30"] + + decoded = local_doc.get_model_state() + assert decoded.timestamp == datetime(2020, 2, 2, 3, 4, 5, tzinfo=UTC) + assert decoded.dimensions == ( + 10, + 30, + ) diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 7f72511..157e88e 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -1,5 +1,4 @@ import gc -import platform import sys import time from functools import partial From 6443b94a82970d695683f8536c9fa3611affbbee Mon Sep 17 00:00:00 2001 From: Matt Brown Date: Tue, 7 Oct 2025 11:20:38 +0100 Subject: [PATCH 2/4] Use `timezone.utc` instead of `datetime.UTC` --- tests/test_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index e3c3459..acc8b42 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,4 +1,4 @@ -from datetime import UTC, datetime +from datetime import datetime, timezone from typing import Tuple import pytest @@ -54,7 +54,7 @@ class Delivery(BaseModel): assert list(local_doc["dimensions"]) == ["10", "30"] decoded = local_doc.get_model_state() - assert decoded.timestamp == datetime(2020, 2, 2, 3, 4, 5, tzinfo=UTC) + assert decoded.timestamp == datetime(2020, 2, 2, 3, 4, 5, tzinfo=timezone.utc) assert decoded.dimensions == ( 10, 30, From 663d1694781df18d3a68b4d3a6c8569a0145f112 Mon Sep 17 00:00:00 2001 From: Matt Brown Date: Tue, 7 Oct 2025 11:26:26 +0100 Subject: [PATCH 3/4] Add test where model not defined. --- tests/test_model.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_model.py b/tests/test_model.py index acc8b42..500e860 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -59,3 +59,16 @@ class Delivery(BaseModel): 10, 30, ) + + +def test_model_no_model_defined(): + local_doc = Doc( + { + "timestamp": Text(), + "dimensions": Array(), + }, + ) + with pytest.raises(RuntimeError) as exc_info: + local_doc.get_model_state() + + assert str(exc_info.value).startswith("no Model defined for doc") From 1aa8dfca55ea8c08156ae3f867715df9c576397c Mon Sep 17 00:00:00 2001 From: Matt Brown Date: Wed, 8 Oct 2025 14:20:53 +0100 Subject: [PATCH 4/4] Implement `to_py` for Doc. This iterates through the roots of the Doc, converting them into their native python types (using the underlying type's to_py() fn). tbd if this can be used to replace `_roots` as well? --- python/pycrdt/_doc.py | 6 +- python/pycrdt/_pycrdt.pyi | 3 + src/doc.rs | 156 ++++++++++++++++++++++++++++++-------- 3 files changed, 131 insertions(+), 34 deletions(-) diff --git a/python/pycrdt/_doc.py b/python/pycrdt/_doc.py index 527bb02..3906365 100644 --- a/python/pycrdt/_doc.py +++ b/python/pycrdt/_doc.py @@ -170,8 +170,10 @@ def get_model_state(self) -> Any: raise RuntimeError( "no Model defined for doc. Instantiate Doc with Doc(Model=PydanticModel)" ) - d = {k: self[k].to_py() for k in self._Model.model_fields} - return self._Model.model_validate(d) + with self.transaction() as txn: + assert txn._txn is not None + all_roots = self._doc.to_py(txn._txn) + return self._Model.model_validate(all_roots) def get_update(self, state: bytes | None = None) -> bytes: """ diff --git a/python/pycrdt/_pycrdt.pyi b/python/pycrdt/_pycrdt.pyi index 0f5afc0..cf0d4d0 100644 --- a/python/pycrdt/_pycrdt.pyi +++ b/python/pycrdt/_pycrdt.pyi @@ -61,6 +61,9 @@ class Doc: def roots(self, txn: Transaction) -> dict[str, Text | Array | Map]: """Get top-level (root) shared types available in current document.""" + def to_py(self, txn: Transaction) -> dict[str, Any]: + """Get top-level (root) shared types as native Python objects.""" + def observe(self, callback: Callable[[TransactionEvent], None]) -> Subscription: """Subscribes a callback to be called with the shared document change event. Returns a subscription that can be used to unsubscribe.""" diff --git a/src/doc.rs b/src/doc.rs index fe75f04..4d23d85 100644 --- a/src/doc.rs +++ b/src/doc.rs @@ -1,20 +1,21 @@ -use pyo3::prelude::*; -use pyo3::IntoPyObjectExt; -use pyo3::exceptions::{PyRuntimeError, PyValueError}; -use pyo3::types::{PyBool, PyBytes, PyDict, PyInt, PyList}; -use yrs::{ - Doc as _Doc, Options, ReadTxn, StateVector, SubdocsEvent as _SubdocsEvent, Transact, TransactionCleanupEvent, TransactionMut, Update, WriteTxn -}; -use yrs::updates::encoder::{Encode, Encoder}; -use yrs::updates::decoder::Decode; -use crate::text::Text; use crate::array::Array; use crate::map::Map; -use crate::transaction::Transaction; use crate::subscription::Subscription; +use crate::text::Text; +use crate::transaction::Transaction; use crate::type_conversions::ToPython; use crate::xml::XmlFragment; - +use pyo3::exceptions::{PyRuntimeError, PyValueError}; +use pyo3::prelude::*; +use pyo3::types::{PyBool, PyBytes, PyDict, PyInt, PyList}; +use pyo3::IntoPyObjectExt; +use yrs::updates::decoder::Decode; +use yrs::updates::encoder::{Encode, Encoder}; +use yrs::{ + Array as YArray, Doc as _Doc, GetString, Map as YMap, Options, ReadTxn, StateVector, + SubdocsEvent as _SubdocsEvent, Transact, TransactionCleanupEvent, TransactionMut, Update, + WriteTxn, +}; #[pyclass] #[derive(Clone)] @@ -41,7 +42,8 @@ impl Doc { let mut encoder = yrs::updates::encoder::EncoderV1::new(); { let txn = original.doc.transact(); - txn.encode_state_from_snapshot(&snapshot.snapshot, &mut encoder).unwrap(); + txn.encode_state_from_snapshot(&snapshot.snapshot, &mut encoder) + .unwrap(); } let update = yrs::Update::decode_v1(&encoder.to_vec()).unwrap(); { @@ -53,11 +55,19 @@ impl Doc { let txn_orig = original.doc.transact(); for (name, root) in txn_orig.root_refs() { match root { - yrs::Out::YText(_) => { let _ = new_doc.get_or_insert_text(name); }, - yrs::Out::YArray(_) => { let _ = new_doc.get_or_insert_array(name); }, - yrs::Out::YMap(_) => { let _ = new_doc.get_or_insert_map(name); }, - yrs::Out::YXmlFragment(_) => { let _ = new_doc.get_or_insert_xml_fragment(name); }, - _ => {}, // ignore unknown types + yrs::Out::YText(_) => { + let _ = new_doc.get_or_insert_text(name); + } + yrs::Out::YArray(_) => { + let _ = new_doc.get_or_insert_array(name); + } + yrs::Out::YMap(_) => { + let _ = new_doc.get_or_insert_map(name); + } + yrs::Out::YXmlFragment(_) => { + let _ = new_doc.get_or_insert_xml_fragment(name); + } + _ => {} // ignore unknown types } } drop(txn_orig); @@ -71,14 +81,16 @@ impl Doc { fn new(client_id: &Bound<'_, PyAny>, skip_gc: &Bound<'_, PyAny>) -> PyResult { let mut options = Options::default(); if !client_id.is_none() { - let _client_id: u64 = client_id.downcast::() + let _client_id: u64 = client_id + .downcast::() .map_err(|_| PyValueError::new_err("client_id must be an integer"))? .extract() .map_err(|_| PyValueError::new_err("client_id must be a valid u64"))?; options.client_id = _client_id; } if !skip_gc.is_none() { - let _skip_gc: bool = skip_gc.downcast::() + let _skip_gc: bool = skip_gc + .downcast::() .map_err(|_| PyValueError::new_err("skip_gc must be a boolean"))? .extract() .map_err(|_| PyValueError::new_err("skip_gc must be a valid bool"))?; @@ -90,7 +102,11 @@ impl Doc { #[staticmethod] #[pyo3(name = "from_snapshot")] - pub fn from_snapshot(py: Python<'_>, snapshot: PyRef<'_, crate::snapshot::Snapshot>, doc: PyRef<'_, Doc>) -> PyResult> { + pub fn from_snapshot( + py: Python<'_>, + snapshot: PyRef<'_, crate::snapshot::Snapshot>, + doc: PyRef<'_, Doc>, + ) -> PyResult> { let restored = Doc::_from_snapshot_impl(&doc, &snapshot); Py::new(py, restored) } @@ -103,7 +119,12 @@ impl Doc { self.doc.client_id() } - fn get_or_insert_text(&mut self, py: Python<'_>, txn: &mut Transaction, name: &str) -> PyResult> { + fn get_or_insert_text( + &mut self, + py: Python<'_>, + txn: &mut Transaction, + name: &str, + ) -> PyResult> { let mut _t = txn.transaction(); let t = _t.as_mut().unwrap().as_mut(); let text = t.get_or_insert_text(name); @@ -111,15 +132,25 @@ impl Doc { Ok(pytext) } - fn get_or_insert_array(&mut self, py: Python<'_>, txn: &mut Transaction, name: &str) -> PyResult> { + fn get_or_insert_array( + &mut self, + py: Python<'_>, + txn: &mut Transaction, + name: &str, + ) -> PyResult> { let mut _t = txn.transaction(); let t = _t.as_mut().unwrap().as_mut(); let shared = t.get_or_insert_array(name); - let pyshared: Py = Py::new(py, Array::from(shared))?; + let pyshared: Py = Py::new(py, Array::from(shared))?; Ok(pyshared) } - fn get_or_insert_map(&mut self, py: Python<'_>, txn: &mut Transaction, name: &str) -> PyResult> { + fn get_or_insert_map( + &mut self, + py: Python<'_>, + txn: &mut Transaction, + name: &str, + ) -> PyResult> { let mut _t = txn.transaction(); let t = _t.as_mut().unwrap().as_mut(); let shared = t.get_or_insert_map(name); @@ -141,7 +172,11 @@ impl Doc { Err(PyRuntimeError::new_err("Already in a transaction")) } - fn create_transaction_with_origin(&self, py: Python<'_>, origin: i128) -> PyResult> { + fn create_transaction_with_origin( + &self, + py: Python<'_>, + origin: i128, + ) -> PyResult> { if let Ok(txn) = self.doc.try_transact_mut_with(origin) { let t: Py = Py::new(py, Transaction::from(txn))?; return Ok(t); @@ -160,7 +195,9 @@ impl Doc { let mut _t = txn.transaction(); let t = _t.as_mut().unwrap().as_mut(); let state: &[u8] = state.extract()?; - let Ok(state_vector) = StateVector::decode_v1(&state) else { return Err(PyValueError::new_err("Cannot decode state")) }; + let Ok(state_vector) = StateVector::decode_v1(&state) else { + return Err(PyValueError::new_err("Cannot decode state")); + }; let update = t.encode_diff_v1(&state_vector); let bytes: Py = Python::attach(|py| PyBytes::new(py, &update).into()); Ok(bytes) @@ -186,8 +223,53 @@ impl Doc { result.into() } + fn to_py(&self, py: Python<'_>, txn: &mut Transaction) -> PyResult> { + let mut _t = txn.transaction(); + let t = _t.as_mut().unwrap().as_mut(); + let result = PyDict::new(py); + + let roots_info: Vec<_> = t + .root_refs() + .map(|(name, root)| (name.to_string(), root)) + .collect(); + + for (name, root) in roots_info { + match root { + yrs::Out::YText(_) => { + let text = t.get_or_insert_text(name.as_str()); + let value = text.get_string(t); + result.set_item(name, value)?; + } + yrs::Out::YArray(_) => { + let array = t.get_or_insert_array(name.as_str()); + let list = PyList::empty(py); + for item in array.iter(t) { + list.append(item.into_py(py))?; + } + result.set_item(name, list)?; + } + yrs::Out::YMap(_) => { + let map = t.get_or_insert_map(name.as_str()); + let dict = PyDict::new(py); + for (key, value) in map.iter(t) { + dict.set_item(key, value.into_py(py))?; + } + result.set_item(name, dict)?; + } + yrs::Out::YXmlFragment(_) => { + let xml = t.get_or_insert_xml_fragment(name.as_str()); + let xml_py = Py::new(py, XmlFragment::from(xml))?; + result.set_item(name, xml_py)?; + } + _ => {} // ignore other types + } + } + Ok(result.into()) + } + pub fn observe(&mut self, py: Python<'_>, f: Py) -> PyResult> { - let sub = self.doc + let sub = self + .doc .observe_transaction_cleanup(move |txn, event| { if !event.delete_set.is_empty() || event.before_state != event.after_state { Python::attach(|py| { @@ -204,7 +286,8 @@ impl Doc { } pub fn observe_subdocs(&mut self, py: Python<'_>, f: Py) -> PyResult> { - let sub = self.doc + let sub = self + .doc .observe_subdocs(move |_, event| { Python::attach(|py| { let event = SubdocsEvent::new(py, event); @@ -326,11 +409,20 @@ pub struct SubdocsEvent { impl SubdocsEvent { fn new<'py>(py: Python<'py>, event: &_SubdocsEvent) -> Self { - let added: Vec = event.added().map(|d| d.guid().clone().to_string()).collect(); + let added: Vec = event + .added() + .map(|d| d.guid().clone().to_string()) + .collect(); let added = PyList::new(py, added).unwrap().into_py_any(py).unwrap(); - let removed: Vec = event.removed().map(|d| d.guid().clone().to_string()).collect(); + let removed: Vec = event + .removed() + .map(|d| d.guid().clone().to_string()) + .collect(); let removed = PyList::new(py, removed).unwrap().into_py_any(py).unwrap(); - let loaded: Vec = event.loaded().map(|d| d.guid().clone().to_string()).collect(); + let loaded: Vec = event + .loaded() + .map(|d| d.guid().clone().to_string()) + .collect(); let loaded = PyList::new(py, loaded).unwrap().into_py_any(py).unwrap(); SubdocsEvent { added,