Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,7 @@ The following keyword arguments are supported on all field types.
* `nullable: bool`
* `default: Any`
* `server_default: Any`
* `on_update: Any`
* `index: bool`
* `unique: bool`
* `choices: typing.Sequence`
Expand Down
64 changes: 64 additions & 0 deletions docs/fields/common-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,70 @@ Sample usage:
!!!info
`server_default` is passed straight to sqlalchemy table definition so you can read more in [server default][server default] sqlalchemy documentation

## on_update

`on_update`: `Any` = `None` -> defaults to None.

A value (or Callable) that is written to the column on every update when the
user did not provide that field explicitly.

It applies to all write paths that update existing rows:

* `Model.update()` (and therefore `Model.upsert()`)
* `QuerySet.update()`
* `QuerySet.bulk_update()`

A field is considered explicitly provided (and its `on_update` is **skipped**)
when any of the following is true:

* it was passed as a keyword argument to `update()` / `upsert()`
* it was mutated on the instance via attribute assignment before `update()` /
`bulk_update()` is called
* it was passed through `QuerySet.update(field=value)`

You can pass a static value or a Callable; a Callable is invoked every time
the update fires (e.g. `datetime.now` produces the current timestamp on every
update).

Used in sql only.

Sample usage:

```python
from datetime import datetime

class Task(ormar.Model):
ormar_config = base_ormar_config.copy(tablename="tasks")

id: int = ormar.Integer(primary_key=True)
# callable - invoked on every update
updated_at: datetime = ormar.DateTime(
default=datetime.now, on_update=datetime.now
)
# static value
is_dirty: bool = ormar.Boolean(default=False, on_update=True)
# static value - field keeps the value the caller passed if any
revision: int = ormar.Integer(default=0, on_update=1)


task = await Task.objects.create()
assert task.is_dirty is False

await task.update()
assert task.is_dirty is True # on_update applied
assert task.revision == 1 # on_update applied
assert task.updated_at > created_time # on_update callable re-evaluated

# explicit values override on_update
await task.update(revision=5)
assert task.revision == 5 # user value wins, on_update skipped
```

!!!note
`on_update` is orthogonal to the `onupdate` parameter on `ForeignKey`,
which maps to the SQL `ON UPDATE` referential action and is unrelated.


## name

`name`: `str` = `None` -> defaults to None
Expand Down
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,7 @@ The following keyword arguments are supported on all field types.
* `nullable: bool`
* `default: Any`
* `server_default: Any`
* `on_update: Any`
* `index: bool`
* `unique: bool`
* `name: str`
Expand Down
19 changes: 19 additions & 0 deletions ormar/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(self, **kwargs: Any) -> None:

self.ormar_default: Any = kwargs.pop("default", None)
self.server_default: Any = kwargs.pop("server_default", None)
self.on_update: Any = kwargs.pop("on_update", None)

self.comment: str = kwargs.pop("comment", None)

Expand Down Expand Up @@ -241,6 +242,24 @@ def has_default(self, use_server: bool = True) -> bool:
self.server_default is not None and use_server
)

def has_on_update(self) -> bool:
"""
Checks if the field has an on_update value or callable configured.

:return: result of the check if on_update value is set
:rtype: bool
"""
return self.on_update is not None

def get_on_update(self) -> Any:
"""
Resolves the on_update value, calling it if it is a callable.

:return: resolved on_update value
:rtype: Any
"""
return self.on_update() if callable(self.on_update) else self.on_update

def is_auto_primary_key(self) -> bool:
"""
Checks if field is first a primary key and if it,
Expand Down
3 changes: 3 additions & 0 deletions ormar/models/helpers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def populate_default_options_values( # noqa: CCR001
new_model._bytes_fields = {
name for name, field in model_fields.items() if field.__type__ is bytes
}
new_model._onupdate_fields = {
name for name, field in model_fields.items() if field.has_on_update()
}

new_model.__relation_map__ = None
new_model.__ormar_fields_validators__ = None
Expand Down
1 change: 1 addition & 0 deletions ormar/models/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def add_cached_properties(new_model: type["Model"]) -> None:
new_model._related_fields = None
new_model._json_fields = set()
new_model._bytes_fields = set()
new_model._onupdate_fields = set()


def add_property_fields(new_model: type["Model"], attrs: dict) -> None: # noqa: CCR001
Expand Down
29 changes: 29 additions & 0 deletions ormar/models/mixins/save_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
_skip_ellipsis: Callable
_json_fields: set[str]
_bytes_fields: set[str]
_onupdate_fields: set[str]
__pydantic_core_schema__: CoreSchema
__ormar_fields_validators__: Optional[
dict[str, Union[SchemaValidator, PluggableSchemaValidator]]
Expand Down Expand Up @@ -251,6 +252,34 @@ def populate_default_values(cls, new_kwargs: dict) -> dict:
new_kwargs.pop(field_name, None)
return new_kwargs

@classmethod
def populate_onupdate_value(
cls, new_kwargs: dict, explicit_fields: Optional[set[str]] = None
) -> dict:
"""
Populates on_update values for fields that have on_update configured
and were not provided explicitly.

If ``explicit_fields`` is omitted the keys of ``new_kwargs`` are
treated as explicitly provided. Callers pass the set explicitly when
``new_kwargs`` is a full ``model_dump`` (e.g. ``bulk_update``) so that
fields the user actually mutated can be distinguished from values that
are just present by virtue of the dump.

