Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions dojo/product_type/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from django.contrib import messages
from django.contrib.admin.utils import NestedObjects
from django.db import DEFAULT_DB_ALIAS
from django.db.models import Count, IntegerField, OuterRef, Subquery, Value
from django.db.models import OuterRef, Value
from django.db.models.functions import Coalesce
from django.db.models.query import QuerySet
from django.http import HttpResponseRedirect
Expand Down Expand Up @@ -82,13 +82,10 @@ def prefetch_for_product_type(prod_types):
logger.debug("unable to prefetch because query was already executed")
return prod_types

prod_subquery = Subquery(
Product.objects.filter(prod_type_id=OuterRef("pk"))
.values("prod_type_id")
.annotate(c=Count("*"))
.values("c")[:1],
output_field=IntegerField(),
)
prod_subquery = build_count_subquery(
Product.objects.filter(prod_type_id=OuterRef("pk")),
group_field="prod_type_id",
)
base_findings = Finding.objects.filter(test__engagement__product__prod_type_id=OuterRef("pk"))
count_subquery = partial(build_count_subquery, group_field="test__engagement__product__prod_type_id")

Expand Down
6 changes: 5 additions & 1 deletion dojo/query_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

def build_count_subquery(model_qs: QuerySet, group_field: str) -> Subquery:
"""Return a Subquery that yields one aggregated count per `group_field`."""
# Important: slicing (`[:1]`) on an unordered queryset makes Django add an implicit `ORDER BY <pk>`.
# With aggregation, Django then includes that pk in the GROUP BY, which collapses counts to 1.
# Ordering by `group_field` avoids that and keeps the GROUP BY stable.
model_qs = model_qs.order_by()
return Subquery(
model_qs.values(group_field).annotate(c=Count("*")).values("c")[:1], # one row per group_field
model_qs.values(group_field).annotate(c=Count("pk")).order_by(group_field).values("c")[:1], # one row per group_field
output_field=IntegerField(),
)
19 changes: 19 additions & 0 deletions unittests/test_product_type_counts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dojo.models import Product, Product_Type
from dojo.product_type.views import prefetch_for_product_type
from unittests.dojo_test_case import DojoTestCase, versioned_fixtures


@versioned_fixtures
class TestProductTypeCounts(DojoTestCase):
fixtures = ["dojo_testdata.json"]

def test_prefetch_for_product_type_prod_count_matches_direct_count(self):
product_type = Product_Type.objects.create(name="PT count test")
Product.objects.create(name="PT product 1", description="test", prod_type=product_type)
Product.objects.create(name="PT product 2", description="test", prod_type=product_type)

annotated = prefetch_for_product_type(Product_Type.objects.filter(id=product_type.id))
annotated_count = annotated.values_list("prod_count", flat=True).get()

direct_count = Product.objects.filter(prod_type_id=product_type.id).count()
self.assertEqual(annotated_count, direct_count)
21 changes: 21 additions & 0 deletions unittests/test_query_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from django.db.models import Count

from dojo.engagement.views import prefetch_for_view_tests
from dojo.models import Finding, Test
from unittests.dojo_test_case import DojoTestCase, versioned_fixtures


@versioned_fixtures
class TestQueryUtils(DojoTestCase):
fixtures = ["dojo_testdata.json"]

def test_prefetch_for_view_tests_finding_counts_match_direct_count(self):
test = Test.objects.annotate(finding_count=Count("finding")).filter(finding_count__gt=1).first()
# If fixtures ever change, ensure we still have a representative test case.
self.assertIsNotNone(test)

annotated = prefetch_for_view_tests(Test.objects.filter(id=test.id))
annotated_count = annotated.values_list("count_findings_test_all", flat=True).get()

direct_count = Finding.objects.filter(test_id=test.id).count()
self.assertEqual(annotated_count, direct_count)