From 1bd66d35b28e49b9b304eb74bf9f53166749213f Mon Sep 17 00:00:00 2001 From: Paul Gammans Date: Tue, 23 Dec 2025 15:23:18 +0000 Subject: [PATCH 01/14] add select_related and partial prefetch_related support implement support for a single query for select related base fetches across polymorphic models. adds a polymorphic QuerySet Mixin to enable non polymorphic models to fetch related models. fixes: #198 #436 #359 #244 possible fixes: #498: support for prefetch_related cannot fetch attributes not on all child models or via class names related: #531 --- src/polymorphic/query.py | 533 ++++++++++++++++- .../tests/migrations/0001_initial.py | 129 ++++- src/polymorphic/tests/models.py | 67 ++- src/polymorphic/tests/test_orm.py | 544 +++++++++++++++++- 4 files changed, 1242 insertions(+), 31 deletions(-) diff --git a/src/polymorphic/query.py b/src/polymorphic/query.py index 8d582297..237e2226 100644 --- a/src/polymorphic/query.py +++ b/src/polymorphic/query.py @@ -3,15 +3,21 @@ """ import copy +import functools +import operator from collections import defaultdict from django.contrib.contenttypes.models import ContentType from django.core.exceptions import FieldDoesNotExist from django.db import connections, models from django.db.models import FilteredRelation -from django.db.models.query import ModelIterable, Q, QuerySet +from django.db.models.constants import LOOKUP_SEP +from django.db.models.query import ModelIterable, Q, QuerySet, RelatedPopulator from .query_translate import ( + _create_base_path, + _get_all_sub_models, + _get_query_related_name, translate_polymorphic_field_path, translate_polymorphic_filter_definitions_in_args, translate_polymorphic_filter_definitions_in_kwargs, @@ -25,31 +31,252 @@ """ -class PolymorphicModelIterable(ModelIterable): - """ - ModelIterable for PolymorphicModel +def merge_dicts(primary, secondary): + """Deep merge two dicts - Yields real instances if qs.polymorphic_disabled is False, - otherwise acts like a regular ModelIterable. + Items from the primary dict are preserved in preference to those on the + secondary dict""" + + for k, v in secondary.items(): + if k in primary: + primary[k] = merge_dicts(primary[k], v) + else: + primary[k] = copy.deepcopy(v) + return primary + + +def search_object_cache(obj, source_model, target_model): + for search_part in _create_base_path(source_model, target_model).split("__"): + try: + obj = obj._state.fields_cache[search_part] + except KeyError: + return + return obj + + +class VanillaRelatedPopulator(RelatedPopulator): + def __init__(self, klass_info, select, db): + super().__init__(klass_info, select, db) + self.field = klass_info["field"] + self.reverse = klass_info["reverse"] + + def build_related(self, row, from_obj, *_): + self.populate(row, from_obj) + + +class RelatedPolymorphicPopulator: + """ + RelatedPopulator is used for select_related() object instantiation. + The idea is that each select_related() model will be populated by a + different RelatedPopulator instance. The RelatedPopulator instances get + klass_info and select (computed in SQLCompiler) plus the used db as + input for initialization. That data is used to compute which columns + to use, how to instantiate the model, and how to populate the links + between the objects. + The actual creation of the objects is done in populate() method. This + method gets row and from_obj as input and populates the select_related() + model instance. """ - def __iter__(self): - base_iter = super().__iter__() - if self.queryset.polymorphic_disabled: - return base_iter - return self._polymorphic_iterator(base_iter) + def __init__(self, klass_info, select, db): + self.db = db + # Pre-compute needed attributes. The attributes are: + # - model_cls: the possibly deferred model class to instantiate + # - either: + # - cols_start, cols_end: usually the columns in the row are + # in the same order model_cls.__init__ expects them, so we + # can instantiate by model_cls(*row[cols_start:cols_end]) + # - reorder_for_init: When select_related descends to a child + # class, then we want to reuse the already selected parent + # data. However, in this case the parent data isn't necessarily + # in the same order that Model.__init__ expects it to be, so + # we have to reorder the parent data. The reorder_for_init + # attribute contains a function used to reorder the field data + # in the order __init__ expects it. + # - pk_idx: the index of the primary key field in the reordered + # model data. Used to check if a related object exists at all. + # - init_list: the field attnames fetched from the database. For + # deferred models this isn't the same as all attnames of the + # model's fields. + # - related_populators: a list of RelatedPopulator instances if + # select_related() descends to related models from this model. + # - local_setter, remote_setter: Methods to set cached values on + # the object being populated and on the remote object. Usually + # these are Field.set_cached_value() methods. + select_fields = klass_info["select_fields"] + from_parent = klass_info["from_parent"] + if not from_parent: + self.cols_start = select_fields[0] + self.cols_end = select_fields[-1] + 1 + self.init_list = [f[0].target.attname for f in select[self.cols_start : self.cols_end]] + self.reorder_for_init = None + else: + attname_indexes = {select[idx][0].target.attname: idx for idx in select_fields} + model_init_attnames = (f.attname for f in klass_info["model"]._meta.concrete_fields) + self.init_list = [ + attname for attname in model_init_attnames if attname in attname_indexes + ] + self.reorder_for_init = operator.itemgetter( + *[attname_indexes[attname] for attname in self.init_list] + ) - def _polymorphic_iterator(self, base_iter): - """ - Here we do the same as:: + self.model_cls = klass_info["model"] + self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname) + self.related_populators = get_related_populators(klass_info, select, self.db) + self.local_setter = klass_info["local_setter"] + self.remote_setter = klass_info["remote_setter"] + self.field = klass_info["field"] + self.reverse = klass_info["reverse"] + self.content_type_manager = ContentType.objects.db_manager(self.db) + self.model_class_id = self.content_type_manager.get_for_model( + self.model_cls, for_concrete_model=False + ).pk + self.concrete_model_class_id = self.content_type_manager.get_for_model( + self.model_cls, for_concrete_model=True + ).pk + + def build_related(self, row, from_obj, post_actions): + if self.reorder_for_init: + obj_data = self.reorder_for_init(row) + else: + obj_data = row[self.cols_start : self.cols_end] + + if obj_data[self.pk_idx] is None: + obj = None + else: + obj = self.model_cls.from_db(self.db, self.init_list, obj_data) + self.post_build_modify( + obj, + from_obj, + post_actions, + functools.partial(self._populate, row, from_obj, post_actions), + ) + + def _populate(self, row, from_obj, post_actions, obj): + for rel_iter in self.related_populators: + rel_iter.build_related(row, obj, post_actions) + + self.local_setter(from_obj, obj) + if obj is not None: + self.remote_setter(obj, from_obj) + + def post_build_modify(self, base_object, from_obj, post_actions, populate_fn): + if base_object.polymorphic_ctype_id == self.model_class_id: + # Real class is exactly the same as base class, go straight to results + populate_fn(base_object) + else: + real_concrete_class = base_object.get_real_instance_class() + real_concrete_class_id = base_object.get_real_concrete_instance_class_id() + + if real_concrete_class_id is None: + # Dealing with a stale content type + populate_fn(None) + return False + elif real_concrete_class_id == self.concrete_model_class_id: + # Real and base classes share the same concrete ancestor, + # upcast it and put it in the results + populate_fn(transmogrify(real_concrete_class, base_object)) + return False + else: + # This model has a concrete derived class: either track it for bulk + # retrieval or if it is already fetched as part of a select_related + # enable pivoting to that object + real_concrete_class = self.content_type_manager.get_for_id( + real_concrete_class_id + ).model_class() + populate_fn(base_object) + post_actions.append( + ( + functools.partial( + self.pivot_onto_cached_subclass, + from_obj, + base_object, + real_concrete_class, + ), + populate_fn, + ) + ) - real_results = queryset._get_real_instances(list(base_iter)) - for o in real_results: yield o + def pivot_onto_cached_subclass(self, from_obj, obj, model_target_cls): + """Pivot to final polymorphic class. - but it requests the objects in chunks from the database, - with QuerySet.iterator(chunk_size) per chunk + Pivot the object created from the base query onto the true polymorphic + instance, we need to ensure that this is only done on objects that are + from non parent-child type relationships. + + If we cannot pivot we return info to be used in the PolymorphicModelIterable + to ensure the correct model loaded from the additional bulk queries """ + original = obj + parents = model_target_cls()._get_inheritance_relation_fields_and_models() + for cls in reversed(model_target_cls.mro()[: -len(self.model_cls.mro())]): + for rel_iter in self.related_populators: + if not isinstance( + rel_iter, (VanillaRelatedPopulator, RelatedPolymorphicPopulator) + ): + continue + if rel_iter.reverse and rel_iter.model_cls is cls: + if rel_iter.field.name in parents.keys(): + obj = getattr(obj, rel_iter.field.remote_field.name) + + if not isinstance(obj, model_target_cls): + # This allow pivoting of object that are descendants of the original field + if not original._meta.get_path_to_parent(from_obj._meta.model): + obj = search_object_cache(original, original._meta.model, model_target_cls) + + if isinstance(obj, model_target_cls): + # We only want to pivot onto a field from a different object, ie not a parent/child + # relationship as this will break the cache and other object relationships + if not original._meta.get_path_to_parent(from_obj._meta.model): + self.local_setter(from_obj, obj) + if obj is not None: + self.remote_setter(obj, from_obj) + return None, None + + pk_name = self.model_cls.polymorphic_primary_key_name + return model_target_cls, (getattr(original, pk_name), self.field.name) + + +def get_related_populators(klass_info, select, db): + from .models import PolymorphicModel + + iterators = [] + related_klass_infos = klass_info.get("related_klass_infos", []) + for rel_klass_info in related_klass_infos: + model = rel_klass_info["model"] + if issubclass(model, PolymorphicModel): + rel_cls = RelatedPolymorphicPopulator(rel_klass_info, select, db) + else: + rel_cls = VanillaRelatedPopulator(rel_klass_info, select, db) + iterators.append(rel_cls) + return iterators + + +class PolymorphicModelIterable(ModelIterable): + """ + ModelIterable for PolymorphicModel + + Yields real instances if qs.polymorphic_disabled is False, + otherwise acts like a regular ModelIterable. We inherit from + ModelIterable non base BaseIterable even though we completely + replace it, but this allows Django test in Prefetch to work + """ + + def __iter__(self): + queryset = self.queryset + db = queryset.db + compiler = queryset.query.get_compiler(using=db) + # Execute the query. This will also fill compiler.select, klass_info, + # and annotations. + results = compiler.execute_sql( + chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size + ) + select, klass_info, annotation_col_map = ( + compiler.select, + compiler.klass_info, + compiler.annotation_col_map, + ) # some databases have a limit on the number of query parameters, we must # respect this for generating get_real_instances queries because those # queries do a large WHERE IN clause with primary keys @@ -64,24 +291,149 @@ def _polymorphic_iterator(self, base_iter): sql_chunk = sql_chunk or Polymorphic_QuerySet_objects_per_request + model_cls = klass_info["model"] + select_fields = klass_info["select_fields"] + model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1 + init_list = [f[0].target.attname for f in select[model_fields_start:model_fields_end]] + related_populators = get_related_populators(klass_info, select, db) + known_related_objects = [ + ( + field, + related_objs, + operator.attrgetter( + *[ + field.attname + if from_field == "self" + else queryset.model._meta.get_field(from_field).attname + for from_field in field.from_fields + ] + ), + ) + for field, related_objs in queryset._known_related_objects.items() + ] + base_iter = compiler.results_iter(results) while True: + result_objects = [] base_result_objects = [] reached_end = False # Fetch in chunks - for _ in range(sql_chunk): + post_actions = list() + for i in range(sql_chunk): + # dict contains one entry per unique model type occurring in result, + # in the format idlist_per_model[modelclass]=[list-of-object-ids] try: - o = next(base_iter) - base_result_objects.append(o) + row = next(base_iter) + obj = model_cls.from_db( + db, init_list, row[model_fields_start:model_fields_end] + ) + for rel_populator in related_populators: + rel_populator.build_related(row, obj, post_actions) + base_result_objects.append([row, obj]) except StopIteration: reached_end = True break - yield from self.queryset._get_real_instances(base_result_objects) + if not self.queryset.polymorphic_disabled: + self.fetch_polymorphic(post_actions, base_result_objects) + + for row, obj in base_result_objects: + if annotation_col_map: + for attr_name, col_pos in annotation_col_map.items(): + setattr(obj, attr_name, row[col_pos]) + + # Add the known related objects to the model. + for field, rel_objs, rel_getter in known_related_objects: + # Avoid overwriting objects loaded by, e.g., select_related(). + if field.is_cached(obj): + continue + rel_obj_id = rel_getter(obj) + try: + rel_obj = rel_objs[rel_obj_id] + except KeyError: + pass # May happen in qs1 | qs2 scenarios. + else: + setattr(obj, field.name, rel_obj) + result_objects.append(obj) + + if not self.queryset.polymorphic_disabled: + result_objects = self.queryset._get_real_instances(result_objects) + + for o in result_objects: + yield o if reached_end: return + def apply_select_related(self, qs, relations): + if self.queryset.query.select_related is True: + return qs.select_related() + + model_name = qs.model.__name__.lower() + if isinstance(self.queryset.query.select_related, dict): + select_related = {} + if isinstance(qs.query.select_related, dict): + select_related = qs.query.select_related + for k, v in self.queryset.query.select_related.items(): + if k in relations: + if not isinstance(select_related, dict): + select_related = {} + if isinstance(v, dict): + if model_name in v: + select_related = merge_dicts(select_related, v[model_name]) + else: + for field in qs.model._meta.fields: + if field.name in v: + select_related = merge_dicts(select_related, v[field.name]) + else: + select_related = merge_dicts(select_related, v) + qs.query.select_related = select_related + return qs + + def fetch_polymorphic(self, post_actions, base_result_objects): + update_fn_per_model = defaultdict(list) + idlist_per_model = defaultdict(list) + + for action, populate_fn in post_actions: + target_class, pk_info = action() + if target_class: + pk, name = pk_info + idlist_per_model[target_class].append((pk, name)) + update_fn_per_model[target_class].append((populate_fn, pk)) + + # For each model in "idlist_per_model" request its objects (the real model) + # from the db and store them in results[]. + # Then we copy the annotate fields from the base objects to the real objects. + # Then we copy the extra() select fields from the base objects to the real objects. + # TODO: defer(), only(): support for these would be around here + for real_concrete_class, data in idlist_per_model.items(): + idlist, names = zip(*data) + updates = update_fn_per_model[real_concrete_class] + pk_name = real_concrete_class.polymorphic_primary_key_name + real_objects = real_concrete_class._base_objects.db_manager(self.queryset.db).filter( + **{("%s__in" % pk_name): idlist} + ) + + real_objects = self.apply_select_related(real_objects, set(names)) + real_objects_dict = { + getattr(real_object, pk_name): real_object for real_object in real_objects + } + + for populate_fn, o_pk in updates: + real_object = real_objects_dict.get(o_pk) + if real_object is None: + continue + + # need shallow copy to avoid duplication in caches (see PR #353) + real_object = copy.copy(real_object) + real_class = real_object.get_real_instance_class() + + # If the real class is a proxy, upcast it + if real_class != real_concrete_class: + real_object = transmogrify(real_class, real_object) + + populate_fn(real_object) + def transmogrify(cls, obj): """ @@ -103,7 +455,64 @@ def transmogrify(cls, obj): # PolymorphicQuerySet -class PolymorphicQuerySet(QuerySet): +class PolymorphicQuerySetMixin(QuerySet): + def select_related(self, *fields): + if fields == (None,) or not len(fields): + return super().select_related(*fields) + field_with_poly = list(self.convert_related_fieldnames(fields)) + return super().select_related(*field_with_poly) + + def _convert_field_name_part(self, field_parts, model): + """ + recursively convert a fieldname into (model, filedname) + """ + field = None + part = field_parts[0] + next_parts = field_parts[1:] + field_path = [] + rel_model = None + try: + field = model._meta.get_field(part) + field_path = [part] + yield field_path + + if field.is_relation: + rel_model = field.related_model + if next_parts: + self._convert_field_name_part(next_parts, rel_model) + else: + rel_model = model + + except FieldDoesNotExist: + submodels = _get_all_sub_models(model) + rel_model = submodels.get(part, None) + field_path = list(_create_base_path(model, rel_model).split("__")) + for field_part_idx in range(0, len(field_path)): + yield field_path[0 : 1 + field_part_idx] + + if next_parts: + child_selectors = self._convert_field_name_part(next_parts, rel_model) + for selector in child_selectors: + all_field_path = field_path + selector + for field_part_idx in range(0, len(all_field_path)): + yield all_field_path[0 : 1 + field_part_idx] + + def convert_related_fieldnames(self, fields, opts=None): + """ + convert the field name which may contain polymorphic models names into + raw filed names that can be used with django select_related and + prefetch_related. + """ + if not opts: + opts = self.model + for field_name in fields: + field_parts = field_name.split(LOOKUP_SEP) + selectors = self._convert_field_name_part(field_parts, opts) + for selector in selectors: + yield "__".join(selector) + + +class PolymorphicQuerySet(PolymorphicQuerySetMixin, QuerySet): """ QuerySet for PolymorphicModel @@ -412,22 +821,37 @@ class self.model, but as a class derived from self.model. We want to re-fetch real_concrete_class = content_type_manager.get_for_id( real_concrete_class_id ).model_class() - idlist_per_model[real_concrete_class].append(getattr(base_object, pk_name)) - indexlist_per_model[real_concrete_class].append((i, len(resultlist))) - resultlist.append(None) + + cached_obj = search_object_cache(base_object, self.model, real_concrete_class) + if cached_obj: + resultlist.append(cached_obj) + else: + idlist_per_model[real_concrete_class].append(getattr(base_object, pk_name)) + indexlist_per_model[real_concrete_class].append((i, len(resultlist))) + resultlist.append(None) # For each model in "idlist_per_model" request its objects (the real model) # from the db and store them in results[]. # Then we copy the annotate fields from the base objects to the real objects. # Then we copy the extra() select fields from the base objects to the real objects. # TODO: defer(), only(): support for these would be around here + # Also see PolymorphicModelIterable.fetch_polymorphic + + filter_relations = [ + _get_query_related_name(mdl_cls) + for mdl_cls in _get_all_sub_models(self.model).values() + ] + for real_concrete_class, idlist in idlist_per_model.items(): indices = indexlist_per_model[real_concrete_class] real_objects = real_concrete_class._base_objects.db_manager(self.db).filter( **{(f"{pk_name}__in"): idlist} ) # copy select related configuration to new qs - real_objects.query.select_related = self.query.select_related + current_relation = real_objects.model.__name__.lower() + real_objects = self.apply_select_related( + real_objects, current_relation, filter_relations + ) # Copy deferred fields configuration to the new queryset deferred_loading_fields = [] @@ -516,6 +940,37 @@ class self.model, but as a class derived from self.model. We want to re-fetch return resultlist + def apply_select_related(self, qs, relation, filtered): + if self.query.select_related is True: + return qs.select_related() + + model_name = qs.model.__name__.lower() + if isinstance(self.query.select_related, dict): + select_related = {} + if isinstance(qs.query.select_related, dict): + select_related = qs.query.select_related + for k, v in self.query.select_related.items(): + if k in filtered and k != relation: + continue + else: + if not isinstance(select_related, dict): + select_related = {} + if k == relation: + if isinstance(v, dict): + if model_name in v: + select_related = merge_dicts(select_related, v[model_name]) + else: + for field in qs.model._meta.fields: + if field.name in v: + select_related = merge_dicts(select_related, v[field.name]) + else: + select_related = merge_dicts(select_related, v) + else: + select_related[k] = v + + qs.query.select_related = select_related + return qs + def __repr__(self, *args, **kwargs): if self.model.polymorphic_query_multiline_output: result = ",\n ".join(repr(o) for o in self.all()) @@ -557,3 +1012,27 @@ def delete(self): disrupts the model hierarchy/relationship traversal. """ return QuerySet.delete(self.non_polymorphic()) + + +################################################################################### +# PolymorphicRelatedQuerySet + + +class PolymorphicRelatedQuerySetMixin(PolymorphicQuerySetMixin): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._iterable_class = PolymorphicModelIterable + self.polymorphic_disabled = False + + def _clone(self, *args, **kwargs): + # Django's _clone only copies its own variables, so we need to copy ours here + new = super()._clone(*args, **kwargs) + new.polymorphic_disabled = self.polymorphic_disabled + return new + + def _get_real_instances(self, base_result_objects): + return base_result_objects + + +class PolymorphicRelatedQuerySet(PolymorphicRelatedQuerySetMixin, QuerySet): + pass diff --git a/src/polymorphic/tests/migrations/0001_initial.py b/src/polymorphic/tests/migrations/0001_initial.py index 99c9ef8e..8686900d 100644 --- a/src/polymorphic/tests/migrations/0001_initial.py +++ b/src/polymorphic/tests/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2 on 2025-12-21 10:54 +# Generated by Django 4.2 on 2025-12-23 09:09 from django.conf import settings from django.db import migrations, models @@ -267,6 +267,20 @@ class Migration(migrations.Migration): 'base_manager_name': 'objects', }, ), + migrations.CreateModel( + name='NonSymRelationBase', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('field_base', models.CharField(max_length=10)), + ('fk', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='relationbase_set', to='tests.nonsymrelationbase')), + ('m2m', models.ManyToManyField(to='tests.nonsymrelationbase')), + ('polymorphic_ctype', models.ForeignKey(editable=False, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='polymorphic_%(app_label)s.%(class)s_set+', to='contenttypes.contenttype')), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + ), migrations.CreateModel( name='NormalBase', fields=[ @@ -287,6 +301,18 @@ class Migration(migrations.Migration): 'base_manager_name': 'objects', }, ), + migrations.CreateModel( + name='ParentModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=10)), + ('polymorphic_ctype', models.ForeignKey(editable=False, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='polymorphic_%(app_label)s.%(class)s_set+', to='contenttypes.contenttype')), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + ), migrations.CreateModel( name='Participant', fields=[ @@ -387,6 +413,19 @@ class Migration(migrations.Migration): }, bases=(polymorphic.showfields.ShowFieldTypeAndContent, models.Model), ), + migrations.CreateModel( + name='AltChildModel', + fields=[ + ('parentmodel_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.parentmodel')), + ('other_name', models.CharField(max_length=10)), + ('link_on_altchild', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='+', to='tests.plaina')), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.parentmodel',), + ), migrations.CreateModel( name='BlogA', fields=[ @@ -710,6 +749,42 @@ class Migration(migrations.Migration): }, bases=('tests.proxybase',), ), + migrations.CreateModel( + name='NonSymRelationA', + fields=[ + ('nonsymrelationbase_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.nonsymrelationbase')), + ('field_a', models.CharField(max_length=10)), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.nonsymrelationbase',), + ), + migrations.CreateModel( + name='NonSymRelationB', + fields=[ + ('nonsymrelationbase_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.nonsymrelationbase')), + ('field_b', models.CharField(max_length=10)), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.nonsymrelationbase',), + ), + migrations.CreateModel( + name='NonSymRelationBC', + fields=[ + ('nonsymrelationbase_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.nonsymrelationbase')), + ('field_c', models.CharField(max_length=10)), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.nonsymrelationbase',), + ), migrations.CreateModel( name='NormalExtension', fields=[ @@ -927,6 +1002,21 @@ class Migration(migrations.Migration): }, bases=(polymorphic.showfields.ShowFieldTypeAndContent, models.Model), ), + migrations.CreateModel( + name='PlainModelWithM2M', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('field1', models.CharField(max_length=10)), + ('m2m', models.ManyToManyField(to='tests.parentmodel')), + ], + ), + migrations.CreateModel( + name='PlainModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('relation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='tests.parentmodel')), + ], + ), migrations.CreateModel( name='PlainChildModelWithManager', fields=[ @@ -1269,6 +1359,18 @@ class Migration(migrations.Migration): }, bases=('tests.subclassselectorproxybasemodel',), ), + migrations.CreateModel( + name='AltChildAsBaseModel', + fields=[ + ('altchildmodel_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.altchildmodel')), + ('more_name', models.CharField(max_length=10)), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.altchildmodel',), + ), migrations.CreateModel( name='Bottom', fields=[ @@ -1437,6 +1539,19 @@ class Migration(migrations.Migration): }, bases=('tests.inlinemodela',), ), + migrations.CreateModel( + name='ChildModel', + fields=[ + ('parentmodel_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.parentmodel')), + ('other_name', models.CharField(max_length=10)), + ('link_on_child', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='+', to='tests.modelextraexternal')), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.parentmodel',), + ), migrations.CreateModel( name='BlogEntry', fields=[ @@ -1497,6 +1612,18 @@ class Migration(migrations.Migration): }, bases=('tests.uuidartprojecta',), ), + migrations.CreateModel( + name='AltChildWithM2MModel', + fields=[ + ('altchildmodel_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.altchildmodel')), + ('m2m', models.ManyToManyField(to='tests.plaina')), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.altchildmodel',), + ), migrations.CreateModel( name='UUIDArtProjectC', fields=[ diff --git a/src/polymorphic/tests/models.py b/src/polymorphic/tests/models.py index 1bb10380..4e0a0595 100644 --- a/src/polymorphic/tests/models.py +++ b/src/polymorphic/tests/models.py @@ -10,7 +10,11 @@ from polymorphic.managers import PolymorphicManager from polymorphic.models import PolymorphicModel -from polymorphic.query import PolymorphicQuerySet +from polymorphic.query import ( + PolymorphicQuerySet, + PolymorphicRelatedQuerySetMixin, + PolymorphicRelatedQuerySet, +) from polymorphic.showfields import ShowFieldContent, ShowFieldType, ShowFieldTypeAndContent @@ -781,3 +785,64 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) self.old_status_id = self.status_id + + +class NonSymRelationBase(PolymorphicModel): + field_base = models.CharField(max_length=10) + fk = models.ForeignKey( + "self", on_delete=models.CASCADE, null=True, related_name="relationbase_set" + ) + m2m = models.ManyToManyField("self", symmetrical=False) + + +class NonSymRelationA(NonSymRelationBase): + field_a = models.CharField(max_length=10) + + +class NonSymRelationB(NonSymRelationBase): + field_b = models.CharField(max_length=10) + + +class NonSymRelationBC(NonSymRelationBase): + field_c = models.CharField(max_length=10) + + +class CustomPolySupportingQuerySet(PolymorphicRelatedQuerySetMixin, models.QuerySet): + pass + + +class ParentModel(PolymorphicModel): + name = models.CharField(max_length=10) + + +class ChildModel(ParentModel): + other_name = models.CharField(max_length=10) + link_on_child = models.ForeignKey( + ModelExtraExternal, on_delete=models.CASCADE, null=True, related_name="+" + ) + + +class AltChildModel(ParentModel): + other_name = models.CharField(max_length=10) + link_on_altchild = models.ForeignKey( + PlainA, on_delete=models.CASCADE, null=True, related_name="+" + ) + + +class AltChildAsBaseModel(AltChildModel): + more_name = models.CharField(max_length=10) + + +class PlainModel(models.Model): + relation = models.ForeignKey(ParentModel, on_delete=models.CASCADE) + objects = models.Manager.from_queryset(PolymorphicRelatedQuerySet)() + + +class PlainModelWithM2M(models.Model): + field1 = models.CharField(max_length=10) + m2m = models.ManyToManyField(ParentModel) + objects = models.Manager.from_queryset(PolymorphicRelatedQuerySet)() + + +class AltChildWithM2MModel(AltChildModel): + m2m = models.ManyToManyField(PlainA) diff --git a/src/polymorphic/tests/test_orm.py b/src/polymorphic/tests/test_orm.py index 412422cd..1a065506 100644 --- a/src/polymorphic/tests/test_orm.py +++ b/src/polymorphic/tests/test_orm.py @@ -1,11 +1,22 @@ import pytest import uuid - +from unittest import expectedFailure from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.db import models, connection -from django.db.models import Case, Count, FilteredRelation, Q, Sum, When, Exists, OuterRef +from django.db.models import ( + Case, + Count, + FilteredRelation, + Prefetch, + Q, + Sum, + When, + Exists, + OuterRef, +) from django.db.utils import IntegrityError, NotSupportedError + from django.test import TransactionTestCase from django.test.utils import CaptureQueriesContext @@ -13,6 +24,9 @@ from polymorphic.managers import PolymorphicManager from polymorphic.models import PolymorphicTypeInvalid, PolymorphicTypeUndefined from polymorphic.tests.models import ( + AltChildAsBaseModel, + AltChildModel, + AltChildWithM2MModel, ArtProject, Base, BlogA, @@ -20,6 +34,7 @@ BlogBase, BlogEntry, BlogEntry_limit_choices_to, + ChildModel, ChildModelWithManager, CustomPkBase, CustomPkInherit, @@ -59,13 +74,20 @@ MyManagerQuerySet, NonPolymorphicParent, NonProxyChild, + NonSymRelationA, + NonSymRelationB, + NonSymRelationBase, + NonSymRelationBC, One2OneRelatingModel, One2OneRelatingModelDerived, + ParentModel, ParentModelWithManager, PlainA, PlainB, PlainC, PlainChildModelWithManager, + PlainModel, + PlainModelWithM2M, PlainMyManager, PlainMyManagerQuerySet, PlainParentModelWithManager, @@ -1967,3 +1989,521 @@ def test_infinite_recursion_with_only(self): RecursionBug.objects.filter(id=item.id).update(status=closed) item.refresh_from_db(fields=("status",)) assert item.status == closed + + def test_normal_django_to_poly_related_give_poly_type(self): + obj1 = ParentModel.objects.create(name="m1") + obj2 = ChildModel.objects.create(name="m2", other_name="m2") + obj3 = ChildModel.objects.create(name="m1") + + PlainModel.objects.create(relation=obj1) + PlainModel.objects.create(relation=obj2) + PlainModel.objects.create(relation=obj3) + + ct = ContentType.objects.get_for_model(AltChildModel, for_concrete_model=True) + + with self.assertNumQueries(6): + # Queries will be + # * 1 for All PlainModels object (1) + # * 1 for each relations ParentModel (4) + # * 1 for each relations ChilModel is needed (3) + multi_q = [ + # these obj.relation values will have their proper sub type + obj.relation + for obj in PlainModel.objects.all() + ] + multi_q_types = [type(obj) for obj in multi_q] + + with self.assertNumQueries(2): + grouped_q = [ + # these obj.relation values will all be ParentModel's + # unless we fix select related but should be their proper + # sub type by using PolymorphicRelatedQuerySetMixin + obj.relation + for obj in PlainModel.objects.select_related("relation") + ] + grouped_q_types = [type(obj) for obj in grouped_q] + + self.assertListEqual(multi_q_types, grouped_q_types) + self.assertListEqual(grouped_q, [obj1, obj2, obj3]) + + def test_normal_django_to_poly_related_give_poly_type_using_select_related_true(self): + obj1 = ParentModel.objects.create(name="m1") + obj2 = ChildModel.objects.create(name="m2", other_name="m2") + obj3 = ChildModel.objects.create(name="m1") + obj4 = AltChildAsBaseModel.objects.create( + name="ac2", other_name="ac2name", more_name="ac2morename" + ) + + PlainModel.objects.create(relation=obj1) + PlainModel.objects.create(relation=obj2) + PlainModel.objects.create(relation=obj3) + PlainModel.objects.create(relation=obj4) + + with self.assertNumQueries(8): + # Queries will be + # * 1 for All PlainModels object (x1) + # * 1 for each relations ParentModel (x4) + # * 1 for each relations ChildModel is needed (x2) + # * 1 for each relations AltChildAsBaseModel is needed (x1) + multi_q = [ + # these obj.relation values will have their proper sub type + obj.relation + for obj in PlainModel.objects.all() + ] + multi_q_types = [type(obj) for obj in multi_q] + + with self.assertNumQueries(3): + grouped_q = [ + # these obj.relation values will all be ParentModel's + # unless we fix select related but should be their proper + # sub type by using PolymorphicRelatedQuerySetMixin + # ATM: we require 1 query fro each type. Although this can + # be reduced by specifying the relations to the polymorphic + # classes. BUT this has the downside of making the query have + # a large number of joins + obj.relation + for obj in PlainModel.objects.select_related() + ] + grouped_q_types = [type(obj) for obj in grouped_q] + + self.assertListEqual(multi_q_types, grouped_q_types) + self.assertListEqual(grouped_q, [obj1, obj2, obj3, obj4]) + + def test_prefetch_base_load_359(self): + obj1_1 = ModelShow1_plain.objects.create(field1="1") + obj2_1 = ModelShow2_plain.objects.create(field1="2", field2="1") + obj3_2 = ModelShow2_plain.objects.create(field1="3", field2="2") + + with self.assertNumQueries(1): + obj = ModelShow2_plain.objects.filter(pk=obj2_1.pk)[0] + _ = (obj.field1, obj.field1) + + def test_select_related_on_poly_classes(self): + plain_a_obj_1 = PlainA.objects.create(field1="f1") + plain_a_obj_2 = PlainA.objects.create(field1="f2") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_ac1 = AltChildModel.objects.create( + name="ac1", other_name="ac1name", link_on_altchild=plain_a_obj_1 + ) + obj_ac2 = AltChildModel.objects.create( + name="ac2", other_name="ac2name", link_on_altchild=plain_a_obj_2 + ) + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_ac1) + obj_p_4 = PlainModel.objects.create(relation=obj_ac2) + + with self.assertNumQueries(1): + # pos 3 if i cannot do optimized select_related + obj_list = list( + PlainModel.objects.select_related( + "relation", + "relation__childmodel__link_on_child", + "relation__altchildmodel__link_on_altchild", + ).order_by("pk") + ) + with self.assertNumQueries(0): + self.assertEqual(obj_list[0].relation.name, "p1") + self.assertEqual(obj_list[1].relation.name, "c1") + self.assertEqual(obj_list[2].relation.name, "ac1") + self.assertEqual(obj_list[3].relation.name, "ac2") + obj_list[1].relation.link_on_child + obj_list[2].relation.link_on_altchild + obj_list[3].relation.link_on_altchild + + def test_select_related_on_poly_classes_supports_multi_level_inheritance(self): + plain_a_obj_1 = PlainA.objects.create(field1="f1") + plain_a_obj_2 = PlainA.objects.create(field1="f2") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_ac1 = AltChildModel.objects.create( + name="ac1", other_name="ac1name", link_on_altchild=plain_a_obj_1 + ) + obj_acab2 = AltChildAsBaseModel.objects.create( + name="ac2ab", + other_name="acab2name", + more_name="acab2morename", + link_on_altchild=plain_a_obj_2, + ) + + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_ac1) + obj_p_4 = PlainModel.objects.create(relation=obj_acab2) + + with self.assertNumQueries(1): + # pos 3 if i cannot do optimized select_related + obj_list = list( + PlainModel.objects.select_related( + "relation", + "relation__childmodel__link_on_child", + "relation__altchildmodel__link_on_altchild", + "relation__altchildmodel__altchildasbasemodel__link_on_altchild", + ).order_by("pk") + ) + with self.assertNumQueries(0): + self.assertEqual(obj_list[0].relation.name, "p1") + self.assertEqual(obj_list[1].relation.name, "c1") + self.assertEqual(obj_list[2].relation.name, "ac1") + self.assertEqual(obj_list[3].relation.name, "ac2ab") + self.assertEqual(obj_list[3].relation.more_name, "acab2morename") + obj_list[1].relation.link_on_child + obj_list[2].relation.link_on_altchild + obj_list[3].relation.link_on_altchild + + def test_select_related_on_poly_classes_with_modelname(self): + plain_a_obj_1 = PlainA.objects.create(field1="f1") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_acab2 = AltChildAsBaseModel.objects.create( + name="acab2", + other_name="acab2name", + more_name="acab2morename", + link_on_altchild=plain_a_obj_1, + ) + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_acab2) + + ct = ContentType.objects.get_for_model(AltChildModel, for_concrete_model=True) + + with self.assertNumQueries(1): + obj_list = list( + PlainModel.objects.select_related( + "relation", + "relation__ChildModel__link_on_child", + "relation__AltChildAsBaseModel__link_on_altchild", + ).order_by("pk") + ) + + with self.assertNumQueries(0): + self.assertEqual(obj_list[0].relation.name, "p1") + self.assertEqual(obj_list[1].relation.name, "c1") + self.assertEqual(obj_list[2].relation.name, "acab2") + obj_list[1].relation.link_on_child + obj_list[2].relation.link_on_altchild + + def test_prefetch_related_from_basepoly(self): + obja1 = NonSymRelationA.objects.create(field_a="fa1", field_base="fa1") + obja2 = NonSymRelationA.objects.create(field_a="fa2", field_base="fa2") + objb1 = NonSymRelationB.objects.create(field_b="fb1", field_base="fb1") + objbc1 = NonSymRelationBC.objects.create(field_c="fbc1", field_base="fbc1") + + obja3 = NonSymRelationA.objects.create(field_a="fa3", field_base="fa3") + # NOTE: these are symmetric links + obja3.m2m.add(obja2) + obja3.m2m.add(objb1) + obja2.m2m.add(objbc1) + + # NOTE: prefetch content types so query asserts test data fetched. + ct = ContentType.objects.get_for_model(NonSymRelationBase, for_concrete_model=True) + + with self.assertNumQueries(10): + # query for NonSymRelationBase (base) + # query for NonSymRelationA # level 1 (base) + # query for NonSymRelationB # level 1 (base) + # query for NonSymRelationBC # level 1 (base) + # query for prefetch links (m2m) + # query for NonSymRelationA # level 2 (m2m) + # query for NonSymRelationB # level 2 (m2m) + # query for NonSymRelationBC # level 2 (m2m) + # query for prefetch links (m2m__m2m) + # query for NonSymRelationA # level 3 (m2m__m2m) + # query for NonSymRelationB # level 3 (m2m__m2m) [SKIPPED AS NO DATA] + # query for NonSymRelationC # level 3 (m2m__m2m) [SKIPPED AS NO DATA] + + all_objs = { + obj.pk: obj + for obj in NonSymRelationBase.objects.prefetch_related("m2m", "m2m__m2m") + } + + with self.assertNumQueries(0): + relations = {obj.pk: set(obj.m2m.all()) for obj in all_objs.values()} + + with self.assertNumQueries(0): + sub_relations = {a.pk: set(a.m2m.all()) for a in all_objs.get(obja3.pk).m2m.all()} + + self.assertDictEqual( + { + obja1.pk: set(), + obja2.pk: set([objbc1]), + obja3.pk: set([obja2, objb1]), + objb1.pk: set([]), + objbc1.pk: set([]), + }, + relations, + ) + + self.assertDictEqual( + { + obja2.pk: set([objbc1]), + objb1.pk: set([]), + }, + sub_relations, + ) + + def test_prefetch_related_from_subclass(self): + obja1 = NonSymRelationA.objects.create(field_a="fa1", field_base="fa1") + obja2 = NonSymRelationA.objects.create(field_a="fa2", field_base="fa2") + objb1 = NonSymRelationB.objects.create(field_b="fb1", field_base="fb1") + objbc1 = NonSymRelationBC.objects.create(field_c="fbc1", field_base="fbc1") + + obja3 = NonSymRelationA.objects.create(field_a="fa3", field_base="fa3") + # NOTE: these are symmetric links + obja3.m2m.add(obja2) + obja3.m2m.add(objb1) + obja2.m2m.add(objbc1) + + # NOTE: prefetch content types so query asserts test data fetched. + ct = ContentType.objects.get_for_model(NonSymRelationBase, for_concrete_model=True) + + with self.assertNumQueries(7): + # query for NonSymRelationA # level 1 (base) + # query for prefetch links (m2m) + # query for NonSymRelationA # level 2 (m2m) + # query for NonSymRelationB # level 2 (m2m) + # query for NonSymRelationBC # level 2 (m2m) + # query for prefetch links (m2m__m2m) + # query for NonSymRelationA # level 3 (m2m__m2m) + # query for NonSymRelationB # level 3 (m2m__m2m) [SKIPPED AS NO DATA] + # query for NonSymRelationC # level 3 (m2m__m2m) [SKIPPED AS NO DATA] + + all_objs = { + obj.pk: obj for obj in NonSymRelationA.objects.prefetch_related("m2m", "m2m__m2m") + } + + with self.assertNumQueries(0): + relations = {obj.pk: set(obj.m2m.all()) for obj in all_objs.values()} + + with self.assertNumQueries(0): + sub_relations = {a.pk: set(a.m2m.all()) for a in all_objs.get(obja3.pk).m2m.all()} + + self.assertDictEqual( + { + obja1.pk: set(), + obja2.pk: set([objbc1]), + obja3.pk: set([obja2, objb1]), + }, + relations, + ) + + self.assertDictEqual( + { + obja2.pk: set([objbc1]), + objb1.pk: set([]), + }, + sub_relations, + ) + + def test_select_related_field_from_polymorphic_child_class(self): + # 198 + obj_p1 = ParentModel.objects.create(name="p1") + obj_p2 = ParentModel.objects.create(name="p2") + obj_p3 = ParentModel.objects.create(name="p4") + obj_c1 = ChildModel.objects.create(name="c1", other_name="c1name") + obj_c2 = ChildModel.objects.create(name="c2", other_name="c2name") + obj_ac1 = AltChildModel.objects.create(name="ac1", other_name="ac1name") + obj_ac2 = AltChildModel.objects.create(name="ac2", other_name="ac2name") + obj_ac3 = AltChildModel.objects.create(name="ac3", other_name="ac3name") + + with self.assertNumQueries(2): + # Queries will be + # * 1 for All ParentModel object (x4 +bases of all) + # * 1 for ChildModel object (x2) + # * 0 for AltChildModel object as from select_related (x3) + all_objs = [ + obj + for obj in ParentModel.objects.select_related( + "altchildmodel", + ) + ] + + def test_select_related_field_from_polymorphic_child_class_using_modelnames_level1(self): + # 198 + obj_p1 = ParentModel.objects.create(name="p1") + obj_p2 = ParentModel.objects.create(name="p2") + obj_p3 = ParentModel.objects.create(name="p4") + obj_c1 = ChildModel.objects.create(name="c1", other_name="c1name") + obj_c2 = ChildModel.objects.create(name="c2", other_name="c2name") + obj_ac1 = AltChildModel.objects.create(name="ac1", other_name="ac1name") + obj_ac2 = AltChildModel.objects.create(name="ac2", other_name="ac2name") + obj_ac3 = AltChildModel.objects.create(name="ac3", other_name="ac3name") + + with self.assertNumQueries(2): + # Queries will be + # * 1 for All ParentModel object (x4 +bases of all) + # * 1 for ChildModel object (x2) + # * 0 for AltChildModel object as from select_related (x3) + all_objs = [ + obj + for obj in ParentModel.objects.select_related( + "AltChildModel", + ) + ] + + def test_select_related_field_from_polymorphic_child_class_using_modelnames_multi_level(self): + plain_a_obj_1 = PlainA.objects.create(field1="f1") + + obj_p1 = ParentModel.objects.create(name="p1") + obj_acab2 = AltChildAsBaseModel.objects.create( + name="acab2", + other_name="acab2name", + more_name="acab2morename", + link_on_altchild=plain_a_obj_1, + ) + obj_c1 = ChildModel.objects.create(name="c1", other_name="c1name") + obj_ac3 = ChildModel.objects.create(name="c2", other_name="c3name") + + # NOTE: prefetch content types so query asserts test data fetched. + ct = ContentType.objects.get_for_model(AltChildModel, for_concrete_model=True) + + with self.assertNumQueries(2): + # Queries will be + # * 1 for All ParentModel object (x4 +bases of all) + # * 1 for ChildModel object (x1) + # * 0 for AltChildAsBaseModel object as from select_related (x1) + # * 0 for AltChildModel object as part of select_related form + # AltChildAsBaseModel (x1) + all_objs = [obj for obj in ParentModel.objects.select_related("AltChildAsBaseModel")] + + def test_prefetch_object_is_supported(self): + b1 = RelatingModel.objects.create() + b2 = RelatingModel.objects.create() + + rel1 = Model2A.objects.create(field1="A1") + rel2 = Model2B.objects.create(field1="A2", field2="B2") + + b1.many2many.add(rel1) + b2.many2many.add(rel2) + + rel2.delete(keep_parents=True) + + qs = RelatingModel.objects.order_by("pk").prefetch_related( + Prefetch("many2many", queryset=Model2A.objects.all(), to_attr="poly"), + Prefetch("many2many", queryset=Model2A.objects.non_polymorphic(), to_attr="non_poly"), + ) + + objects = list(qs) + self.assertEqual(len(objects[0].poly), 1) + + # derived object was not fetched + self.assertEqual(len(objects[1].poly), 0) + + # base object always found + self.assertEqual(len(objects[0].non_poly), 1) + self.assertEqual(len(objects[1].non_poly), 1) + + def test_select_related_on_poly_classes_preserves_on_relations_annotations(self): + b1 = RelatingModel.objects.create() + b2 = RelatingModel.objects.create() + b3 = RelatingModel.objects.create() + + rel1 = Model2A.objects.create(field1="A1") + rel2 = Model2B.objects.create(field1="A2", field2="B2") + + b1.many2many.add(rel1) + b2.many2many.add(rel2) + b3.many2many.add(rel2) + + qs = RelatingModel.objects.order_by("pk").prefetch_related( + Prefetch( + "many2many", + queryset=Model2A.objects.annotate(Count("relatingmodel")), + to_attr="poly", + ) + ) + + objects = list(qs) + self.assertEqual(objects[0].poly[0].relatingmodel__count, 1) + self.assertEqual(objects[1].poly[0].relatingmodel__count, 2) + self.assertEqual(objects[2].poly[0].relatingmodel__count, 2) + + @expectedFailure + def test_prefetch_loading_relation_only_on_some_poly_model(self): + plain_a_obj_1 = PlainA.objects.create(field1="p1") + plain_a_obj_2 = PlainA.objects.create(field1="p2") + plain_a_obj_3 = PlainA.objects.create(field1="p3") + plain_a_obj_4 = PlainA.objects.create(field1="p4") + plain_a_obj_5 = PlainA.objects.create(field1="p5") + + ac_m2m_obj = AltChildWithM2MModel.objects.create( + other_name="o1", + ) + ac_m2m_obj.m2m.set([plain_a_obj_1, plain_a_obj_2, plain_a_obj_3]) + + cm_1 = ChildModel.objects.create(other_name="c1") + cm_2 = ChildModel.objects.create(other_name="c2") + cm_3 = ChildModel.objects.create(other_name="c3") + + acm_1 = AltChildModel.objects.create(other_name="ac3", link_on_altchild=plain_a_obj_4) + acm_2 = AltChildModel.objects.create(other_name="ac3", link_on_altchild=plain_a_obj_5) + + pm_1 = PlainModelWithM2M.objects.create(field1="pm1") + pm_2 = PlainModelWithM2M.objects.create(field1="pm2") + + pm_1.m2m.set([cm_1, cm_2]) + pm_2.m2m.set( + [ + cm_3, + ] + ) + + # NOTE: prefetch content types so query asserts test data fetched. + ct = ContentType.objects.get_for_model(ParentModel, for_concrete_model=True) + + pm_2.m2m.set([ac_m2m_obj]) + with self.assertNumQueries(4): + # query for PlainModelWithM2M # level 1 (base) + # query for prefetch links (m2m) + # query for ChildModel # level 2 (m2m) + # query for AltChildWithM2MModel # level 2 (m2m) + qs = PlainModelWithM2M.objects.all() + qs = qs.prefetch_related("m2m__altchildmodel__altchildWithm2mmodel__m2m") + all_objs = list(qs) + + @expectedFailure + def test_prefetch_loading_relation_only_on_some_poly_model_using_modelnames(self): + plain_a_obj_1 = PlainA.objects.create(field1="p1") + plain_a_obj_2 = PlainA.objects.create(field1="p2") + plain_a_obj_3 = PlainA.objects.create(field1="p3") + plain_a_obj_4 = PlainA.objects.create(field1="p4") + plain_a_obj_5 = PlainA.objects.create(field1="p5") + + ac_m2m_obj = AltChildWithM2MModel.objects.create( + other_name="o1", + ) + ac_m2m_obj.m2m.set([plain_a_obj_1, plain_a_obj_2, plain_a_obj_3]) + + cm_1 = ChildModel.objects.create(other_name="c1") + cm_2 = ChildModel.objects.create(other_name="c2") + cm_3 = ChildModel.objects.create(other_name="c3") + + acm_1 = AltChildModel.objects.create(other_name="ac3", link_on_altchild=plain_a_obj_4) + acm_2 = AltChildModel.objects.create(other_name="ac3", link_on_altchild=plain_a_obj_5) + + pm_1 = PlainModelWithM2M.objects.create(field1="pm1") + pm_2 = PlainModelWithM2M.objects.create(field1="pm2") + + pm_1.m2m.set([cm_1, cm_2]) + pm_2.m2m.set( + [ + cm_3, + ] + ) + + # NOTE: prefetch content types so query asserts test data fetched. + ct = ContentType.objects.get_for_model(ParentModel, for_concrete_model=True) + + pm_2.m2m.set([ac_m2m_obj]) + with self.assertNumQueries(4): + # query for PlainModelWithM2M # level 1 (base) + # query for prefetch links (m2m) + # query for ChildModel # level 2 (m2m) + # query for AltChildWithM2MModel # level 2 (m2m) + qs = PlainModelWithM2M.objects.all() + qs = qs.prefetch_related("m2m__AltChildWithM2MModel__m2m") + all_objs = list(qs) From e11c2a4e22a68887796eaab541cf232e65956570 Mon Sep 17 00:00:00 2001 From: Paul Gammans Date: Tue, 23 Dec 2025 15:25:51 +0000 Subject: [PATCH 02/14] fix and convert merged to plain asserts --- src/polymorphic/tests/test_orm.py | 51 ++++++++++++++++--------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/src/polymorphic/tests/test_orm.py b/src/polymorphic/tests/test_orm.py index 1a065506..a3b531d2 100644 --- a/src/polymorphic/tests/test_orm.py +++ b/src/polymorphic/tests/test_orm.py @@ -109,6 +109,7 @@ SubclassSelectorProxyBaseModel, SubclassSelectorProxyConcreteModel, ParentLinkAndRelatedName, + TestParentLinkAndRelatedName, UUIDArtProject, UUIDArtProjectA, UUIDArtProjectB, @@ -1020,21 +1021,21 @@ def test_polymorphic__accessor_caching(self): blog_a = BlogA.objects.get(id=blog_a.id) # test reverse accessor & check that we get back cached object on repeated access - self.assertEqual(blog_base.bloga, blog_a) - self.assertIs(blog_base.bloga, blog_base.bloga) + assert blog_base.bloga == blog_a + assert blog_base.bloga is blog_base.bloga cached_blog_a = blog_base.bloga # test forward accessor & check that we get back cached object on repeated access - self.assertEqual(blog_a.blogbase_ptr, blog_base) - self.assertIs(blog_a.blogbase_ptr, blog_a.blogbase_ptr) + assert blog_a.blogbase_ptr == blog_base + assert blog_a.blogbase_ptr is blog_a.blogbase_ptr cached_blog_base = blog_a.blogbase_ptr # check that refresh_from_db correctly clears cached related objects blog_base.refresh_from_db() blog_a.refresh_from_db() - self.assertIsNot(cached_blog_a, blog_base.bloga) - self.assertIsNot(cached_blog_base, blog_a.blogbase_ptr) + assert cached_blog_a is not blog_base.bloga + assert cached_blog_base is not blog_a.blogbase_ptr def test_polymorphic__aggregate(self): """test ModelX___field syntax on aggregate (should work for annotate either)""" @@ -2105,10 +2106,10 @@ def test_select_related_on_poly_classes(self): ).order_by("pk") ) with self.assertNumQueries(0): - self.assertEqual(obj_list[0].relation.name, "p1") - self.assertEqual(obj_list[1].relation.name, "c1") - self.assertEqual(obj_list[2].relation.name, "ac1") - self.assertEqual(obj_list[3].relation.name, "ac2") + assert obj_list[0].relation.name == "p1" + assert obj_list[1].relation.name == "c1" + assert obj_list[2].relation.name == "ac1" + assert obj_list[3].relation.name == "ac2" obj_list[1].relation.link_on_child obj_list[2].relation.link_on_altchild obj_list[3].relation.link_on_altchild @@ -2145,11 +2146,11 @@ def test_select_related_on_poly_classes_supports_multi_level_inheritance(self): ).order_by("pk") ) with self.assertNumQueries(0): - self.assertEqual(obj_list[0].relation.name, "p1") - self.assertEqual(obj_list[1].relation.name, "c1") - self.assertEqual(obj_list[2].relation.name, "ac1") - self.assertEqual(obj_list[3].relation.name, "ac2ab") - self.assertEqual(obj_list[3].relation.more_name, "acab2morename") + assert obj_list[0].relation.name == "p1" + assert obj_list[1].relation.name == "c1" + assert obj_list[2].relation.name == "ac1" + assert obj_list[3].relation.name == "ac2ab" + assert obj_list[3].relation.more_name == "acab2morename" obj_list[1].relation.link_on_child obj_list[2].relation.link_on_altchild obj_list[3].relation.link_on_altchild @@ -2181,9 +2182,9 @@ def test_select_related_on_poly_classes_with_modelname(self): ) with self.assertNumQueries(0): - self.assertEqual(obj_list[0].relation.name, "p1") - self.assertEqual(obj_list[1].relation.name, "c1") - self.assertEqual(obj_list[2].relation.name, "acab2") + assert obj_list[0].relation.name == "p1" + assert obj_list[1].relation.name == "c1" + assert obj_list[2].relation.name == "acab2" obj_list[1].relation.link_on_child obj_list[2].relation.link_on_altchild @@ -2388,14 +2389,14 @@ def test_prefetch_object_is_supported(self): ) objects = list(qs) - self.assertEqual(len(objects[0].poly), 1) + assert len(objects[0].poly) == 1 # derived object was not fetched - self.assertEqual(len(objects[1].poly), 0) + assert len(objects[1].poly) == 0 # base object always found - self.assertEqual(len(objects[0].non_poly), 1) - self.assertEqual(len(objects[1].non_poly), 1) + assert len(objects[0].non_poly) == 1 + assert len(objects[1].non_poly) == 1 def test_select_related_on_poly_classes_preserves_on_relations_annotations(self): b1 = RelatingModel.objects.create() @@ -2418,9 +2419,9 @@ def test_select_related_on_poly_classes_preserves_on_relations_annotations(self) ) objects = list(qs) - self.assertEqual(objects[0].poly[0].relatingmodel__count, 1) - self.assertEqual(objects[1].poly[0].relatingmodel__count, 2) - self.assertEqual(objects[2].poly[0].relatingmodel__count, 2) + assert objects[0].poly[0].relatingmodel__count == 1 + assert objects[1].poly[0].relatingmodel__count == 2 + assert objects[2].poly[0].relatingmodel__count == 2 @expectedFailure def test_prefetch_loading_relation_only_on_some_poly_model(self): From 9016d7dc4508739a892aaa0295a9a47835a8d5d6 Mon Sep 17 00:00:00 2001 From: Paul Gammans Date: Tue, 12 Aug 2025 11:34:29 +0100 Subject: [PATCH 03/14] style: fix linting not to use type(X) == / != etc --- src/polymorphic/query.py | 2 +- src/polymorphic/showfields.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/polymorphic/query.py b/src/polymorphic/query.py index 237e2226..7e1dc4bf 100644 --- a/src/polymorphic/query.py +++ b/src/polymorphic/query.py @@ -700,7 +700,7 @@ def tree_node_test___lookup(my_model, node): for i in range(len(node.children)): child = node.children[i] - if type(child) is tuple: + if isinstance(child, tuple): # this Q object child is a tuple => a kwarg like Q( instance_of=ModelB ) assert "___" not in child[0], ___lookup_assert_msg else: diff --git a/src/polymorphic/showfields.py b/src/polymorphic/showfields.py index e894c70b..082caf6b 100644 --- a/src/polymorphic/showfields.py +++ b/src/polymorphic/showfields.py @@ -62,7 +62,7 @@ def _showfields_add_regular_fields(self, parts): out = field.name # if this is the standard primary key named "id", print it as we did with older versions of django_polymorphic - if field.primary_key and field.name == "id" and type(field) is models.AutoField: + if field.primary_key and field.name == "id" and isinstance(field, models.AutoField): out += f" {getattr(self, field.name)}" # otherwise, display it just like all other fields (with correct type, shortened content etc.) From 1b83609cdbe1913baad48313a53ae14cd9d8ee92 Mon Sep 17 00:00:00 2001 From: Paul Gammans Date: Sat, 23 Mar 2024 14:13:39 +0000 Subject: [PATCH 04/14] fix: allow non poly classes to use RelatedPolymorphicPopulator --- src/polymorphic/query.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/polymorphic/query.py b/src/polymorphic/query.py index 7e1dc4bf..73ef2453 100644 --- a/src/polymorphic/query.py +++ b/src/polymorphic/query.py @@ -59,6 +59,9 @@ def __init__(self, klass_info, select, db): super().__init__(klass_info, select, db) self.field = klass_info["field"] self.reverse = klass_info["reverse"] + # replace replated populator with possibly a polymorphic version + # this is needed for relation across a non poly model + self.related_populators = get_related_populators(klass_info, select, self.db) def build_related(self, row, from_obj, *_): self.populate(row, from_obj) @@ -161,7 +164,9 @@ def _populate(self, row, from_obj, post_actions, obj): self.remote_setter(obj, from_obj) def post_build_modify(self, base_object, from_obj, post_actions, populate_fn): - if base_object.polymorphic_ctype_id == self.model_class_id: + if not hasattr(base_object, "polymorphic_ctype_id"): + populate_fn(base_object) + elif base_object.polymorphic_ctype_id == self.model_class_id: # Real class is exactly the same as base class, go straight to results populate_fn(base_object) else: @@ -245,10 +250,14 @@ def get_related_populators(klass_info, select, db): related_klass_infos = klass_info.get("related_klass_infos", []) for rel_klass_info in related_klass_infos: model = rel_klass_info["model"] + rel_cls = VanillaRelatedPopulator(rel_klass_info, select, db) if issubclass(model, PolymorphicModel): rel_cls = RelatedPolymorphicPopulator(rel_klass_info, select, db) else: - rel_cls = VanillaRelatedPopulator(rel_klass_info, select, db) + for col, *_ in select: + if issubclass(col.target.model, PolymorphicModel): + rel_cls = RelatedPolymorphicPopulator(rel_klass_info, select, db) + break iterators.append(rel_cls) return iterators From e57c2d507f9050459553e4cc968ef33d39336abc Mon Sep 17 00:00:00 2001 From: Paul Gammans Date: Mon, 25 Mar 2024 09:08:28 +0000 Subject: [PATCH 05/14] feat: improve sub-model name convertion --- src/polymorphic/query.py | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/src/polymorphic/query.py b/src/polymorphic/query.py index 73ef2453..b6da1758 100644 --- a/src/polymorphic/query.py +++ b/src/polymorphic/query.py @@ -468,8 +468,8 @@ class PolymorphicQuerySetMixin(QuerySet): def select_related(self, *fields): if fields == (None,) or not len(fields): return super().select_related(*fields) - field_with_poly = list(self.convert_related_fieldnames(fields)) - return super().select_related(*field_with_poly) + field_with_poly = set(self.convert_related_fieldnames(fields)) + return super().select_related(*sorted(list(field_with_poly))) def _convert_field_name_part(self, field_parts, model): """ @@ -488,23 +488,32 @@ def _convert_field_name_part(self, field_parts, model): if field.is_relation: rel_model = field.related_model if next_parts: - self._convert_field_name_part(next_parts, rel_model) + child_selectors = self._convert_field_name_part(next_parts, rel_model) + for selector in child_selectors: + yield field_path + selector else: rel_model = model - except FieldDoesNotExist: submodels = _get_all_sub_models(model) - rel_model = submodels.get(part, None) - field_path = list(_create_base_path(model, rel_model).split("__")) - for field_part_idx in range(0, len(field_path)): - yield field_path[0 : 1 + field_part_idx] - - if next_parts: - child_selectors = self._convert_field_name_part(next_parts, rel_model) + if part == "*": + for rel_model in submodels.values(): + if model is rel_model: + continue + yield from self._convert_submodel_fields_parts(next_parts, model, rel_model) + else: + rel_model = submodels.get(part, None) + if model is not rel_model: + yield from self._convert_submodel_fields_parts(next_parts, model, rel_model) + + def _convert_submodel_fields_parts(self, field_parts, model, rel_model): + field_path = list(_create_base_path(model, rel_model).split("__")) + for field_part_idx in range(0, len(field_path)): + yield field_path[0 : 1 + field_part_idx] + yield field_path + if field_parts: + child_selectors = self._convert_field_name_part(field_parts, rel_model) for selector in child_selectors: - all_field_path = field_path + selector - for field_part_idx in range(0, len(all_field_path)): - yield all_field_path[0 : 1 + field_part_idx] + yield field_path + selector def convert_related_fieldnames(self, fields, opts=None): """ From 422025e3e11d26a9dfc4c62172b8f653203d1ee3 Mon Sep 17 00:00:00 2001 From: Paul Gammans Date: Mon, 25 Mar 2024 13:48:39 +0000 Subject: [PATCH 06/14] style: improve linting --- src/polymorphic/formsets/utils.py | 1 + src/polymorphic/query.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/polymorphic/formsets/utils.py b/src/polymorphic/formsets/utils.py index f9c31ea6..c83543ec 100644 --- a/src/polymorphic/formsets/utils.py +++ b/src/polymorphic/formsets/utils.py @@ -3,6 +3,7 @@ """ + def add_media(dest, media): """ Optimized version of django.forms.Media.__add__() that doesn't create new objects. diff --git a/src/polymorphic/query.py b/src/polymorphic/query.py index b6da1758..55de23e3 100644 --- a/src/polymorphic/query.py +++ b/src/polymorphic/query.py @@ -232,7 +232,7 @@ def pivot_onto_cached_subclass(self, from_obj, obj, model_target_cls): if isinstance(obj, model_target_cls): # We only want to pivot onto a field from a different object, ie not a parent/child - # relationship as this will break the cache and other object relationships + # relationship as this will break the cache and other object if not original._meta.get_path_to_parent(from_obj._meta.model): self.local_setter(from_obj, obj) if obj is not None: @@ -311,9 +311,11 @@ def __iter__(self): related_objs, operator.attrgetter( *[ - field.attname - if from_field == "self" - else queryset.model._meta.get_field(from_field).attname + ( + field.attname + if from_field == "self" + else queryset.model._meta.get_field(from_field).attname + ) for from_field in field.from_fields ] ), @@ -618,9 +620,9 @@ def _filter_or_exclude(self, negate, args, kwargs): def order_by(self, *field_names): """translate the field paths in the args, then call vanilla order_by.""" field_names = [ - translate_polymorphic_field_path(self.model, a) - if isinstance(a, str) - else a # allow expressions to pass unchanged + ( + translate_polymorphic_field_path(self.model, a) if isinstance(a, str) else a + ) # allow expressions to pass unchanged for a in field_names ] return super().order_by(*field_names) From bba85c2c9a56aed217ee88631e38dd0ae72a9850 Mon Sep 17 00:00:00 2001 From: Paul Gammans Date: Tue, 12 Aug 2025 11:38:23 +0100 Subject: [PATCH 07/14] testing: fix and add test for more select_related across models --- src/polymorphic/formsets/utils.py | 1 - .../tests/migrations/0001_initial.py | 23 ++- src/polymorphic/tests/models.py | 5 + src/polymorphic/tests/test_orm.py | 162 +++++++++++++++++- 4 files changed, 175 insertions(+), 16 deletions(-) diff --git a/src/polymorphic/formsets/utils.py b/src/polymorphic/formsets/utils.py index c83543ec..f9c31ea6 100644 --- a/src/polymorphic/formsets/utils.py +++ b/src/polymorphic/formsets/utils.py @@ -3,7 +3,6 @@ """ - def add_media(dest, media): """ Optimized version of django.forms.Media.__add__() that doesn't create new objects. diff --git a/src/polymorphic/tests/migrations/0001_initial.py b/src/polymorphic/tests/migrations/0001_initial.py index 8686900d..cc388d88 100644 --- a/src/polymorphic/tests/migrations/0001_initial.py +++ b/src/polymorphic/tests/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2 on 2025-12-23 09:09 +# Generated by Django 4.2 on 2025-12-23 09:38 from django.conf import settings from django.db import migrations, models @@ -331,6 +331,13 @@ class Migration(migrations.Migration): ('field1', models.CharField(max_length=30)), ], ), + migrations.CreateModel( + name='PlainModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('relation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='tests.parentmodel')), + ], + ), migrations.CreateModel( name='PlainParentModelWithManager', fields=[ @@ -968,6 +975,13 @@ class Migration(migrations.Migration): }, bases=(polymorphic.showfields.ShowFieldType, models.Model), ), + migrations.CreateModel( + name='RefPlainModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('plainobj', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='tests.plainmodel')), + ], + ), migrations.CreateModel( name='RecursionBug', fields=[ @@ -1010,13 +1024,6 @@ class Migration(migrations.Migration): ('m2m', models.ManyToManyField(to='tests.parentmodel')), ], ), - migrations.CreateModel( - name='PlainModel', - fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('relation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='tests.parentmodel')), - ], - ), migrations.CreateModel( name='PlainChildModelWithManager', fields=[ diff --git a/src/polymorphic/tests/models.py b/src/polymorphic/tests/models.py index 4e0a0595..6fb9a17b 100644 --- a/src/polymorphic/tests/models.py +++ b/src/polymorphic/tests/models.py @@ -838,6 +838,11 @@ class PlainModel(models.Model): objects = models.Manager.from_queryset(PolymorphicRelatedQuerySet)() +class RefPlainModel(models.Model): + plainobj = models.ForeignKey(PlainModel, on_delete=models.CASCADE) + objects = models.Manager.from_queryset(PolymorphicRelatedQuerySet)() + + class PlainModelWithM2M(models.Model): field1 = models.CharField(max_length=10) m2m = models.ManyToManyField(ParentModel) diff --git a/src/polymorphic/tests/test_orm.py b/src/polymorphic/tests/test_orm.py index a3b531d2..6c0a8f80 100644 --- a/src/polymorphic/tests/test_orm.py +++ b/src/polymorphic/tests/test_orm.py @@ -98,6 +98,7 @@ ProxyModelB, ProxyModelBase, RedheadDuck, + RefPlainModel, RelatingModel, RelationA, RelationB, @@ -2000,7 +2001,7 @@ def test_normal_django_to_poly_related_give_poly_type(self): PlainModel.objects.create(relation=obj2) PlainModel.objects.create(relation=obj3) - ct = ContentType.objects.get_for_model(AltChildModel, for_concrete_model=True) + ContentType.objects.get_for_model(AltChildModel) with self.assertNumQueries(6): # Queries will be @@ -2096,6 +2097,8 @@ def test_select_related_on_poly_classes(self): obj_p_3 = PlainModel.objects.create(relation=obj_ac1) obj_p_4 = PlainModel.objects.create(relation=obj_ac2) + ContentType.objects.get_for_models(PlainA, ModelExtraExternal, AltChildModel) + with self.assertNumQueries(1): # pos 3 if i cannot do optimized select_related obj_list = list( @@ -2114,6 +2117,149 @@ def test_select_related_on_poly_classes(self): obj_list[2].relation.link_on_altchild obj_list[3].relation.link_on_altchild + self.assertIsInstance(obj_list[0].relation, ParentModel) + self.assertIsInstance(obj_list[1].relation, ChildModel) + self.assertIsInstance(obj_list[2].relation, AltChildModel) + self.assertIsInstance(obj_list[3].relation, AltChildModel) + + def test_select_related_on_poly_classes_simple(self): + # can we fetch the related object but only the minimal 'common' values + plain_a_obj_1 = PlainA.objects.create(field1="f1") + plain_a_obj_2 = PlainA.objects.create(field1="f2") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_ac1 = AltChildModel.objects.create( + name="ac1", other_name="ac1name", link_on_altchild=plain_a_obj_1 + ) + obj_ac2 = AltChildModel.objects.create( + name="ac2", other_name="ac2name", link_on_altchild=plain_a_obj_2 + ) + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_ac1) + obj_p_4 = PlainModel.objects.create(relation=obj_ac2) + + with self.assertNumQueries(1): + # pos 3 if i cannot do optimized select_related + obj_list = list( + PlainModel.objects.select_related( + "relation", + "relation__childmodel", + "relation__altchildmodel", + ) + .order_by("pk") + .only( + "relation__name", + "relation__polymorphic_ctype_id", + ) + ) + with self.assertNumQueries(0): + self.assertEqual(obj_list[0].relation.name, "p1") + self.assertEqual(obj_list[1].relation.name, "c1") + self.assertEqual(obj_list[2].relation.name, "ac1") + self.assertEqual(obj_list[3].relation.name, "ac2") + + self.assertIsInstance(obj_list[0].relation, ParentModel) + self.assertIsInstance(obj_list[1].relation, ChildModel) + self.assertIsInstance(obj_list[2].relation, AltChildModel) + self.assertIsInstance(obj_list[3].relation, AltChildModel) + + def test_select_related_on_poly_classes_indirect_related(self): + # can we fetch the related object but only the minimal 'common' values + plain_a_obj_1 = PlainA.objects.create(field1="f1") + plain_a_obj_2 = PlainA.objects.create(field1="f2") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_ac1 = AltChildModel.objects.create( + name="ac1", other_name="ac1name", link_on_altchild=plain_a_obj_1 + ) + obj_ac2 = AltChildModel.objects.create( + name="ac2", other_name="ac2name", link_on_altchild=plain_a_obj_2 + ) + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_ac1) + obj_p_4 = PlainModel.objects.create(relation=obj_ac2) + + robj_1 = RefPlainModel.objects.create(plainobj=obj_p_1) + robj_2 = RefPlainModel.objects.create(plainobj=obj_p_2) + robj_3 = RefPlainModel.objects.create(plainobj=obj_p_3) + robj_4 = RefPlainModel.objects.create(plainobj=obj_p_4) + + # Prefetch content_types + ContentType.objects.get_for_models(PlainModel, PlainA, ModelExtraExternal) + + with self.assertNumQueries(1): + # pos 3 if i cannot do optimized select_related + obj_list = list( + RefPlainModel.objects.select_related( + # "plainobj__relation", + "plainobj__relation", + "plainobj__relation__childmodel__link_on_child", + "plainobj__relation__altchildmodel__link_on_altchild", + ).order_by("pk") + ) + with self.assertNumQueries(0): + self.assertEqual(obj_list[0].plainobj.relation.name, "p1") + self.assertEqual(obj_list[1].plainobj.relation.name, "c1") + self.assertEqual(obj_list[2].plainobj.relation.name, "ac1") + self.assertEqual(obj_list[3].plainobj.relation.name, "ac2") + + self.assertIsInstance(obj_list[0].plainobj.relation, ParentModel) + self.assertIsInstance(obj_list[1].plainobj.relation, ChildModel) + self.assertIsInstance(obj_list[2].plainobj.relation, AltChildModel) + self.assertIsInstance(obj_list[3].plainobj.relation, AltChildModel) + + def test_select_related_fecth_all_poly_classes_indirect_related(self): + # can we fetch the related object but only the minimal 'common' values + plain_a_obj_1 = PlainA.objects.create(field1="f1") + plain_a_obj_2 = PlainA.objects.create(field1="f2") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_ac1 = AltChildModel.objects.create( + name="ac1", other_name="ac1name", link_on_altchild=plain_a_obj_1 + ) + obj_ac2 = AltChildModel.objects.create( + name="ac2", other_name="ac2name", link_on_altchild=plain_a_obj_2 + ) + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_ac1) + obj_p_4 = PlainModel.objects.create(relation=obj_ac2) + + robj_1 = RefPlainModel.objects.create(plainobj=obj_p_1) + robj_2 = RefPlainModel.objects.create(plainobj=obj_p_2) + robj_3 = RefPlainModel.objects.create(plainobj=obj_p_3) + robj_4 = RefPlainModel.objects.create(plainobj=obj_p_4) + + # Prefetch content_types + ContentType.objects.get_for_models( + PlainModel, PlainA, ModelExtraExternal, AltChildAsBaseModel, AltChildWithM2MModel + ) + + with self.assertNumQueries(1): + # pos 3 if i cannot do optimized select_related + obj_list = list( + RefPlainModel.objects.select_related( + # "plainobj__relation", + "plainobj__relation", + "plainobj__relation__*", + ).order_by("pk") + ) + with self.assertNumQueries(0): + self.assertEqual(obj_list[0].plainobj.relation.name, "p1") + self.assertEqual(obj_list[1].plainobj.relation.name, "c1") + self.assertEqual(obj_list[2].plainobj.relation.name, "ac1") + self.assertEqual(obj_list[3].plainobj.relation.name, "ac2") + + self.assertIsInstance(obj_list[0].plainobj.relation, ParentModel) + self.assertIsInstance(obj_list[1].plainobj.relation, ChildModel) + self.assertIsInstance(obj_list[2].plainobj.relation, AltChildModel) + self.assertIsInstance(obj_list[3].plainobj.relation, AltChildModel) + def test_select_related_on_poly_classes_supports_multi_level_inheritance(self): plain_a_obj_1 = PlainA.objects.create(field1="f1") plain_a_obj_2 = PlainA.objects.create(field1="f2") @@ -2135,6 +2281,8 @@ def test_select_related_on_poly_classes_supports_multi_level_inheritance(self): obj_p_3 = PlainModel.objects.create(relation=obj_ac1) obj_p_4 = PlainModel.objects.create(relation=obj_acab2) + ContentType.objects.get_for_models(PlainA, ModelExtraExternal) + with self.assertNumQueries(1): # pos 3 if i cannot do optimized select_related obj_list = list( @@ -2170,7 +2318,7 @@ def test_select_related_on_poly_classes_with_modelname(self): obj_p_2 = PlainModel.objects.create(relation=obj_c) obj_p_3 = PlainModel.objects.create(relation=obj_acab2) - ct = ContentType.objects.get_for_model(AltChildModel, for_concrete_model=True) + ContentType.objects.get_for_models(PlainA, ModelExtraExternal, AltChildModel) with self.assertNumQueries(1): obj_list = list( @@ -2201,7 +2349,7 @@ def test_prefetch_related_from_basepoly(self): obja2.m2m.add(objbc1) # NOTE: prefetch content types so query asserts test data fetched. - ct = ContentType.objects.get_for_model(NonSymRelationBase, for_concrete_model=True) + ContentType.objects.get_for_model(NonSymRelationBase) with self.assertNumQueries(10): # query for NonSymRelationBase (base) @@ -2260,7 +2408,7 @@ def test_prefetch_related_from_subclass(self): obja2.m2m.add(objbc1) # NOTE: prefetch content types so query asserts test data fetched. - ct = ContentType.objects.get_for_model(NonSymRelationBase, for_concrete_model=True) + ContentType.objects.get_for_model(NonSymRelationBase) with self.assertNumQueries(7): # query for NonSymRelationA # level 1 (base) @@ -2360,7 +2508,7 @@ def test_select_related_field_from_polymorphic_child_class_using_modelnames_mult obj_ac3 = ChildModel.objects.create(name="c2", other_name="c3name") # NOTE: prefetch content types so query asserts test data fetched. - ct = ContentType.objects.get_for_model(AltChildModel, for_concrete_model=True) + ContentType.objects.get_for_model(AltChildModel) with self.assertNumQueries(2): # Queries will be @@ -2454,7 +2602,7 @@ def test_prefetch_loading_relation_only_on_some_poly_model(self): ) # NOTE: prefetch content types so query asserts test data fetched. - ct = ContentType.objects.get_for_model(ParentModel, for_concrete_model=True) + ContentType.objects.get_for_model(ParentModel) pm_2.m2m.set([ac_m2m_obj]) with self.assertNumQueries(4): @@ -2497,7 +2645,7 @@ def test_prefetch_loading_relation_only_on_some_poly_model_using_modelnames(self ) # NOTE: prefetch content types so query asserts test data fetched. - ct = ContentType.objects.get_for_model(ParentModel, for_concrete_model=True) + ContentType.objects.get_for_model(ParentModel) pm_2.m2m.set([ac_m2m_obj]) with self.assertNumQueries(4): From 1600a7f47e807bdef2e8b390004e79901137b184 Mon Sep 17 00:00:00 2001 From: Paul Gammans Date: Wed, 27 Mar 2024 16:53:20 +0000 Subject: [PATCH 08/14] doc: add relect_related basic docs --- docs/advanced.rst | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/docs/advanced.rst b/docs/advanced.rst index f04535ec..d2189453 100644 --- a/docs/advanced.rst +++ b/docs/advanced.rst @@ -193,9 +193,27 @@ About Queryset Methods * :meth:`~django.db.models.query.QuerySet.distinct` works as expected. It only regards the fields of the base class, but this should never make a difference. -* :meth:`~django.db.models.query.QuerySet.select_related` works just as usual, but it can not - (yet) be used to select relations in inherited models (like - ``ModelA.objects.select_related('ModelC___fieldxy')`` ) +* :meth:`~django.db.models.query.QuerySet.select_related` works just as usual with the + exception that the query set must be derived from a PolymorphicRelatedQuerySetMixin + or PolymorphicRelatedQuerySet. + + This can be achieved by using a custom manager + + class NonPolyModel(models.Model): + relation = models.ForeignKey(BasePolyModel, on_delete=models.CASCADE) + objects = models.Manager.from_queryset(PolymorphicRelatedQuerySet)() + + + To select related fields the model name comes after the field name and set the + field. + ``ModelA.objects.filter(....).select_related('field___TargetModel__subfield')``. + or using the polymorphic added related fieldname which is normally the lowercase + version of the model name. + ``ModelA.objects.filter(....).select_related('field__targetmodel__subfield')`` + + This automatically manages the via models between the model specified in the related + field and the target model. + ``ModelA.objects.filter(....).select_related('field__targetparentmodel__targetmodel__subfield')`` * :meth:`~django.db.models.query.QuerySet.extra` works as expected (it returns polymorphic results) but currently has one restriction: The resulting objects are required to have a unique From 32f117e23a4248b2f12b747dcad8f332f5ba6d64 Mon Sep 17 00:00:00 2001 From: Paul Gammans Date: Thu, 18 Dec 2025 11:42:58 +0000 Subject: [PATCH 09/14] feat: add support to convert from normal manager/qs to poly versions --- docs/advanced.rst | 7 ++ src/polymorphic/query.py | 35 +++++++- .../tests/migrations/0001_initial.py | 9 ++- src/polymorphic/tests/models.py | 7 +- src/polymorphic/tests/test_orm.py | 79 ++++++++++++++++++- 5 files changed, 131 insertions(+), 6 deletions(-) diff --git a/docs/advanced.rst b/docs/advanced.rst index d2189453..f7240d3e 100644 --- a/docs/advanced.rst +++ b/docs/advanced.rst @@ -203,6 +203,13 @@ About Queryset Methods relation = models.ForeignKey(BasePolyModel, on_delete=models.CASCADE) objects = models.Manager.from_queryset(PolymorphicRelatedQuerySet)() + or by converting a models queryset using + + class NonPolyModel(models.Model): + relation = models.ForeignKey(BasePolyModel, on_delete=models.CASCADE) + objects = models.Manager.from_queryset(QuerySet)() + + ``convert_to_polymorphic_queryset(NonPolyModel.objects).filter(...)`` To select related fields the model name comes after the field name and set the field. diff --git a/src/polymorphic/query.py b/src/polymorphic/query.py index 55de23e3..e557d670 100644 --- a/src/polymorphic/query.py +++ b/src/polymorphic/query.py @@ -10,7 +10,7 @@ from django.contrib.contenttypes.models import ContentType from django.core.exceptions import FieldDoesNotExist from django.db import connections, models -from django.db.models import FilteredRelation +from django.db.models import FilteredRelation, Manager from django.db.models.constants import LOOKUP_SEP from django.db.models.query import ModelIterable, Q, QuerySet, RelatedPopulator @@ -1056,3 +1056,36 @@ def _get_real_instances(self, base_result_objects): class PolymorphicRelatedQuerySet(PolymorphicRelatedQuerySetMixin, QuerySet): pass + + +def convert_to_polymorphic_queryset(qs): + "Convert a queryset to one that support polymorphic evaluation" + + if isinstance(qs, Manager): + qs = qs.get_queryset() + + if issubclass(qs.__class__, PolymorphicQuerySetMixin): + return qs + + assert issubclass(QuerySet, qs.__class__), ( + f"PolymorphicModel: cannot guarantee conversion of {qs.__class__} to polymorphic queryset" + ) + + class RelatedPolyQuerySet(PolymorphicRelatedQuerySetMixin, qs.__class__): + @classmethod + def _convert_to(cls, qs): + c = cls( + model=qs.model, + query=qs.query.chain(), + using=qs._db, + hints=qs._hints, + ) + c._sticky_filter = qs._sticky_filter + c._for_write = qs._for_write + c._prefetch_related_lookups = qs._prefetch_related_lookups[:] + c._known_related_objects = qs._known_related_objects + c._fields = qs._fields + return c + + poly_qs = RelatedPolyQuerySet._convert_to(qs) + return poly_qs diff --git a/src/polymorphic/tests/migrations/0001_initial.py b/src/polymorphic/tests/migrations/0001_initial.py index cc388d88..1ddc4dac 100644 --- a/src/polymorphic/tests/migrations/0001_initial.py +++ b/src/polymorphic/tests/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2 on 2025-12-23 09:38 +# Generated by Django 4.2 on 2025-12-23 09:40 from django.conf import settings from django.db import migrations, models @@ -933,6 +933,13 @@ class Migration(migrations.Migration): }, bases=('tests.uuidproject',), ), + migrations.CreateModel( + name='VanillaPlainModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('relation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='tests.parentmodel')), + ], + ), migrations.CreateModel( name='SwappedModel', fields=[ diff --git a/src/polymorphic/tests/models.py b/src/polymorphic/tests/models.py index 6fb9a17b..8b5d5a55 100644 --- a/src/polymorphic/tests/models.py +++ b/src/polymorphic/tests/models.py @@ -838,9 +838,14 @@ class PlainModel(models.Model): objects = models.Manager.from_queryset(PolymorphicRelatedQuerySet)() +class VanillaPlainModel(models.Model): + relation = models.ForeignKey(ParentModel, on_delete=models.CASCADE) + + class RefPlainModel(models.Model): plainobj = models.ForeignKey(PlainModel, on_delete=models.CASCADE) - objects = models.Manager.from_queryset(PolymorphicRelatedQuerySet)() + objects = models.Manager.from_queryset(QuerySet)() + poly_objects = models.Manager.from_queryset(PolymorphicRelatedQuerySet)() class PlainModelWithM2M(models.Model): diff --git a/src/polymorphic/tests/test_orm.py b/src/polymorphic/tests/test_orm.py index 6c0a8f80..5d2fcd79 100644 --- a/src/polymorphic/tests/test_orm.py +++ b/src/polymorphic/tests/test_orm.py @@ -21,6 +21,7 @@ from django.test.utils import CaptureQueriesContext from polymorphic import query_translate +from polymorphic.query import convert_to_polymorphic_queryset, PolymorphicRelatedQuerySetMixin from polymorphic.managers import PolymorphicManager from polymorphic.models import PolymorphicTypeInvalid, PolymorphicTypeUndefined from polymorphic.tests.models import ( @@ -38,6 +39,7 @@ ChildModelWithManager, CustomPkBase, CustomPkInherit, + Duck, Enhance_Base, Enhance_Plain, Enhance_Inherit, @@ -121,7 +123,7 @@ UUIDPlainC, UUIDProject, UUIDResearchProject, - Duck, + VanillaPlainModel, PurpleHeadDuck, Account, SpecialAccount1, @@ -2165,6 +2167,77 @@ def test_select_related_on_poly_classes_simple(self): self.assertIsInstance(obj_list[2].relation, AltChildModel) self.assertIsInstance(obj_list[3].relation, AltChildModel) + def test_we_can_upgrade_a_query_set_to_polymorphic_supports_already_ploy_qs(self): + base_qs = RefPlainModel.poly_objects.get_queryset() + self.assertIs(convert_to_polymorphic_queryset(base_qs), base_qs) + + def test_we_can_upgrade_a_query_set_to_polymorphic_supports_non_ploy_qs_on_ploy_object(self): + base_qs = RefPlainModel.objects.get_queryset() + self.assertIsNot(convert_to_polymorphic_queryset(base_qs), base_qs) + self.assertIsInstance( + convert_to_polymorphic_queryset(base_qs), PolymorphicRelatedQuerySetMixin + ) + + def test_we_can_upgrade_a_query_set_to_polymorphic_supports_non_ploy_managers_on_ploy_object( + self, + ): + base_qs = RefPlainModel.objects + self.assertIsNot(convert_to_polymorphic_queryset(base_qs), base_qs) + self.assertIsInstance( + convert_to_polymorphic_queryset(base_qs), PolymorphicRelatedQuerySetMixin + ) + + def test_we_can_upgrade_a_query_set_to_polymorphic(self): + # can we fetch the related object but only the minimal 'common' values + plain_a_obj_1 = PlainA.objects.create(field1="f1") + plain_a_obj_2 = PlainA.objects.create(field1="f2") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_ac1 = AltChildModel.objects.create( + name="ac1", other_name="ac1name", link_on_altchild=plain_a_obj_1 + ) + obj_ac2 = AltChildModel.objects.create( + name="ac2", other_name="ac2name", link_on_altchild=plain_a_obj_2 + ) + obj_p_1 = VanillaPlainModel.objects.create(relation=obj_p) + obj_p_2 = VanillaPlainModel.objects.create(relation=obj_c) + obj_p_3 = VanillaPlainModel.objects.create(relation=obj_ac1) + obj_p_4 = VanillaPlainModel.objects.create(relation=obj_ac2) + + with self.assertNumQueries(1): + # pos 3 if i cannot do optimized select_related + obj_list = list(VanillaPlainModel.objects.order_by("pk")) + + with self.assertNumQueries(7): + self.assertEqual(obj_list[0].relation.name, "p1") + self.assertEqual(obj_list[1].relation.name, "c1") + self.assertEqual(obj_list[2].relation.name, "ac1") + self.assertEqual(obj_list[3].relation.name, "ac2") + + with self.assertNumQueries(1): + # pos 3 if i cannot do optimized select_related + obj_list = list( + convert_to_polymorphic_queryset(VanillaPlainModel.objects) + .select_related( + "relation", + "relation__childmodel", + "relation__altchildmodel", + ) + .order_by("pk") + ) + + with self.assertNumQueries(0): + self.assertEqual(obj_list[0].relation.name, "p1") + self.assertEqual(obj_list[1].relation.name, "c1") + self.assertEqual(obj_list[2].relation.name, "ac1") + self.assertEqual(obj_list[3].relation.name, "ac2") + + self.assertIsInstance(obj_list[0].relation, ParentModel) + self.assertIsInstance(obj_list[1].relation, ChildModel) + self.assertIsInstance(obj_list[2].relation, AltChildModel) + self.assertIsInstance(obj_list[3].relation, AltChildModel) + def test_select_related_on_poly_classes_indirect_related(self): # can we fetch the related object but only the minimal 'common' values plain_a_obj_1 = PlainA.objects.create(field1="f1") @@ -2194,7 +2267,7 @@ def test_select_related_on_poly_classes_indirect_related(self): with self.assertNumQueries(1): # pos 3 if i cannot do optimized select_related obj_list = list( - RefPlainModel.objects.select_related( + RefPlainModel.poly_objects.select_related( # "plainobj__relation", "plainobj__relation", "plainobj__relation__childmodel__link_on_child", @@ -2243,7 +2316,7 @@ def test_select_related_fecth_all_poly_classes_indirect_related(self): with self.assertNumQueries(1): # pos 3 if i cannot do optimized select_related obj_list = list( - RefPlainModel.objects.select_related( + RefPlainModel.poly_objects.select_related( # "plainobj__relation", "plainobj__relation", "plainobj__relation__*", From 3b31727ab677d8ac86715234b9145b3e6853892e Mon Sep 17 00:00:00 2001 From: Paul Gammans Date: Thu, 4 Apr 2024 15:39:10 +0100 Subject: [PATCH 10/14] feat: move to using upper case class names on search test --- src/polymorphic/tests/test_orm.py | 53 +++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/src/polymorphic/tests/test_orm.py b/src/polymorphic/tests/test_orm.py index 5d2fcd79..10ce0ae0 100644 --- a/src/polymorphic/tests/test_orm.py +++ b/src/polymorphic/tests/test_orm.py @@ -2106,8 +2106,8 @@ def test_select_related_on_poly_classes(self): obj_list = list( PlainModel.objects.select_related( "relation", - "relation__childmodel__link_on_child", - "relation__altchildmodel__link_on_altchild", + "relation__ChildModel__link_on_child", + "relation__AltChildModel__link_on_altchild", ).order_by("pk") ) with self.assertNumQueries(0): @@ -2124,6 +2124,31 @@ def test_select_related_on_poly_classes(self): self.assertIsInstance(obj_list[2].relation, AltChildModel) self.assertIsInstance(obj_list[3].relation, AltChildModel) + def test_select_related_can_merge_fields(self): + # can we fetch the related object but only the minimal 'common' values + plain_a_obj_1 = PlainA.objects.create(field1="f1") + plain_a_obj_2 = PlainA.objects.create(field1="f2") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_ac1 = AltChildModel.objects.create( + name="ac1", other_name="ac1name", link_on_altchild=plain_a_obj_1 + ) + obj_ac2 = AltChildModel.objects.create( + name="ac2", other_name="ac2name", link_on_altchild=plain_a_obj_2 + ) + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_ac1) + obj_p_4 = PlainModel.objects.create(relation=obj_ac2) + ContentType.objects.get_for_models(PlainA, ModelExtraExternal, AltChildModel) + base_query = PlainModel.objects.select_related( + "relation__ChildModel", + ) + base_query = base_query.select_related("relation__AltChildModel") + with self.assertNumQueries(1): + list(base_query) + def test_select_related_on_poly_classes_simple(self): # can we fetch the related object but only the minimal 'common' values plain_a_obj_1 = PlainA.objects.create(field1="f1") @@ -2147,13 +2172,13 @@ def test_select_related_on_poly_classes_simple(self): obj_list = list( PlainModel.objects.select_related( "relation", - "relation__childmodel", - "relation__altchildmodel", + "relation__ChildModel_", + "relation__AltChildModel_", ) .order_by("pk") .only( "relation__name", - "relation__polymorphic_ctype_id", + "relation__polymorphic_ctype", ) ) with self.assertNumQueries(0): @@ -2221,8 +2246,8 @@ def test_we_can_upgrade_a_query_set_to_polymorphic(self): convert_to_polymorphic_queryset(VanillaPlainModel.objects) .select_related( "relation", - "relation__childmodel", - "relation__altchildmodel", + "relation__ChildModel", + "relation__AltChildModel", ) .order_by("pk") ) @@ -2270,8 +2295,8 @@ def test_select_related_on_poly_classes_indirect_related(self): RefPlainModel.poly_objects.select_related( # "plainobj__relation", "plainobj__relation", - "plainobj__relation__childmodel__link_on_child", - "plainobj__relation__altchildmodel__link_on_altchild", + "plainobj__relation__ChildModel__link_on_child", + "plainobj__relation__AltChildModel__link_on_altchild", ).order_by("pk") ) with self.assertNumQueries(0): @@ -2361,9 +2386,9 @@ def test_select_related_on_poly_classes_supports_multi_level_inheritance(self): obj_list = list( PlainModel.objects.select_related( "relation", - "relation__childmodel__link_on_child", - "relation__altchildmodel__link_on_altchild", - "relation__altchildmodel__altchildasbasemodel__link_on_altchild", + "relation__ChildModel__link_on_child", + "relation__AltChildModel__link_on_altchild", + "relation__AltChildAsBaseModel__link_on_altchild", ).order_by("pk") ) with self.assertNumQueries(0): @@ -2540,7 +2565,7 @@ def test_select_related_field_from_polymorphic_child_class(self): all_objs = [ obj for obj in ParentModel.objects.select_related( - "altchildmodel", + "AltChildModel", ) ] @@ -2619,7 +2644,7 @@ def test_prefetch_object_is_supported(self): assert len(objects[0].non_poly) == 1 assert len(objects[1].non_poly) == 1 - def test_select_related_on_poly_classes_preserves_on_relations_annotations(self): + def test_prefetch_related_on_poly_classes_preserves_on_relations_annotations(self): b1 = RelatingModel.objects.create() b2 = RelatingModel.objects.create() b3 = RelatingModel.objects.create() From b076e11159797c232953f700137a5624184c5899 Mon Sep 17 00:00:00 2001 From: Paul Gammans Date: Thu, 4 Apr 2024 15:42:22 +0100 Subject: [PATCH 11/14] fix: raise unknown select_related field part --- src/polymorphic/query.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/polymorphic/query.py b/src/polymorphic/query.py index e557d670..05e78df2 100644 --- a/src/polymorphic/query.py +++ b/src/polymorphic/query.py @@ -220,6 +220,7 @@ def pivot_onto_cached_subclass(self, from_obj, obj, model_target_cls): if not isinstance( rel_iter, (VanillaRelatedPopulator, RelatedPolymorphicPopulator) ): + # NOTE: We don't know how to handle this type of populator! continue if rel_iter.reverse and rel_iter.model_cls is cls: if rel_iter.field.name in parents.keys(): @@ -506,6 +507,8 @@ def _convert_field_name_part(self, field_parts, model): rel_model = submodels.get(part, None) if model is not rel_model: yield from self._convert_submodel_fields_parts(next_parts, model, rel_model) + else: + raise def _convert_submodel_fields_parts(self, field_parts, model, rel_model): field_path = list(_create_base_path(model, rel_model).split("__")) From 0221c82cda8b24cb4c23689f2bed7eb169b3c901 Mon Sep 17 00:00:00 2001 From: Paul Gammans Date: Thu, 4 Apr 2024 16:36:22 +0100 Subject: [PATCH 12/14] test fix extra _ in path --- src/polymorphic/tests/test_orm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/polymorphic/tests/test_orm.py b/src/polymorphic/tests/test_orm.py index 10ce0ae0..7d675bc7 100644 --- a/src/polymorphic/tests/test_orm.py +++ b/src/polymorphic/tests/test_orm.py @@ -2172,8 +2172,8 @@ def test_select_related_on_poly_classes_simple(self): obj_list = list( PlainModel.objects.select_related( "relation", - "relation__ChildModel_", - "relation__AltChildModel_", + "relation__ChildModel", + "relation__AltChildModel", ) .order_by("pk") .only( From 83b48aa7cbdb29753c27cf005d6db91e2b91b988 Mon Sep 17 00:00:00 2001 From: Paul Gammans Date: Mon, 29 Dec 2025 11:02:19 +0000 Subject: [PATCH 13/14] fix: issues caused by #620 change of polymorphic_primary_key_name --- src/polymorphic/query.py | 39 ++++-- .../tests/migrations/0001_initial.py | 26 +++- src/polymorphic/tests/models.py | 8 ++ src/polymorphic/tests/test_orm.py | 132 ++++++++++++++++-- 4 files changed, 185 insertions(+), 20 deletions(-) diff --git a/src/polymorphic/query.py b/src/polymorphic/query.py index 05e78df2..890633ea 100644 --- a/src/polymorphic/query.py +++ b/src/polymorphic/query.py @@ -8,7 +8,7 @@ from collections import defaultdict from django.contrib.contenttypes.models import ContentType -from django.core.exceptions import FieldDoesNotExist +from django.core.exceptions import FieldDoesNotExist, FieldError from django.db import connections, models from django.db.models import FilteredRelation, Manager from django.db.models.constants import LOOKUP_SEP @@ -240,8 +240,19 @@ def pivot_onto_cached_subclass(self, from_obj, obj, model_target_cls): self.remote_setter(obj, from_obj) return None, None - pk_name = self.model_cls.polymorphic_primary_key_name - return model_target_cls, (getattr(original, pk_name), self.field.name) + local_pk_name = original.__class__.polymorphic_primary_key_name + target_pk_name = original.__class__.polymorphic_primary_key_name + original_pk = getattr(original, local_pk_name) + + # NOTE: We could use a recursive function on model_target_cls._meta.parents + # PolymorphicModel.much _get_inheritance_relation_fields_and_models.like add_all_sub_models + for field in model_target_cls._meta.fields: + if field.is_relation is True: + for rel_field in field.foreign_related_fields: + if rel_field.name is local_pk_name and rel_field.model is original._meta.model: + target_pk_name = field.attname + + return model_target_cls, (original_pk, self.field.name, target_pk_name) def get_related_populators(klass_info, select, db): @@ -409,8 +420,8 @@ def fetch_polymorphic(self, post_actions, base_result_objects): for action, populate_fn in post_actions: target_class, pk_info = action() if target_class: - pk, name = pk_info - idlist_per_model[target_class].append((pk, name)) + pk, name, pk_name = pk_info + idlist_per_model[target_class].append(pk_info) update_fn_per_model[target_class].append((populate_fn, pk)) # For each model in "idlist_per_model" request its objects (the real model) @@ -419,16 +430,26 @@ def fetch_polymorphic(self, post_actions, base_result_objects): # Then we copy the extra() select fields from the base objects to the real objects. # TODO: defer(), only(): support for these would be around here for real_concrete_class, data in idlist_per_model.items(): - idlist, names = zip(*data) + idlist, names, pk_attr_names = zip(*data) updates = update_fn_per_model[real_concrete_class] - pk_name = real_concrete_class.polymorphic_primary_key_name + + if len(set(pk_attr_names)) != 1: + raise FieldError( + "PolymorphicModel: cannot convert model type as non " + f"upk_namesnique related key names {pk_attr_names}" + ) + + pk_attr_name = pk_attr_names[0] + # FIXME: this seams to get extra field already fetch in base + # initial query, we may need to add defer? + real_objects = real_concrete_class._base_objects.db_manager(self.queryset.db).filter( - **{("%s__in" % pk_name): idlist} + **{("%s__in" % pk_attr_name): idlist}, ) real_objects = self.apply_select_related(real_objects, set(names)) real_objects_dict = { - getattr(real_object, pk_name): real_object for real_object in real_objects + getattr(real_object, pk_attr_name): real_object for real_object in real_objects } for populate_fn, o_pk in updates: diff --git a/src/polymorphic/tests/migrations/0001_initial.py b/src/polymorphic/tests/migrations/0001_initial.py index 1ddc4dac..08d4f187 100644 --- a/src/polymorphic/tests/migrations/0001_initial.py +++ b/src/polymorphic/tests/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2 on 2025-12-23 09:40 +# Generated by Django 4.2 on 2025-12-23 08:21 from django.conf import settings from django.db import migrations, models @@ -1433,6 +1433,30 @@ class Migration(migrations.Migration): }, bases=('tests.mrobase2', 'tests.mrobase3'), ), + migrations.CreateModel( + name='NonAutoPKChild', + fields=[ + ('altchildmodel_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, to='tests.altchildmodel')), + ('uuid_primary_key', models.UUIDField(default=uuid.uuid1, primary_key=True, serialize=False)), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.altchildmodel',), + ), + migrations.CreateModel( + name='NonUUIDArtProject', + fields=[ + ('uuidresearchproject_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, to='tests.uuidresearchproject')), + ('idkey', models.AutoField(primary_key=True, serialize=False)), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.uuidresearchproject',), + ), migrations.CreateModel( name='PlainC', fields=[ diff --git a/src/polymorphic/tests/models.py b/src/polymorphic/tests/models.py index 8b5d5a55..9c43e458 100644 --- a/src/polymorphic/tests/models.py +++ b/src/polymorphic/tests/models.py @@ -346,6 +346,10 @@ class UUIDResearchProject(UUIDProject): supervisor = models.CharField(max_length=30) +class NonUUIDArtProject(UUIDResearchProject): + idkey = models.AutoField(primary_key=True) + + class UUIDArtProjectA(UUIDArtProject): ... @@ -833,6 +837,10 @@ class AltChildAsBaseModel(AltChildModel): more_name = models.CharField(max_length=10) +class NonAutoPKChild(AltChildModel): + uuid_primary_key = models.UUIDField(primary_key=True, default=uuid.uuid1) + + class PlainModel(models.Model): relation = models.ForeignKey(ParentModel, on_delete=models.CASCADE) objects = models.Manager.from_queryset(PolymorphicRelatedQuerySet)() diff --git a/src/polymorphic/tests/test_orm.py b/src/polymorphic/tests/test_orm.py index 7d675bc7..386571d7 100644 --- a/src/polymorphic/tests/test_orm.py +++ b/src/polymorphic/tests/test_orm.py @@ -74,12 +74,14 @@ MultiTableDerived, MyManager, MyManagerQuerySet, + NonAutoPKChild, NonPolymorphicParent, NonProxyChild, NonSymRelationA, NonSymRelationB, NonSymRelationBase, NonSymRelationBC, + NonUUIDArtProject, One2OneRelatingModel, One2OneRelatingModelDerived, ParentModel, @@ -112,7 +114,6 @@ SubclassSelectorProxyBaseModel, SubclassSelectorProxyConcreteModel, ParentLinkAndRelatedName, - TestParentLinkAndRelatedName, UUIDArtProject, UUIDArtProjectA, UUIDArtProjectB, @@ -1997,19 +1998,22 @@ def test_infinite_recursion_with_only(self): def test_normal_django_to_poly_related_give_poly_type(self): obj1 = ParentModel.objects.create(name="m1") obj2 = ChildModel.objects.create(name="m2", other_name="m2") - obj3 = ChildModel.objects.create(name="m1") + obj3 = ChildModel.objects.create(name="m3") + obj4 = ChildModel.objects.create(name="m3") + obj5 = AltChildModel.objects.create(name="m4") PlainModel.objects.create(relation=obj1) PlainModel.objects.create(relation=obj2) PlainModel.objects.create(relation=obj3) + PlainModel.objects.create(relation=obj4) + PlainModel.objects.create(relation=obj5) - ContentType.objects.get_for_model(AltChildModel) - - with self.assertNumQueries(6): + with self.assertNumQueries(10): # Queries will be # * 1 for All PlainModels object (1) - # * 1 for each relations ParentModel (4) - # * 1 for each relations ChilModel is needed (3) + # * 1 for each relations ParentModel (5) + # * 1 for each relations ChildModel is needed (3) + # * 1 for each relations AltChildModel is needed (1) multi_q = [ # these obj.relation values will have their proper sub type obj.relation @@ -2017,18 +2021,121 @@ def test_normal_django_to_poly_related_give_poly_type(self): ] multi_q_types = [type(obj) for obj in multi_q] - with self.assertNumQueries(2): + with self.assertNumQueries(3): grouped_q = [ # these obj.relation values will all be ParentModel's # unless we fix select related but should be their proper # sub type by using PolymorphicRelatedQuerySetMixin + # 1 query for each relation type obj.relation for obj in PlainModel.objects.select_related("relation") ] grouped_q_types = [type(obj) for obj in grouped_q] self.assertListEqual(multi_q_types, grouped_q_types) - self.assertListEqual(grouped_q, [obj1, obj2, obj3]) + self.assertListEqual(grouped_q, [obj1, obj2, obj3, obj4, obj5]) + + def test_normal_django_to_multi_level_poly_related_give_poly_type(self): + obj1 = ParentModel.objects.create(name="m1") + obj2 = ChildModel.objects.create(name="m2", other_name="c1") + obj3 = AltChildModel.objects.create(name="m3") + obj4 = AltChildAsBaseModel.objects.create(name="m4", more_name="acab1") + + PlainModel.objects.create(relation=obj1) + PlainModel.objects.create(relation=obj2) + PlainModel.objects.create(relation=obj3) + PlainModel.objects.create(relation=obj4) + + with self.assertNumQueries(4): + grouped_q = [ + # these obj.relation values will all be ParentModel's + # unless we fix select related but should be their proper + # sub type by using PolymorphicRelatedQuerySetMixin + obj.relation + for obj in PlainModel.objects.select_related("relation") + ] + self.assertListEqual(grouped_q, [obj1, obj2, obj3, obj4]) + + def test_related_fetch_of_different_type_pks(self): + "pk on child is not same field type as pk on parent and thus prt field" + obj1 = ChildModel.objects.create(name="m1", other_name="c1") + obj2 = ParentModel.objects.create(name="m2") + obj3 = NonAutoPKChild.objects.create(name="m3", other_name="napk1") + obj4 = AltChildModel.objects.create(name="m4", other_name="acm1") + obj5 = ChildModel.objects.create(name="m5", other_name="c3") + obj6 = AltChildAsBaseModel.objects.create(name="m6", more_name="acab1") + obj7 = NonAutoPKChild.objects.create(name="m7", other_name="napk2") + + PlainModel.objects.create(relation=obj1) + PlainModel.objects.create(relation=obj2) + PlainModel.objects.create(relation=obj3) + PlainModel.objects.create(relation=obj4) + PlainModel.objects.create(relation=obj5) + PlainModel.objects.create(relation=obj6) + PlainModel.objects.create(relation=obj7) + + def object_info(obj): + return { + "pk": obj.pk, + "parentmodel_ptr": getattr(obj, "parentmodel_ptr_id", None), + "altchildmodel_ptr": getattr(obj, "altchildmodel_ptr_id", None), + } + + with self.assertNumQueries(5): + grouped_q = [ + # these obj.relation values will all be ParentModel's + # unless we fix select related but should be their proper + # sub type by using PolymorphicRelatedQuerySetMixin + obj.relation + for obj in PlainModel.objects.select_related("relation").order_by("pk") + ] + grouped_info = [object_info(obj) for obj in grouped_q] + self.assertListEqual( + grouped_info, [object_info(obj) for obj in [obj1, obj2, obj3, obj4, obj5, obj6, obj7]] + ) + self.assertListEqual(grouped_q, [obj1, obj2, obj3, obj4, obj5, obj6, obj7]) + + def test_related_fetch_of_non_sequential_pks(self): + obj1 = ChildModel.objects.create(name="m1", other_name="c1") + obj2 = ParentModel.objects.create(name="m2") + + # FIXME use PK from table to get in sequential PKS + # from django.db import connection + # with connection.cursor() as cursor: + # cursor.execute('INSERT INTO "tests_childmodel" ("other_name", "parentmodel_ptr_id") VALUES (%s, %s)', ['fake', 1]) + + obj3 = ChildModel.objects.create(name="m3", other_name="c2") + obj4 = AltChildModel.objects.create(name="m4", other_name="acm1") + obj5 = ChildModel.objects.create(name="m5", other_name="c3") + obj6 = AltChildAsBaseModel.objects.create(name="m6", more_name="acab1") + + PlainModel.objects.create(relation=obj1) + PlainModel.objects.create(relation=obj2) + PlainModel.objects.create(relation=obj3) + PlainModel.objects.create(relation=obj4) + PlainModel.objects.create(relation=obj5) + PlainModel.objects.create(relation=obj6) + + def object_info(obj): + return { + "pk": obj.pk, + "parentmodel_ptr": getattr(obj, "parentmodel_ptr_id", None), + "altchildmodel_ptr": getattr(obj, "altchildmodel_ptr_id", None), + } + + with self.assertNumQueries(4): + grouped_q = [ + # these obj.relation values will all be ParentModel's + # unless we fix select related but should be their proper + # sub type by using PolymorphicRelatedQuerySetMixin + obj.relation + for obj in PlainModel.objects.select_related("relation") + ] + grouped_info = [object_info(obj) for obj in grouped_q] + self.assertListEqual( + grouped_info, [object_info(obj) for obj in [obj1, obj2, obj3, obj4, obj5, obj6]] + ) + self.assertListEqual(grouped_q, [obj1, obj2, obj3, obj4, obj5, obj6]) def test_normal_django_to_poly_related_give_poly_type_using_select_related_true(self): obj1 = ParentModel.objects.create(name="m1") @@ -2335,7 +2442,12 @@ def test_select_related_fecth_all_poly_classes_indirect_related(self): # Prefetch content_types ContentType.objects.get_for_models( - PlainModel, PlainA, ModelExtraExternal, AltChildAsBaseModel, AltChildWithM2MModel + AltChildAsBaseModel, + AltChildWithM2MModel, + ModelExtraExternal, + NonAutoPKChild, + PlainA, + PlainModel, ) with self.assertNumQueries(1): From 022efb364b54a3c4783e4bbd8d5936ffc36bd6a9 Mon Sep 17 00:00:00 2001 From: Paul Gammans Date: Mon, 29 Dec 2025 16:22:30 +0000 Subject: [PATCH 14/14] testing: postgres issues with test object names --- src/polymorphic/tests/test_orm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/polymorphic/tests/test_orm.py b/src/polymorphic/tests/test_orm.py index 386571d7..cb1735fb 100644 --- a/src/polymorphic/tests/test_orm.py +++ b/src/polymorphic/tests/test_orm.py @@ -2142,7 +2142,7 @@ def test_normal_django_to_poly_related_give_poly_type_using_select_related_true( obj2 = ChildModel.objects.create(name="m2", other_name="m2") obj3 = ChildModel.objects.create(name="m1") obj4 = AltChildAsBaseModel.objects.create( - name="ac2", other_name="ac2name", more_name="ac2morename" + name="ac2", other_name="ac2name", more_name="ac2_mn" ) PlainModel.objects.create(relation=obj1) @@ -2482,7 +2482,7 @@ def test_select_related_on_poly_classes_supports_multi_level_inheritance(self): obj_acab2 = AltChildAsBaseModel.objects.create( name="ac2ab", other_name="acab2name", - more_name="acab2morename", + more_name="acab2_mn", link_on_altchild=plain_a_obj_2, ) @@ -2508,7 +2508,7 @@ def test_select_related_on_poly_classes_supports_multi_level_inheritance(self): assert obj_list[1].relation.name == "c1" assert obj_list[2].relation.name == "ac1" assert obj_list[3].relation.name == "ac2ab" - assert obj_list[3].relation.more_name == "acab2morename" + assert obj_list[3].relation.more_name == "acab2_mn" obj_list[1].relation.link_on_child obj_list[2].relation.link_on_altchild obj_list[3].relation.link_on_altchild @@ -2521,7 +2521,7 @@ def test_select_related_on_poly_classes_with_modelname(self): obj_acab2 = AltChildAsBaseModel.objects.create( name="acab2", other_name="acab2name", - more_name="acab2morename", + more_name="acab2_mn", link_on_altchild=plain_a_obj_1, ) obj_p_1 = PlainModel.objects.create(relation=obj_p) @@ -2711,7 +2711,7 @@ def test_select_related_field_from_polymorphic_child_class_using_modelnames_mult obj_acab2 = AltChildAsBaseModel.objects.create( name="acab2", other_name="acab2name", - more_name="acab2morename", + more_name="acab2_mn", link_on_altchild=plain_a_obj_1, ) obj_c1 = ChildModel.objects.create(name="c1", other_name="c1name")