diff --git a/README.md b/README.md index b010b4e30..939b7c3a3 100644 --- a/README.md +++ b/README.md @@ -663,6 +663,7 @@ The following keyword arguments are supported on all field types. * `nullable: bool` * `default: Any` * `server_default: Any` +* `on_update: Any` * `index: bool` * `unique: bool` * `choices: typing.Sequence` diff --git a/docs/fields/common-parameters.md b/docs/fields/common-parameters.md index 77da0dde8..0ad32eeb4 100644 --- a/docs/fields/common-parameters.md +++ b/docs/fields/common-parameters.md @@ -124,6 +124,70 @@ Sample usage: !!!info `server_default` is passed straight to sqlalchemy table definition so you can read more in [server default][server default] sqlalchemy documentation +## on_update + +`on_update`: `Any` = `None` -> defaults to None. + +A value (or Callable) that is written to the column on every update when the +user did not provide that field explicitly. + +It applies to all write paths that update existing rows: + +* `Model.update()` (and therefore `Model.upsert()`) +* `QuerySet.update()` +* `QuerySet.bulk_update()` + +A field is considered explicitly provided (and its `on_update` is **skipped**) +when any of the following is true: + +* it was passed as a keyword argument to `update()` / `upsert()` +* it was mutated on the instance via attribute assignment before `update()` / + `bulk_update()` is called +* it was passed through `QuerySet.update(field=value)` + +You can pass a static value or a Callable; a Callable is invoked every time +the update fires (e.g. `datetime.now` produces the current timestamp on every +update). + +Used in sql only. + +Sample usage: + +```python +from datetime import datetime + +class Task(ormar.Model): + ormar_config = base_ormar_config.copy(tablename="tasks") + + id: int = ormar.Integer(primary_key=True) + # callable - invoked on every update + updated_at: datetime = ormar.DateTime( + default=datetime.now, on_update=datetime.now + ) + # static value + is_dirty: bool = ormar.Boolean(default=False, on_update=True) + # static value - field keeps the value the caller passed if any + revision: int = ormar.Integer(default=0, on_update=1) + + +task = await Task.objects.create() +assert task.is_dirty is False + +await task.update() +assert task.is_dirty is True # on_update applied +assert task.revision == 1 # on_update applied +assert task.updated_at > created_time # on_update callable re-evaluated + +# explicit values override on_update +await task.update(revision=5) +assert task.revision == 5 # user value wins, on_update skipped +``` + +!!!note + `on_update` is orthogonal to the `onupdate` parameter on `ForeignKey`, + which maps to the SQL `ON UPDATE` referential action and is unrelated. + + ## name `name`: `str` = `None` -> defaults to None diff --git a/docs/index.md b/docs/index.md index 65ad38457..97959527c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -668,6 +668,7 @@ The following keyword arguments are supported on all field types. * `nullable: bool` * `default: Any` * `server_default: Any` + * `on_update: Any` * `index: bool` * `unique: bool` * `name: str` diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 15bd01950..2f2d61248 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -90,6 +90,7 @@ def __init__(self, **kwargs: Any) -> None: self.ormar_default: Any = kwargs.pop("default", None) self.server_default: Any = kwargs.pop("server_default", None) + self.on_update: Any = kwargs.pop("on_update", None) self.comment: str = kwargs.pop("comment", None) @@ -241,6 +242,24 @@ def has_default(self, use_server: bool = True) -> bool: self.server_default is not None and use_server ) + def has_on_update(self) -> bool: + """ + Checks if the field has an on_update value or callable configured. + + :return: result of the check if on_update value is set + :rtype: bool + """ + return self.on_update is not None + + def get_on_update(self) -> Any: + """ + Resolves the on_update value, calling it if it is a callable. + + :return: resolved on_update value + :rtype: Any + """ + return self.on_update() if callable(self.on_update) else self.on_update + def is_auto_primary_key(self) -> bool: """ Checks if field is first a primary key and if it, diff --git a/ormar/models/helpers/models.py b/ormar/models/helpers/models.py index acd894a92..46886d0b5 100644 --- a/ormar/models/helpers/models.py +++ b/ormar/models/helpers/models.py @@ -54,6 +54,9 @@ def populate_default_options_values( # noqa: CCR001 new_model._bytes_fields = { name for name, field in model_fields.items() if field.__type__ is bytes } + new_model._onupdate_fields = { + name for name, field in model_fields.items() if field.has_on_update() + } new_model.__relation_map__ = None new_model.__ormar_fields_validators__ = None diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 2214c69a2..d04285e3a 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -75,6 +75,7 @@ def add_cached_properties(new_model: type["Model"]) -> None: new_model._related_fields = None new_model._json_fields = set() new_model._bytes_fields = set() + new_model._onupdate_fields = set() def add_property_fields(new_model: type["Model"], attrs: dict) -> None: # noqa: CCR001 diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index 8d8650ef8..d1aa6c1eb 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -28,6 +28,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin): _skip_ellipsis: Callable _json_fields: set[str] _bytes_fields: set[str] + _onupdate_fields: set[str] __pydantic_core_schema__: CoreSchema __ormar_fields_validators__: Optional[ dict[str, Union[SchemaValidator, PluggableSchemaValidator]] @@ -251,6 +252,34 @@ def populate_default_values(cls, new_kwargs: dict) -> dict: new_kwargs.pop(field_name, None) return new_kwargs + @classmethod + def populate_onupdate_value( + cls, new_kwargs: dict, explicit_fields: Optional[set[str]] = None + ) -> dict: + """ + Populates on_update values for fields that have on_update configured + and were not provided explicitly. + + If ``explicit_fields`` is omitted the keys of ``new_kwargs`` are + treated as explicitly provided. Callers pass the set explicitly when + ``new_kwargs`` is a full ``model_dump`` (e.g. ``bulk_update``) so that + fields the user actually mutated can be distinguished from values that + are just present by virtue of the dump. + + :param new_kwargs: dictionary of values about to be used for the update + :type new_kwargs: dict + :param explicit_fields: fields that should be considered user-provided + :type explicit_fields: Optional[set[str]] + :return: new_kwargs with on_update values populated + :rtype: dict + """ + if explicit_fields is None: + explicit_fields = set(new_kwargs.keys()) + for field_name in cls._onupdate_fields - explicit_fields: + field = cls.ormar_config.model_fields[field_name] + new_kwargs[field_name] = field.get_on_update() + return new_kwargs + @classmethod def validate_enums(cls, new_kwargs: dict) -> dict: """ diff --git a/ormar/models/model.py b/ormar/models/model.py index d62cefdc5..132a05655 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -119,6 +119,7 @@ async def save(self: T) -> T: ): await self.load() + self.__setattr_fields__.clear() await self.signals.post_save.send(sender=self.__class__, instance=self) return self @@ -243,8 +244,12 @@ async def update(self: T, _columns: Optional[list[str]] = None, **kwargs: Any) - :return: updated Model :rtype: Model """ - if kwargs: - self.update_from_dict(kwargs) + explicit_fields = self.__setattr_fields__ | kwargs.keys() + values = self.populate_onupdate_value( + dict(kwargs), explicit_fields=explicit_fields + ) + if values: + self.update_from_dict(values) if not self.pk: raise ModelPersistenceError( @@ -257,13 +262,18 @@ async def update(self: T, _columns: Optional[list[str]] = None, **kwargs: Any) - self_fields = self._extract_model_db_fields() self_fields.pop(self.get_column_name_from_alias(self.ormar_config.pkname)) if _columns: - self_fields = {k: v for k, v in self_fields.items() if k in _columns} + self_fields = { + k: v + for k, v in self_fields.items() + if k in _columns or k in self._onupdate_fields + } if self_fields: self_fields = self.translate_columns_to_aliases(self_fields) expr = self.ormar_config.table.update().values(**self_fields) expr = expr.where(self.pk_column == getattr(self, self.ormar_config.pkname)) await self._execute_query(expr) self.set_save_status(True) + self.__setattr_fields__.clear() await self.signals.post_update.send(sender=self.__class__, instance=self) return self @@ -309,6 +319,7 @@ async def load(self: T) -> T: kwargs = self.translate_aliases_to_columns(kwargs) self.update_from_dict(kwargs) self.set_save_status(True) + self.__setattr_fields__.clear() return self async def load_all( @@ -357,4 +368,5 @@ async def load_all( instance = await queryset.select_related(relations).get(pk=self.pk) self._orm.clear() self.update_from_dict(instance.model_dump()) + self.__setattr_fields__.clear() return self diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 07b87e0dc..0f94e714d 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -70,12 +70,14 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass "__cached_hash__", "__pydantic_extra__", "__pydantic_fields_set__", + "__setattr_fields__", ) if TYPE_CHECKING: # pragma no cover pk: Any __relation_map__: Optional[list[str]] __cached_hash__: Optional[int] + __setattr_fields__: set[str] _orm_relationship_manager: AliasManager _orm: RelationsManager _orm_id: int @@ -86,6 +88,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass _quick_access_fields: set _json_fields: set _bytes_fields: set + _onupdate_fields: set ormar_config: OrmarConfig # noinspection PyMissingConstructor @@ -230,6 +233,9 @@ def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001 # let pydantic handle errors for unknown fields super().__setattr__(name, value) + if self._onupdate_fields and name in self.ormar_config.model_fields: + self.__setattr_fields__.add(name) + # In this case, the hash could have changed, so update it if name == self.ormar_config.pkname or self.pk is None: object.__setattr__(self, "__cached_hash__", None) @@ -415,6 +421,7 @@ def _initialize_internal_attributes(self) -> None: # object.__setattr__(self, "_orm_id", uuid.uuid4().hex) object.__setattr__(self, "_orm_saved", False) object.__setattr__(self, "_pk_column", None) + object.__setattr__(self, "__setattr_fields__", set()) object.__setattr__( self, "_orm", diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index d302027c8..1bbe217c4 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -786,6 +786,7 @@ async def update(self, each: bool = False, **kwargs: Any) -> int: self.model.extract_related_names() ) updates = {k: v for k, v in kwargs.items() if k in self_fields} + updates = self.model.populate_onupdate_value(updates) updates = self.model.validate_enums(updates) updates = self.model.translate_columns_to_aliases(updates) @@ -1210,15 +1211,30 @@ async def bulk_update( # noqa: CCR001 if pk_name not in columns: columns.append(pk_name) + onupdate_fields = self.model._onupdate_fields + for field_name in onupdate_fields: + if field_name not in columns: + columns.append(field_name) + columns = [self.model.get_column_alias(k) for k in columns] for i, obj in enumerate(objects): + explicit_fields = obj.__setattr_fields__ new_kwargs = obj.model_dump() if new_kwargs.get(pk_name) is None: raise ModelPersistenceError( "You cannot update unsaved objects. " f"{self.model.__name__} has to have {pk_name} filled." ) + new_kwargs = obj.populate_onupdate_value( + new_kwargs, explicit_fields=explicit_fields + ) + obj.update_from_dict( + { + field_name: new_kwargs[field_name] + for field_name in onupdate_fields - explicit_fields + } + ) new_kwargs = obj.prepare_model_to_update(new_kwargs) ready_objects.append( {"new_" + k: v for k, v in new_kwargs.items() if k in columns} @@ -1249,6 +1265,7 @@ async def bulk_update( # noqa: CCR001 for obj in objects: obj.set_save_status(True) + obj.__setattr_fields__.clear() await cast( type["Model"], self.model_cls diff --git a/tests/test_model_methods/test_populate_onupdate_values.py b/tests/test_model_methods/test_populate_onupdate_values.py new file mode 100644 index 000000000..bb5c11339 --- /dev/null +++ b/tests/test_model_methods/test_populate_onupdate_values.py @@ -0,0 +1,256 @@ +import base64 +import enum +import uuid +from datetime import date, datetime, time +from decimal import Decimal +from typing import Optional + +import pytest +from sqlalchemy import func + +import ormar +from tests.lifespan import init_tests +from tests.settings import create_config + +base_ormar_config = create_config() + +PAST = datetime(2000, 1, 1, 0, 0, 0) + + +def _past() -> datetime: + return PAST + + +class Task(ormar.Model): + ormar_config = base_ormar_config.copy(tablename="tasks") + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String( + max_length=255, + on_update=lambda: "hello", + ) + points: int = ormar.Integer(default=0, minimum=0, on_update=1) + year: int = ormar.Integer(default=1, on_update=2) + updated_at: Optional[datetime] = ormar.DateTime( + default=_past, server_default=func.now(), on_update=datetime.now + ) + + +class Size(enum.Enum): + SMALL = "small" + LARGE = "large" + + +UPDATED_UUID = uuid.UUID("11111111-1111-1111-1111-111111111111") + + +class AllTypes(ormar.Model): + """Covers every ormar field type so on_update is exercised against each.""" + + ormar_config = base_ormar_config.copy(tablename="all_types") + + id: int = ormar.Integer(primary_key=True) + flag: bool = ormar.Boolean(default=False, on_update=True) + count: int = ormar.Integer(default=0, on_update=42) + big: int = ormar.BigInteger(default=0, on_update=9_000_000_000) + small: int = ormar.SmallInteger(default=0, on_update=7) + ratio: float = ormar.Float(default=0.0, on_update=3.14) + note: str = ormar.String(max_length=100, default="initial", on_update="updated_s") + body: str = ormar.Text(default="", on_update="updated_t") + amount: Decimal = ormar.Decimal( + max_digits=10, + decimal_places=2, + default=Decimal("0.00"), + on_update=Decimal("9.99"), + ) + when: datetime = ormar.DateTime(default=_past, on_update=datetime(2030, 1, 1)) + day: date = ormar.Date(default=date(2000, 1, 1), on_update=date(2030, 1, 1)) + clock: time = ormar.Time(default=time(0, 0, 0), on_update=time(12, 30, 45)) + meta: dict = ormar.JSON(default={"v": 0}, on_update={"v": 1}) + token: uuid.UUID = ormar.UUID( + default=uuid.UUID("00000000-0000-0000-0000-000000000000"), + on_update=UPDATED_UUID, + ) + blob: bytes = ormar.LargeBinary( + max_length=100, default=b"initial_blob", on_update=b"updated_blob" + ) + blob64: str = ormar.LargeBinary( + max_length=100, + represent_as_base64_str=True, + default=b"initial_b64", + on_update=b"updated_b64", + ) + size: Size = ormar.Enum(enum_class=Size, default=Size.SMALL, on_update=Size.LARGE) + secret: str = ormar.String( + max_length=100, + default="initial_secret", + on_update="updated_secret", + encrypt_secret="key_for_tests", + encrypt_backend=ormar.EncryptBackends.FERNET, + ) + + +create_test_database = init_tests(base_ormar_config) + + +@pytest.mark.asyncio +async def test_onupdate_use_setattr_to_update(): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): + t1 = await Task.objects.create(name="123") + assert t1.name == "123" + assert t1.points == 0 + assert t1.year == 1 + assert t1.updated_at == PAST + + t2 = await Task.objects.get(name="123") + t2.name = "explicit" + t2.year = 2024 + await t2.update() + assert t2.name == "explicit" + assert t2.points == 1 + assert t2.year == 2024 + assert t2.updated_at > PAST + + +@pytest.mark.asyncio +async def test_onupdate_use_update_func_kwargs(): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): + t1 = await Task.objects.create(name="123") + assert t1.name == "123" + assert t1.points == 0 + assert t1.year == 1 + assert t1.updated_at == PAST + + t2 = await Task.objects.get(name="123") + await t2.update(name="from_kwargs") + assert t2.name == "from_kwargs" + assert t2.points == 1 + assert t2.year == 2 + assert t2.updated_at > PAST + + +@pytest.mark.asyncio +async def test_onupdate_use_update_func_columns(): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): + t1 = await Task.objects.create(name="123") + assert t1.name == "123" + assert t1.points == 0 + assert t1.year == 1 + assert t1.updated_at == PAST + + t2 = await Task.objects.get(name="123") + await t2.update(_columns=["year"], year=2024) + assert t2.name == "hello" + assert t2.points == 1 + assert t2.year == 2024 + assert t2.updated_at > PAST + + +@pytest.mark.asyncio +async def test_onupdate_queryset_update(): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): + t1 = await Task.objects.create(name="123") + assert t1.name == "123" + assert t1.points == 0 + assert t1.year == 1 + assert t1.updated_at == PAST + + await Task.objects.filter(name="123").update(name="qs_update") + t2 = await Task.objects.get(name="qs_update") + assert t2.name == "qs_update" + assert t2.points == 1 + assert t2.year == 2 + assert t2.updated_at > PAST + + +@pytest.mark.asyncio +async def test_onupdate_upsert(): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): + t1 = await Task.objects.create(name="upsert_initial") + assert t1.updated_at == PAST + t1.name = "upsert_modified" + await t1.upsert() + assert t1.name == "upsert_modified" + assert t1.points == 1 + assert t1.year == 2 + assert t1.updated_at > PAST + + +@pytest.mark.asyncio +async def test_onupdate_bulk_update(): + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): + t1 = await Task.objects.create(name="123") + assert t1.name == "123" + assert t1.points == 0 + assert t1.year == 1 + assert t1.updated_at == PAST + + t2 = await Task.objects.get(name="123") + t2.name = "bulk_update" + await Task.objects.bulk_update([t2]) + t3 = await Task.objects.get(name="bulk_update") + assert t3.name == "bulk_update" + assert t3.points == 1 + assert t3.year == 2 + assert t3.updated_at > PAST + + t4 = await Task.objects.get(name="bulk_update") + t4.year = 2024 + await Task.objects.bulk_update([t4], columns=["year"]) + t5 = await Task.objects.get(year=2024) + assert t5.year == 2024 + assert t5.points == 1 + assert t5.name == "hello" + assert t5.updated_at > PAST + + +@pytest.mark.asyncio +async def test_onupdate_all_field_types(): + """Exercises on_update against every ormar field type.""" + async with base_ormar_config.database: + async with base_ormar_config.database.transaction(force_rollback=True): + created = await AllTypes.objects.create(id=1) + assert created.flag is False + assert created.count == 0 + assert created.big == 0 + assert created.small == 0 + assert created.ratio == 0.0 + assert created.note == "initial" + assert created.body == "" + assert created.amount == Decimal("0.00") + assert created.when == PAST + assert created.day == date(2000, 1, 1) + assert created.clock == time(0, 0, 0) + assert created.meta == {"v": 0} + assert created.token == uuid.UUID("00000000-0000-0000-0000-000000000000") + assert created.blob == b"initial_blob" + assert created.size is Size.SMALL + assert created.secret == "initial_secret" + + fetched = await AllTypes.objects.get(id=1) + await fetched.update() + + reloaded = await AllTypes.objects.get(id=1) + assert reloaded.flag is True + assert reloaded.count == 42 + assert reloaded.big == 9_000_000_000 + assert reloaded.small == 7 + assert reloaded.ratio == 3.14 + assert reloaded.note == "updated_s" + assert reloaded.body == "updated_t" + assert reloaded.amount == Decimal("9.99") + assert reloaded.when == datetime(2030, 1, 1) + assert reloaded.day == date(2030, 1, 1) + assert reloaded.clock == time(12, 30, 45) + assert reloaded.meta == {"v": 1} + assert reloaded.token == UPDATED_UUID + assert reloaded.blob == b"updated_blob" + assert reloaded.blob64 == base64.b64encode(b"updated_b64").decode() + assert reloaded.size is Size.LARGE + assert reloaded.secret == "updated_secret"