diff --git a/tests/test_update.py b/tests/test_update.py index 106f6c1b4..8cdfb455f 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -165,6 +165,78 @@ async def test_update_auto_now(db): assert obj1.updated_at.date() == updated_at.date() +@pytest.mark.asyncio +async def test_update_auto_now_with_update_fields(db): + tournament = await Tournament.create(name="1") + event = await Event.create(name="original", tournament=tournament) + original_modified = event.modified + + # Set modified to the past so we can detect if it gets updated + past = timezone.now() - timedelta(days=1) + await Event.filter(pk=event.pk).update(modified=past) + + event = await Event.get(pk=event.pk) + assert event.modified.date() == past.date() + + # Update only name with update_fields; auto_now field should also be updated + event.name = "updated" + await event.save(update_fields=["name"]) + + event = await Event.get(pk=event.pk) + assert event.name == "updated" + assert event.modified.date() == timezone.now().date() + + +@pytest.mark.asyncio +async def test_bulk_update_auto_now(db): + tournament = await Tournament.create(name="1") + event1 = await Event.create(name="original1", tournament=tournament) + event2 = await Event.create(name="original2", tournament=tournament) + + # Set modified to the past so we can detect if it gets updated + past = timezone.now() - timedelta(days=1) + await Event.filter(pk__in=[event1.pk, event2.pk]).update(modified=past) + + event1 = await Event.get(pk=event1.pk) + event2 = await Event.get(pk=event2.pk) + assert event1.modified.date() == past.date() + assert event2.modified.date() == past.date() + + # bulk_update only name; auto_now field should also be updated + event1.name = "updated1" + event2.name = "updated2" + await Event.filter(pk__in=[event1.pk, event2.pk]).bulk_update( + [event1, event2], fields=["name"] + ) + + event1 = await Event.get(pk=event1.pk) + event2 = await Event.get(pk=event2.pk) + assert event1.name == "updated1" + assert event2.name == "updated2" + assert event1.modified.date() == timezone.now().date() + assert event2.modified.date() == timezone.now().date() + + +@pytest.mark.asyncio +async def test_queryset_update_auto_now(db): + tournament = await Tournament.create(name="1") + event = await Event.create(name="original", tournament=tournament) + + # Set modified to the past (explicit modified= won't be overridden by auto_now) + past = timezone.now() - timedelta(days=1) + await Event.filter(pk=event.pk).update(modified=past) + + event = await Event.get(pk=event.pk) + assert event.modified.date() == past.date() + + # queryset.update() should auto-include auto_now fields + await Event.filter(pk=event.pk).update(name="updated") + + event = await Event.get(pk=event.pk) + assert event.name == "updated" + assert event.modified.date() == timezone.now().date() + + @pytest.mark.asyncio async def test_update_relation(db): tournament_first = await Tournament.create(name="1") diff --git a/tortoise/models.py b/tortoise/models.py index 71fa3c43d..a7e15f0ec 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -1149,6 +1149,12 @@ async def save( raise IncompleteInstanceError( f"{self.__class__.__name__} is a partial model, can only be saved with the relevant update_field provided" ) + if update_fields: + update_fields = list(update_fields) + for field_name, field_obj in self._meta.fields_map.items(): + if field_name not in update_fields and getattr(field_obj, "auto_now", False): + update_fields.append(field_name) + await self._pre_save(db, update_fields) if force_create: diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 0cfd2fc52..0cc796c4f 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1285,6 +1285,12 @@ def __init__( orderings: list[tuple[str, str]], ) -> None: super().__init__(model) + # Inject auto_now fields into update_kwargs if not already specified + from tortoise import timezone + + for field_name, field_obj in model._meta.fields_map.items(): + if field_name not in update_kwargs and getattr(field_obj, "auto_now", False): + update_kwargs[field_name] = timezone.now() self.update_kwargs = update_kwargs self._q_objects = q_objects self._annotations = annotations @@ -1933,7 +1939,11 @@ def __init__( limit=limit, orderings=orderings, ) - self.fields = fields + fields_list = list(fields) + for field_name, field_obj in model._meta.fields_map.items(): + if field_name not in fields_list and getattr(field_obj, "auto_now", False): + fields_list.append(field_name) + self.fields = fields_list self._objects = objects self._batch_size = batch_size self._queries: list[QueryBuilder] = []