:param new_kwargs: dictionary of values about to be used for the update
:type new_kwargs: dict
:param explicit_fields: fields that should be considered user-provided
:type explicit_fields: Optional[set[str]]
:return: new_kwargs with on_update values populated
:rtype: dict
"""
if explicit_fields is None:
explicit_fields = set(new_kwargs.keys())
for field_name in cls._onupdate_fields - explicit_fields:
field = cls.ormar_config.model_fields[field_name]
new_kwargs[field_name] = field.get_on_update()
return new_kwargs

@classmethod
def validate_enums(cls, new_kwargs: dict) -> dict:
"""
Expand Down
18 changes: 15 additions & 3 deletions ormar/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ async def save(self: T) -> T:
):
await self.load()

self.__setattr_fields__.clear()
await self.signals.post_save.send(sender=self.__class__, instance=self)
return self

Expand Down Expand Up @@ -243,8 +244,12 @@ async def update(self: T, _columns: Optional[list[str]] = None, **kwargs: Any) -
:return: updated Model
:rtype: Model
"""
if kwargs:
self.update_from_dict(kwargs)
explicit_fields = self.__setattr_fields__ | kwargs.keys()
values = self.populate_onupdate_value(
dict(kwargs), explicit_fields=explicit_fields
)
if values:
self.update_from_dict(values)

if not self.pk:
raise ModelPersistenceError(
Expand All @@ -257,13 +262,18 @@ async def update(self: T, _columns: Optional[list[str]] = None, **kwargs: Any) -
self_fields = self._extract_model_db_fields()
self_fields.pop(self.get_column_name_from_alias(self.ormar_config.pkname))
if _columns:
self_fields = {k: v for k, v in self_fields.items() if k in _columns}
self_fields = {
k: v
for k, v in self_fields.items()
if k in _columns or k in self._onupdate_fields
}
if self_fields:
self_fields = self.translate_columns_to_aliases(self_fields)
expr = self.ormar_config.table.update().values(**self_fields)
expr = expr.where(self.pk_column == getattr(self, self.ormar_config.pkname))
await self._execute_query(expr)
self.set_save_status(True)
self.__setattr_fields__.clear()
await self.signals.post_update.send(sender=self.__class__, instance=self)
return self

Expand Down Expand Up @@ -309,6 +319,7 @@ async def load(self: T) -> T:
kwargs = self.translate_aliases_to_columns(kwargs)
self.update_from_dict(kwargs)
self.set_save_status(True)
self.__setattr_fields__.clear()
return self

async def load_all(
Expand Down Expand Up @@ -357,4 +368,5 @@ async def load_all(
instance = await queryset.select_related(relations).get(pk=self.pk)
self._orm.clear()
self.update_from_dict(instance.model_dump())
self.__setattr_fields__.clear()
return self
7 changes: 7 additions & 0 deletions ormar/models/newbasemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
"__cached_hash__",
"__pydantic_extra__",
"__pydantic_fields_set__",
"__setattr_fields__",
)

if TYPE_CHECKING: # pragma no cover
pk: Any
__relation_map__: Optional[list[str]]
__cached_hash__: Optional[int]
__setattr_fields__: set[str]
_orm_relationship_manager: AliasManager
_orm: RelationsManager
_orm_id: int
Expand All @@ -86,6 +88,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
_quick_access_fields: set
_json_fields: set
_bytes_fields: set
_onupdate_fields: set
ormar_config: OrmarConfig

# noinspection PyMissingConstructor
Expand Down Expand Up @@ -230,6 +233,9 @@ def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
# let pydantic handle errors for unknown fields
super().__setattr__(name, value)

if self._onupdate_fields and name in self.ormar_config.model_fields:
self.__setattr_fields__.add(name)

# In this case, the hash could have changed, so update it
if name == self.ormar_config.pkname or self.pk is None:
object.__setattr__(self, "__cached_hash__", None)
Expand Down Expand Up @@ -415,6 +421,7 @@ def _initialize_internal_attributes(self) -> None:
# object.__setattr__(self, "_orm_id", uuid.uuid4().hex)
object.__setattr__(self, "_orm_saved", False)
object.__setattr__(self, "_pk_column", None)
object.__setattr__(self, "__setattr_fields__", set())
object.__setattr__(
self,
"_orm",
Expand Down
17 changes: 17 additions & 0 deletions ormar/queryset/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,7 @@ async def update(self, each: bool = False, **kwargs: Any) -> int:
self.model.extract_related_names()
)
updates = {k: v for k, v in kwargs.items() if k in self_fields}
updates = self.model.populate_onupdate_value(updates)
updates = self.model.validate_enums(updates)
updates = self.model.translate_columns_to_aliases(updates)

Expand Down Expand Up @@ -1210,15 +1211,30 @@ async def bulk_update( # noqa: CCR001
if pk_name not in columns:
columns.append(pk_name)

onupdate_fields = self.model._onupdate_fields
for field_name in onupdate_fields:
if field_name not in columns:
columns.append(field_name)

columns = [self.model.get_column_alias(k) for k in columns]

for i, obj in enumerate(objects):
explicit_fields = obj.__setattr_fields__
new_kwargs = obj.model_dump()
if new_kwargs.get(pk_name) is None:
raise ModelPersistenceError(
"You cannot update unsaved objects. "
f"{self.model.__name__} has to have {pk_name} filled."
)
new_kwargs = obj.populate_onupdate_value(
new_kwargs, explicit_fields=explicit_fields
)
obj.update_from_dict(
{
field_name: new_kwargs[field_name]
for field_name in onupdate_fields - explicit_fields
}
)
new_kwargs = obj.prepare_model_to_update(new_kwargs)
ready_objects.append(
{"new_" + k: v for k, v in new_kwargs.items() if k in columns}
Expand Down Expand Up @@ -1249,6 +1265,7 @@ async def bulk_update( # noqa: CCR001

for obj in objects:
obj.set_save_status(True)
obj.__setattr_fields__.clear()

await cast(
type["Model"], self.model_cls
Expand Down
Loading
Loading