diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 49582e62612e..0c79e5c1337d 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -38,6 +38,7 @@ class BaseDatabaseFeatures: can_use_chunked_reads = True can_return_columns_from_insert = False can_return_rows_from_bulk_insert = False + can_return_rows_from_update = False has_bulk_insert = True uses_savepoints = True can_release_savepoints = False diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index c8b49609bdaa..bf79f7a6e327 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -243,6 +243,7 @@ def __init__(self, *args, **kwargs): "use_returning_into", True ) self.features.can_return_columns_from_insert = use_returning_into + self.features.can_return_rows_from_update = use_returning_into @property def is_pool(self): diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py index 7ca40a8000bd..e87f495e5cae 100644 --- a/django/db/backends/oracle/features.py +++ b/django/db/backends/oracle/features.py @@ -19,6 +19,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): has_select_for_update_of = True select_for_update_of_column = True can_return_columns_from_insert = True + can_return_rows_from_update = True supports_subqueries_in_group_by = False ignores_unnecessary_order_by_in_subqueries = False supports_tuple_comparison_against_subquery = False diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 419fad868670..5f63b6c713f5 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -11,6 +11,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): allows_group_by_selected_pks = True can_return_columns_from_insert = True can_return_rows_from_bulk_insert = True + can_return_rows_from_update = True has_real_datatype = True has_native_uuid_field = True has_native_duration_field = True diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py index 8604adf40aab..143ee1e98bb5 100644 --- a/django/db/backends/sqlite3/features.py +++ b/django/db/backends/sqlite3/features.py @@ -171,3 +171,7 @@ def can_return_columns_from_insert(self): can_return_rows_from_bulk_insert = property( operator.attrgetter("can_return_columns_from_insert") ) + + can_return_rows_from_update = property( + operator.attrgetter("can_return_columns_from_insert") + ) diff --git a/django/db/models/base.py b/django/db/models/base.py index 3827b00346d4..518cfc44a252 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -1094,12 +1094,33 @@ def _save_table( ] forced_update = update_fields or force_update pk_val = self._get_pk_val(meta) - updated = self._do_update( - base_qs, using, pk_val, values, update_fields, forced_update + returning_fields = [ + f + for f in meta.local_concrete_fields + if ( + f.generated + and f.referenced_fields.intersection(non_pks_non_generated) + ) + ] + for field, _model, value in values: + if (update_fields is None or field.name in update_fields) and hasattr( + value, "resolve_expression" + ): + returning_fields.append(field) + results = self._do_update( + base_qs, + using, + pk_val, + values, + update_fields, + forced_update, + returning_fields, ) - if force_update and not updated: + if updated := bool(results): + self._assign_returned_values(results[0], returning_fields) + elif force_update: raise self.NotUpdated("Forced update did not affect any rows.") - if update_fields and not updated: + elif update_fields: raise self.NotUpdated( "Save with update_fields did not affect any rows." ) @@ -1126,16 +1147,32 @@ def _save_table( for f in meta.local_concrete_fields if not f.generated and (pk_set or f is not meta.auto_field) ] - returning_fields = meta.db_returning_fields + returning_fields = list(meta.db_returning_fields) + for field in fields: + value = ( + getattr(self, field.attname) if raw else field.pre_save(self, False) + ) + if hasattr(value, "resolve_expression"): + returning_fields.append(field) + elif field.db_returning: + returning_fields.remove(field) results = self._do_insert( cls._base_manager, using, fields, returning_fields, raw ) if results: - for value, field in zip(results[0], returning_fields): - setattr(self, field.attname, value) + self._assign_returned_values(results[0], returning_fields) return updated - def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_update): + def _do_update( + self, + base_qs, + using, + pk_val, + values, + update_fields, + forced_update, + returning_fields, + ): """ Try to update the model. Return True if the model was updated (if an update query was done and a matching row was found in the DB). @@ -1147,22 +1184,23 @@ def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_updat # case we just say the update succeeded. Another case ending up # here is a model with just PK - in that case check that the PK # still exists. - return update_fields is not None or filtered.exists() + if update_fields is not None or filtered.exists(): + return [()] + return [] if self._meta.select_on_save and not forced_update: - return ( - filtered.exists() - and - # It may happen that the object is deleted from the DB right - # after this check, causing the subsequent UPDATE to return - # zero matching rows. The same result can occur in some rare - # cases when the database returns zero despite the UPDATE being - # executed successfully (a row is matched and updated). In - # order to distinguish these two cases, the object's existence - # in the database is again checked for if the UPDATE query - # returns 0. - (filtered._update(values) > 0 or filtered.exists()) - ) - return filtered._update(values) > 0 + # It may happen that the object is deleted from the DB right after + # this check, causing the subsequent UPDATE to return zero matching + # rows. The same result can occur in some rare cases when the + # database returns zero despite the UPDATE being executed + # successfully (a row is matched and updated). In order to + # distinguish these two cases, the object's existence in the + # database is again checked for if the UPDATE query returns 0. + if not filtered.exists(): + return [] + if results := filtered._update(values, returning_fields): + return results + return [()] if filtered.exists() else [] + return filtered._update(values, returning_fields) def _do_insert(self, manager, using, fields, returning_fields, raw): """ @@ -1177,6 +1215,15 @@ def _do_insert(self, manager, using, fields, returning_fields, raw): raw=raw, ) + def _assign_returned_values(self, returned_values, returning_fields): + returning_fields_iter = iter(returning_fields) + for value, field in zip(returned_values, returning_fields_iter): + setattr(self, field.attname, value) + # Defer all fields that were meant to be updated with their database + # resolved values but couldn't as they are effectively stale. + for field in returning_fields_iter: + self.__dict__.pop(field.attname, None) + def _prepare_related_fields_for_save(self, operation_name, fields=None): # Ensure that a model instance without a PK hasn't been assigned to # a ForeignKey, GenericForeignKey or OneToOneField on this model. If diff --git a/django/db/models/fields/generated.py b/django/db/models/fields/generated.py index f6e3445b7142..f89269b5e6e4 100644 --- a/django/db/models/fields/generated.py +++ b/django/db/models/fields/generated.py @@ -66,6 +66,16 @@ def generated_sql(self, connection): sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END" return sql, params + @cached_property + def referenced_fields(self): + resolved_expression = self.expression.resolve_expression( + self._query, allow_joins=False + ) + referenced_fields = [] + for col in self._query._gen_cols([resolved_expression]): + referenced_fields.append(col.target) + return frozenset(referenced_fields) + def check(self, **kwargs): databases = kwargs.get("databases") or [] errors = [ diff --git a/django/db/models/query.py b/django/db/models/query.py index d2f31d15a077..2359ee3bb4a3 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1306,7 +1306,7 @@ async def aupdate(self, **kwargs): aupdate.alters_data = True - def _update(self, values): + def _update(self, values, returning_fields=None): """ A version of update() that accepts field objects instead of field names. Used primarily for model saving and not intended for use by @@ -1320,7 +1320,9 @@ def _update(self, values): # Clear any annotations so that they won't be present in subqueries. query.annotations = {} self._result_cache = None - return query.get_compiler(self.db).execute_sql(ROW_COUNT) + if returning_fields is None: + return query.get_compiler(self.db).execute_sql(ROW_COUNT) + return query.get_compiler(self.db).execute_returning_sql(returning_fields) _update.alters_data = True _update.queryset_only = False diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 73dfa5b87c84..0e483dc4f649 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -2020,6 +2020,9 @@ def as_sql(self): class SQLUpdateCompiler(SQLCompiler): + returning_fields = None + returning_params = () + def as_sql(self): """ Create the SQL for this query. Return the SQL string and list of @@ -2087,6 +2090,15 @@ def as_sql(self): params = [] else: result.append("WHERE %s" % where) + if self.returning_fields: + # Skip empty r_sql to allow subclasses to customize behavior for + # 3rd party backends. Refs #19096. + r_sql, self.returning_params = self.connection.ops.returning_columns( + self.returning_fields + ) + if r_sql: + result.append(r_sql) + params.extend(self.returning_params) return " ".join(result), tuple(update_params + params) def execute_sql(self, result_type): @@ -2110,6 +2122,38 @@ def execute_sql(self, result_type): is_empty = False return row_count + def execute_returning_sql(self, returning_fields): + """ + Execute the specified update and return rows of the returned columns + associated with the specified returning_field if the backend supports + it. + """ + if self.query.get_related_updates(): + raise NotImplementedError( + "Update returning is not implemented for queries with related updates." + ) + + if ( + not returning_fields + or not self.connection.features.can_return_rows_from_update + ): + row_count = self.execute_sql(ROW_COUNT) + return [()] * row_count + + self.returning_fields = returning_fields + with self.connection.cursor() as cursor: + sql, params = self.as_sql() + cursor.execute(sql, params) + rows = self.connection.ops.fetch_returned_rows( + cursor, self.returning_params + ) + opts = self.query.get_meta() + cols = [field.get_col(opts.db_table) for field in self.returning_fields] + converters = self.get_converters(cols) + if converters: + rows = self.apply_converters(rows, converters) + return list(rows) + def pre_sql_setup(self): """ If the update depends on results from other tables, munge the "where" diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 77e8b165da7f..a1b8984a9b14 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -69,8 +69,6 @@ Some examples # Create a new company using expressions. >>> company = Company.objects.create(name="Google", ticker=Upper(Value("goog"))) - # Be sure to refresh it if you need to access the field. - >>> company.refresh_from_db() >>> company.ticker 'GOOG' @@ -157,12 +155,6 @@ know about it - it is dealt with entirely by the database. All Python does, through Django's ``F()`` class, is create the SQL syntax to refer to the field and describe the operation. -To access the new value saved this way, the object must be reloaded:: - - reporter = Reporters.objects.get(pk=reporter.pk) - # Or, more succinctly: - reporter.refresh_from_db() - As well as being used in operations on single instances as above, ``F()`` can be used with ``update()`` to perform bulk updates on a ``QuerySet``. This reduces the two queries we were using above - the ``get()`` and the @@ -199,7 +191,6 @@ array-slicing syntax. The indices are 0-based and the ``step`` argument to >>> writer = Writers.objects.get(name="Priyansh") >>> writer.name = F("name")[1:5] >>> writer.save() - >>> writer.refresh_from_db() >>> writer.name 'riya' @@ -221,23 +212,27 @@ robust: it will only ever update the field based on the value of the field in the database when the :meth:`~Model.save` or ``update()`` is executed, rather than based on its value when the instance was retrieved. -``F()`` assignments persist after ``Model.save()`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``F()`` assignments are refreshed after ``Model.save()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -``F()`` objects assigned to model fields persist after saving the model -instance and will be applied on each :meth:`~Model.save`. For example:: +``F()`` objects assigned to model fields are refreshed from the database on +:meth:`~Model.save` on backends that support it without incurring a subsequent +query (SQLite, PostgreSQL, and Oracle) and deferred otherwise (MySQL or +MariaDB). For example: - reporter = Reporters.objects.get(name="Tintin") - reporter.stories_filed = F("stories_filed") + 1 - reporter.save() +.. code-block:: pycon - reporter.name = "Tintin Jr." - reporter.save() + >>> reporter = Reporters.objects.get(name="Tintin") + >>> reporter.stories_filed = F("stories_filed") + 1 + >>> reporter.save() + >>> reporter.stories_filed # This triggers a refresh query on MySQL/MariaDB. + 14 # Assuming the database value was 13 when the object was saved. + +.. versionchanged:: 6.0 -``stories_filed`` will be updated twice in this case. If it's initially ``1``, -the final value will be ``3``. This persistence can be avoided by reloading the -model object after saving it, for example, by using -:meth:`~Model.refresh_from_db`. + In previous versions of Django, ``F()`` objects were not refreshed from the + database on :meth:`~Model.save` which resulted in them being evaluated and + persisted every time the instance was saved. Using ``F()`` in filters ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index dd0862926aa0..f105096c8c0a 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -1315,12 +1315,6 @@ materialized view. PostgreSQL only supports persisted columns. Oracle only supports virtual columns. -.. admonition:: Refresh the data - - Since the database computes the value, the object must be reloaded to - access the new value after :meth:`~Model.save`, for example, by using - :meth:`~Model.refresh_from_db`. - .. admonition:: Database limitations There are many database-specific restrictions on generated fields that @@ -1338,6 +1332,12 @@ materialized view. .. _PostgreSQL: https://www.postgresql.org/docs/current/ddl-generated-columns.html .. _SQLite: https://www.sqlite.org/gencol.html#limitations +.. versionchanged:: 6.0 + + ``GeneratedField``\s are now automatically refreshed from the database on + backends that support it (SQLite, PostgreSQL, and Oracle) and marked as + deferred otherwise. + ``GenericIPAddressField`` ------------------------- diff --git a/docs/releases/6.0.txt b/docs/releases/6.0.txt index e11a16162a1e..adfac83b8da2 100644 --- a/docs/releases/6.0.txt +++ b/docs/releases/6.0.txt @@ -331,6 +331,13 @@ Models value from the non-null input values. This is supported on SQLite, MySQL, Oracle, and PostgreSQL 16+. +* :class:`~django.db.models.GeneratedField`\s and :ref:`fields assigned + expressions ` are now refreshed from the + database after :meth:`~django.db.models.Model.save` on backends that support + the ``RETURNING`` clause (SQLite, PostgreSQL, and Oracle). On backends that + don't support it (MySQL and MariaDB), the fields are marked as deferred to + trigger a refresh on subsequent accesses. + Pagination ~~~~~~~~~~ @@ -420,6 +427,9 @@ backends. ``returning_params`` to be provided just like ``fetch_returned_insert_columns()`` did. +* If the database supports ``UPDATE … RETURNING`` statements, backends can set + ``DatabaseFeatures.can_return_rows_from_update=True``. + Dropped support for MariaDB 10.5 -------------------------------- diff --git a/tests/basic/tests.py b/tests/basic/tests.py index f8ec2715f6c3..38e7278210b9 100644 --- a/tests/basic/tests.py +++ b/tests/basic/tests.py @@ -1,5 +1,6 @@ import inspect import threading +import time from datetime import datetime, timedelta from unittest import mock @@ -12,6 +13,7 @@ models, transaction, ) +from django.db.models.functions import Now from django.db.models.manager import BaseManager from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet from django.test import ( @@ -558,6 +560,26 @@ def new_instance(): with self.subTest(case=case): self.assertIs(case._is_pk_set(), True) + def test_save_expressions(self): + article = Article(pub_date=Now()) + article.save() + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + article_pub_date = article.pub_date + self.assertIsInstance(article_pub_date, datetime) + # Sleep slightly to ensure a different database level NOW(). + time.sleep(0.1) + article.pub_date = Now() + article.save() + expected_num_queries = ( + 0 if connection.features.can_return_rows_from_update else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertIsInstance(article.pub_date, datetime) + self.assertGreater(article.pub_date, article_pub_date) + class ModelLookupTest(TestCase): @classmethod diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 27d88be6213c..6f18321aa7b3 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -420,8 +420,11 @@ def test_object_update(self): # F expressions can be used to update attributes on single objects self.gmbh.num_employees = F("num_employees") + 4 self.gmbh.save() - self.gmbh.refresh_from_db() - self.assertEqual(self.gmbh.num_employees, 36) + expected_num_queries = ( + 0 if connection.features.can_return_rows_from_update else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(self.gmbh.num_employees, 36) def test_new_object_save(self): # We should be able to use Funcs when inserting new data @@ -1644,8 +1647,11 @@ def test_decimal_expression(self): n = Number.objects.create(integer=1, decimal_value=Decimal("0.5")) n.decimal_value = F("decimal_value") - Decimal("0.4") n.save() - n.refresh_from_db() - self.assertEqual(n.decimal_value, Decimal("0.1")) + expected_num_queries = ( + 0 if connection.features.can_return_rows_from_update else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(n.decimal_value, Decimal("0.1")) class ExpressionOperatorTests(TestCase): diff --git a/tests/field_defaults/tests.py b/tests/field_defaults/tests.py index e914adfc5183..7f85d946f6bb 100644 --- a/tests/field_defaults/tests.py +++ b/tests/field_defaults/tests.py @@ -15,13 +15,7 @@ ) from django.db.models.functions import Collate from django.db.models.lookups import GreaterThan -from django.test import ( - SimpleTestCase, - TestCase, - override_settings, - skipIfDBFeature, - skipUnlessDBFeature, -) +from django.test import SimpleTestCase, TestCase, override_settings, skipUnlessDBFeature from django.utils import timezone from .models import ( @@ -44,47 +38,56 @@ def test_field_defaults(self): self.assertEqual(a.headline, "Default headline") self.assertLess((now - a.pub_date).seconds, 5) - @skipUnlessDBFeature( - "can_return_columns_from_insert", "supports_expression_defaults" - ) + @skipUnlessDBFeature("supports_expression_defaults") def test_field_db_defaults_returning(self): a = DBArticle() a.save() self.assertIsInstance(a.id, int) - self.assertEqual(a.headline, "Default headline") - self.assertIsInstance(a.pub_date, datetime) - self.assertEqual(a.cost, Decimal("3.33")) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 3 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(a.headline, "Default headline") + self.assertIsInstance(a.pub_date, datetime) + self.assertEqual(a.cost, Decimal("3.33")) - @skipIfDBFeature("can_return_columns_from_insert") @skipUnlessDBFeature("supports_expression_defaults") def test_field_db_defaults_refresh(self): a = DBArticle() a.save() - a.refresh_from_db() + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 3 + ) self.assertIsInstance(a.id, int) - self.assertEqual(a.headline, "Default headline") - self.assertIsInstance(a.pub_date, datetime) - self.assertEqual(a.cost, Decimal("3.33")) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(a.headline, "Default headline") + self.assertIsInstance(a.pub_date, datetime) + self.assertEqual(a.cost, Decimal("3.33")) def test_null_db_default(self): obj1 = DBDefaults.objects.create() - if not connection.features.can_return_columns_from_insert: - obj1.refresh_from_db() - self.assertEqual(obj1.null, 1.1) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(obj1.null, 1.1) obj2 = DBDefaults.objects.create(null=None) - self.assertIsNone(obj2.null) + with self.assertNumQueries(0): + self.assertIsNone(obj2.null) @skipUnlessDBFeature("supports_expression_defaults") @override_settings(USE_TZ=True) def test_db_default_function(self): m = DBDefaultsFunction.objects.create() - if not connection.features.can_return_columns_from_insert: - m.refresh_from_db() - self.assertAlmostEqual(m.number, pi) - self.assertEqual(m.year, timezone.now().year) - self.assertAlmostEqual(m.added, pi + 4.5) - self.assertEqual(m.multiple_subfunctions, 4.5) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 4 + ) + with self.assertNumQueries(expected_num_queries): + self.assertAlmostEqual(m.number, pi) + self.assertEqual(m.year, timezone.now().year) + self.assertAlmostEqual(m.added, pi + 4.5) + self.assertEqual(m.multiple_subfunctions, 4.5) @skipUnlessDBFeature("insert_test_table_with_defaults") def test_both_default(self): @@ -125,14 +128,15 @@ def test_foreign_key_db_default(self): child2 = DBDefaultsFK.objects.create(language_code=parent2) self.assertEqual(child2.language_code, parent2) - @skipUnlessDBFeature( - "can_return_columns_from_insert", "supports_expression_defaults" - ) + @skipUnlessDBFeature("supports_expression_defaults") def test_case_when_db_default_returning(self): m = DBDefaultsFunction.objects.create() - self.assertEqual(m.case_when, 3) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.case_when, 3) - @skipIfDBFeature("can_return_columns_from_insert") @skipUnlessDBFeature("supports_expression_defaults") def test_case_when_db_default_no_returning(self): m = DBDefaultsFunction.objects.create() diff --git a/tests/model_fields/test_generatedfield.py b/tests/model_fields/test_generatedfield.py index b6a933451dd5..f0ac6eecb5fd 100644 --- a/tests/model_fields/test_generatedfield.py +++ b/tests/model_fields/test_generatedfield.py @@ -173,11 +173,6 @@ class Sum(Model): class GeneratedFieldTestMixin: - def _refresh_if_needed(self, m): - if not connection.features.can_return_columns_from_insert: - m.refresh_from_db() - return m - def test_unsaved_error(self): m = self.base_model(a=1, b=2) msg = "Cannot retrieve deferred field 'field' from an unsaved model." @@ -189,8 +184,11 @@ def test_full_clean(self): # full_clean() ignores GeneratedFields. m.full_clean() m.save() - m = self._refresh_if_needed(m) - self.assertEqual(m.field, 3) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, 3) @skipUnlessDBFeature("supports_table_check_constraints") def test_full_clean_with_check_constraint(self): @@ -199,8 +197,11 @@ def test_full_clean_with_check_constraint(self): m = self.check_constraint_model(a=2) m.full_clean() m.save() - m = self._refresh_if_needed(m) - self.assertEqual(m.a_squared, 4) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.a_squared, 4) m = self.check_constraint_model(a=-1) with self.assertRaises(ValidationError) as cm: @@ -217,8 +218,11 @@ def test_full_clean_with_unique_constraint_expression(self): m = self.unique_constraint_model(a=2) m.full_clean() m.save() - m = self._refresh_if_needed(m) - self.assertEqual(m.a_squared, 4) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.a_squared, 4) m = self.unique_constraint_model(a=2) with self.assertRaises(ValidationError) as cm: @@ -230,8 +234,11 @@ def test_full_clean_with_unique_constraint_expression(self): def test_create(self): m = self.base_model.objects.create(a=1, b=2) - m = self._refresh_if_needed(m) - self.assertEqual(m.field, 3) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, 3) def test_non_nullable_create(self): with self.assertRaises(IntegrityError): @@ -241,26 +248,52 @@ def test_save(self): # Insert. m = self.base_model(a=2, b=4) m.save() - m = self._refresh_if_needed(m) - self.assertEqual(m.field, 6) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, 6) # Update. m.a = 4 m.save() - m.refresh_from_db() - self.assertEqual(m.field, 8) + expected_num_queries = ( + 0 if connection.features.can_return_rows_from_update else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, 8) + # Update non-dependent field. + self.base_model.objects.filter(pk=m.pk).update(a=6) + m.save(update_fields=["fk"]) + with self.assertNumQueries(0): + self.assertEqual(m.field, 8) + # Update dependent field without persisting local changes. + m.save(update_fields=["b"]) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, 10) + # Update dependent field while persisting local changes. + m.a = 8 + m.save(update_fields=["a"]) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, 12) def test_save_model_with_pk(self): m = self.base_model(pk=1, a=1, b=2) m.save() - m = self._refresh_if_needed(m) - self.assertEqual(m.field, 3) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, 3) def test_save_model_with_foreign_key(self): fk_object = Foo.objects.create(a="abc", d=Decimal("12.34")) m = self.base_model(a=1, b=2, fk=fk_object) m.save() - m = self._refresh_if_needed(m) - self.assertEqual(m.field, 3) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, 3) def test_generated_fields_can_be_deferred(self): fk_object = Foo.objects.create(a="abc", d=Decimal("12.34")) @@ -330,17 +363,23 @@ def test_db_type_parameters(self): def test_model_with_params(self): m = self.params_model.objects.create() - m = self._refresh_if_needed(m) - self.assertEqual(m.field, "Constant") + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m.field, "Constant") def test_nullable(self): m1 = self.nullable_model.objects.create() - m1 = self._refresh_if_needed(m1) none_val = "" if connection.features.interprets_empty_strings_as_nulls else None - self.assertEqual(m1.lower_name, none_val) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m1.lower_name, none_val) m2 = self.nullable_model.objects.create(name="NaMe") - m2 = self._refresh_if_needed(m2) - self.assertEqual(m2.lower_name, "name") + with self.assertNumQueries(expected_num_queries): + self.assertEqual(m2.lower_name, "name") @skipUnlessDBFeature("supports_stored_generated_columns") @@ -354,8 +393,21 @@ class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase): def test_create_field_with_db_converters(self): obj = GeneratedModelFieldWithConverters.objects.create(field=uuid.uuid4()) - obj = self._refresh_if_needed(obj) - self.assertEqual(obj.field, obj.field_copy) + expected_num_queries = ( + 0 if connection.features.can_return_columns_from_insert else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(obj.field, obj.field_copy) + + def test_save_field_with_db_converters(self): + obj = GeneratedModelFieldWithConverters.objects.create(field=uuid.uuid4()) + obj.field = uuid.uuid4() + expected_num_queries = ( + 0 if connection.features.can_return_rows_from_update else 1 + ) + obj.save(update_fields={"field"}) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(obj.field, obj.field_copy) def test_create_with_non_auto_pk(self): obj = GeneratedModelNonAutoPk.objects.create(id=1, a=2) diff --git a/tests/update_only_fields/tests.py b/tests/update_only_fields/tests.py index 9595c767ebca..1c7ef88832b7 100644 --- a/tests/update_only_fields/tests.py +++ b/tests/update_only_fields/tests.py @@ -1,5 +1,6 @@ from django.core.exceptions import ObjectNotUpdated -from django.db import DatabaseError, transaction +from django.db import DatabaseError, connection, transaction +from django.db.models import F from django.db.models.signals import post_save, pre_save from django.test import TestCase @@ -308,3 +309,16 @@ def test_update_fields_not_updated(self): transaction.atomic(), ): obj.save(update_fields=["name"]) + + def test_update_fields_expression(self): + obj = Person.objects.create(name="Valerie", gender="F", pid=42) + updated_pid = F("pid") + 1 + obj.pid = updated_pid + obj.save(update_fields={"gender"}) + self.assertIs(obj.pid, updated_pid) + obj.save(update_fields={"pid"}) + expected_num_queries = ( + 0 if connection.features.can_return_rows_from_update else 1 + ) + with self.assertNumQueries(expected_num_queries): + self.assertEqual(obj.pid, 43)