From 1331ce15e830852f039ab44eab500f3621b606fd Mon Sep 17 00:00:00 2001 From: Federico Bond Date: Tue, 12 May 2026 18:26:34 +1000 Subject: [PATCH] Preserve row type when .annotate() is called on a generic QuerySet --- mypy_django_plugin/transformers/querysets.py | 48 +++++++++++-------- .../managers/querysets/test_annotate.yml | 13 +++++ 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index 38711f0dc..52278ee41 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -408,6 +408,31 @@ def _extract_model_type_var_upper_bound(ctx: MethodContext) -> Instance | None: return None +def _resolve_annotate_row_type( + api: TypeChecker, + default_return_type: Instance, + annotated_model: ProperType, + expression_types: dict[str, MypyType], +) -> MypyType: + if len(default_return_type.args) <= 1: + return annotated_model + original_row_type = get_proper_type(default_return_type.args[1]) + if isinstance(original_row_type, TypedDictType): + return api.named_generic_type( + "builtins.dict", + [api.named_generic_type("builtins.str", []), AnyType(TypeOfAny.from_omitted_generics)], + ) + if isinstance(original_row_type, TupleType): + if original_row_type.partial_fallback.type.has_base("typing.NamedTuple"): + # Rebuild the NamedTuple with existing fields + annotation fields. + annotation_fields = {name: AnyType(TypeOfAny.from_omitted_generics) for name in expression_types} + return helpers.extend_oneoff_named_tuple(api, "Row", original_row_type, annotation_fields) + return api.named_generic_type("builtins.tuple", [AnyType(TypeOfAny.from_omitted_generics)]) + if isinstance(original_row_type, Instance) and helpers.is_model_type(original_row_type.type): + return annotated_model + return original_row_type + + def extract_proper_type_queryset_annotate(ctx: MethodContext, django_context: DjangoContext) -> MypyType: django_model = helpers.get_model_info_from_qs_ctx(ctx, django_context) if django_model is None: @@ -424,7 +449,8 @@ def extract_proper_type_queryset_annotate(ctx: MethodContext, django_context: Dj if expression_types: fields_dict = helpers.make_typeddict(api, expression_types) upper_annotated = get_annotated_type(api, upper_bound, fields_dict=fields_dict) - return default_return_type.copy_modified(args=[upper_annotated, upper_annotated]) + row_type = _resolve_annotate_row_type(api, default_return_type, upper_annotated, expression_types) + return default_return_type.copy_modified(args=[upper_annotated, row_type]) return AnyType(TypeOfAny.from_omitted_generics) default_return_type = get_proper_type(ctx.default_return_type) @@ -452,25 +478,7 @@ def extract_proper_type_queryset_annotate(ctx: MethodContext, django_context: Dj fields_dict = helpers.make_typeddict(api, all_fields) annotated_type = get_annotated_type(api, django_model.typ, fields_dict=fields_dict) - row_type: MypyType - if len(default_return_type.args) > 1: - original_row_type = get_proper_type(default_return_type.args[1]) - row_type = original_row_type - if isinstance(original_row_type, TypedDictType): - row_type = api.named_generic_type( - "builtins.dict", [api.named_generic_type("builtins.str", []), AnyType(TypeOfAny.from_omitted_generics)] - ) - elif isinstance(original_row_type, TupleType): - if original_row_type.partial_fallback.type.has_base("typing.NamedTuple"): - # Rebuild the NamedTuple with existing fields + annotation fields. - annotation_fields = {name: AnyType(TypeOfAny.from_omitted_generics) for name in expression_types} - row_type = helpers.extend_oneoff_named_tuple(api, "Row", original_row_type, annotation_fields) - else: - row_type = api.named_generic_type("builtins.tuple", [AnyType(TypeOfAny.from_omitted_generics)]) - elif isinstance(original_row_type, Instance) and helpers.is_model_type(original_row_type.type): - row_type = annotated_type - else: - row_type = annotated_type + row_type = _resolve_annotate_row_type(api, default_return_type, annotated_type, expression_types) return default_return_type.copy_modified(args=[annotated_type, row_type]) diff --git a/tests/typecheck/managers/querysets/test_annotate.yml b/tests/typecheck/managers/querysets/test_annotate.yml index 19ae48eea..2d66d1674 100644 --- a/tests/typecheck/managers/querysets/test_annotate.yml +++ b/tests/typecheck/managers/querysets/test_annotate.yml @@ -1012,3 +1012,16 @@ class FooModel(models.Model): objects = FooManager() + +- case: values_then_annotate_on_generic_queryset_typevar + main: | + from typing import Any + from typing_extensions import reveal_type, TypeVar + from django.db import models + + _Model = TypeVar("_Model", bound=models.Model) + + def latest_max(qs: models.QuerySet[_Model]) -> dict[str, Any] | None: + annotated = qs.values("id").annotate(foo=models.Max("id")) + reveal_type(annotated) # N: Revealed type is "django.db.models.query.QuerySet[django.db.models.base.Model, dict[str, Any]]" + return annotated.first()