Skip to content

Commit adca4a1

Browse files
authored
Preserve row type in merge_annotations_from_custom_method (#3383)
1 parent 18413d5 commit adca4a1

2 files changed

Lines changed: 37 additions & 1 deletion

File tree

mypy_django_plugin/transformers/querysets.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,16 @@ def merge_annotations_from_custom_method(ctx: MethodContext, django_context: Dja
505505

506506
api = helpers.get_typechecker_api(ctx)
507507
annotated_type = get_annotated_type(api, django_model.typ, fields_dict=new_td)
508-
return default_return_type.copy_modified(args=[annotated_type, annotated_type])
508+
new_args: list[MypyType] = [annotated_type]
509+
if len(default_return_type.args) > 1:
510+
original_row = get_proper_type(default_return_type.args[1])
511+
if isinstance(original_row, Instance) and helpers.is_model_type(original_row.type):
512+
new_args.append(annotated_type)
513+
else:
514+
new_args.append(default_return_type.args[1])
515+
else:
516+
new_args.append(annotated_type)
517+
return default_return_type.copy_modified(args=new_args)
509518

510519

511520
def resolve_field_lookups(lookup_exprs: Sequence[Expression], django_context: DjangoContext) -> list[str] | None:

tests/typecheck/managers/querysets/test_annotate.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,33 @@
362362
num_posts = models.IntegerField()
363363
text = models.CharField(max_length=100)
364364
365+
- case: annotate_then_values_list_flat_slice_preserves_row_type
366+
main: |
367+
from typing import Iterable
368+
from typing_extensions import reveal_type
369+
from myapp.models import Blog
370+
from django.db.models import F
371+
372+
qs = Blog.objects.annotate(a=F("id")).values_list("id", flat=True)
373+
reveal_type(qs) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Blog@AnnotatedWith[TypedDict({'a': Any})], int]"
374+
reveal_type(qs[:]) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Blog@AnnotatedWith[TypedDict({'a': Any})], int]"
375+
376+
def get_blog_ids() -> Iterable[int]:
377+
return Blog.objects.annotate(a=F("id")).values_list("id", flat=True)[:]
378+
379+
# Also preserve non-flat tuple row types and TypedDict row types from values().
380+
reveal_type(Blog.objects.annotate(a=F("id")).values_list("id", "text")[:]) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Blog@AnnotatedWith[TypedDict({'a': Any})], tuple[int, str]]"
381+
reveal_type(Blog.objects.annotate(a=F("id")).values("id")[:]) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Blog@AnnotatedWith[TypedDict({'a': Any})], TypedDict({'id': int})]"
382+
installed_apps:
383+
- myapp
384+
files:
385+
- path: myapp/__init__.py
386+
- path: myapp/models.py
387+
content: |
388+
from django.db import models
389+
class Blog(models.Model):
390+
text = models.CharField(max_length=100)
391+
365392
- case: test_annotate_with_filter
366393
main: |
367394
from django.db import models

0 commit comments

Comments
 (0)