diff --git a/metrics/api/serializers/__init__.py b/metrics/api/serializers/__init__.py
index 7519667b58..e948a5a2b1 100644
--- a/metrics/api/serializers/__init__.py
+++ b/metrics/api/serializers/__init__.py
@@ -1,4 +1,8 @@
from .charts import ChartsSerializer
+from .dual_category_tables import (
+ DualCategoryTablesSerializer,
+ DualCategoryTablesResponseSerializer,
+)
from .headlines import HeadlinesQuerySerializer, CoreHeadlineSerializer
from .trends import TrendsQuerySerializer, TrendsResponseSerializer
from .downloads import (
diff --git a/metrics/api/serializers/charts/__init__.py b/metrics/api/serializers/charts/__init__.py
index c70380d523..239e302625 100644
--- a/metrics/api/serializers/charts/__init__.py
+++ b/metrics/api/serializers/charts/__init__.py
@@ -2,7 +2,6 @@
ChartPlotSerializer,
ChartPlotsListSerializer,
ChartsResponseSerializer,
- EncodedChartResponseSerializer,
EncodedChartsRequestSerializer,
ChartsSerializer,
)
diff --git a/metrics/api/serializers/charts/common.py b/metrics/api/serializers/charts/common.py
index decec6be74..bfee70811f 100644
--- a/metrics/api/serializers/charts/common.py
+++ b/metrics/api/serializers/charts/common.py
@@ -14,6 +14,8 @@
class BaseChartsSerializer(serializers.Serializer):
+ """Base serializer for chart request payloads, containing common fields across different chart types."""
+
file_format = serializers.ChoiceField(
choices=FILE_FORMAT_CHOICES,
help_text=help_texts.CHART_FILE_FORMAT_FIELD,
@@ -94,3 +96,21 @@ class BaseChartsSerializer(serializers.Serializer):
allow_null=True,
default="",
)
+
+
+class ChartPreviewQueryParamsSerializer(serializers.Serializer):
+ """Serializer for query parameters when requesting a chart preview."""
+
+ preview = serializers.BooleanField(required=False)
+
+
+class EncodedChartResponseSerializer(serializers.Serializer):
+ """Serializer for the response of an encoded chart generation, containing the encoded chart and related metadata."""
+
+ last_updated = serializers.CharField(
+ help_text=help_texts.ENCODED_CHARTS_LAST_UPDATED,
+ allow_blank=True,
+ )
+ chart = serializers.CharField(help_text=help_texts.ENCODED_CHARTS_RESPONSE)
+ alt_text = serializers.CharField(help_text=help_texts.CHARTS_ALT_TEXT)
+ figure = serializers.DictField(help_text=help_texts.CHARTS_FIGURE_OUTPUT)
diff --git a/metrics/api/serializers/charts/dual_category_charts.py b/metrics/api/serializers/charts/dual_category_charts.py
index dc6f7fa5cd..be03ab2677 100644
--- a/metrics/api/serializers/charts/dual_category_charts.py
+++ b/metrics/api/serializers/charts/dual_category_charts.py
@@ -10,10 +10,15 @@
DEFAULT_CHART_WIDTH,
DEFAULT_X_AXIS,
DEFAULT_Y_AXIS,
+ ChartAxisFields,
ChartTypes,
+ DataSourceFileType,
DEFAULT_Y_AXIS_MINIMUM_VAlUE,
+ extract_metric_group_from_metric,
+)
+from metrics.domain.models.charts.dual_category_charts import (
+ DualCategoryChartRequestParams,
)
-from metrics.domain.models.charts import DualCategoryChartRequestParams
class DualCategoryChartSegmentSerializer(serializers.Serializer):
@@ -33,30 +38,17 @@ class DualCategoryChartSegmentSerializer(serializers.Serializer):
)
-class StaticFieldsSerializer(PlotSerializer):
- theme = serializers.CharField(
- required=True,
- allow_blank=True,
- allow_null=True,
- )
- sub_theme = serializers.CharField(
- required=True,
- allow_blank=True,
- allow_null=True,
- )
-
-
class DualCategoryChartSerializer(BaseChartsSerializer):
chart_type = serializers.ChoiceField(
help_text=help_texts.CHART_TYPE_FIELD,
- choices=ChartTypes.selectable_choices(),
+ choices=ChartTypes.dual_category_chart_options(),
required=True,
)
primary_field_values = serializers.ListField(
child=serializers.CharField(),
help_text="List of primary field values for this segment",
- required=True,
- allow_empty=False,
+ required=False,
+ allow_empty=True,
)
secondary_category = serializers.CharField(
@@ -64,7 +56,7 @@ class DualCategoryChartSerializer(BaseChartsSerializer):
required=True,
)
- static_fields = StaticFieldsSerializer()
+ static_fields = PlotSerializer()
segments = serializers.ListField(
child=DualCategoryChartSegmentSerializer(),
@@ -72,20 +64,99 @@ class DualCategoryChartSerializer(BaseChartsSerializer):
required=True,
)
+ @classmethod
+ def validate(cls, attrs: dict) -> dict:
+ """Validate primary_field_values based on the selected x-axis."""
+ x_axis = attrs.get("x_axis") or DEFAULT_X_AXIS
+ primary_field_values = attrs.get("primary_field_values") or []
+ metric = attrs["static_fields"]["metric"]
+ metric_group = extract_metric_group_from_metric(metric=metric)
+ is_timeseries_data = DataSourceFileType[metric_group].is_timeseries
+
+ if is_timeseries_data:
+ if primary_field_values:
+ raise serializers.ValidationError(
+ {
+ "primary_field_values": (
+ "This field should not be provided for timeseries data."
+ )
+ }
+ )
+ if x_axis != ChartAxisFields.date.name:
+ raise serializers.ValidationError(
+ {
+ "x_axis": (
+ "This field should be set to 'date' for timeseries data."
+ )
+ }
+ )
+
+ elif not is_timeseries_data and not primary_field_values:
+ raise serializers.ValidationError(
+ {"primary_field_values": ("This field is required for headline data.")}
+ )
+
+ return attrs
+
def to_models(self, request: Request) -> DualCategoryChartRequestParams:
x_axis = self.data.get("x_axis") or DEFAULT_X_AXIS
y_axis = self.data.get("y_axis") or DEFAULT_Y_AXIS
- for plot in self.data["segments"]:
- plot["x_axis"] = x_axis
- plot["y_axis"] = y_axis
+ primary_field_values = self.data.get("primary_field_values") or []
+ secondary_category = self.data["secondary_category"]
+ static_fields: dict[str, str | int] = self.validated_data.pop("static_fields")
+
+ if static_fields["date_to"]:
+ static_fields["date_to"] = static_fields["date_to"].isoformat()
+
+ if static_fields["date_from"]:
+ static_fields["date_from"] = static_fields["date_from"].isoformat()
+
+ groups_plots = []
+ segments: list[dict] = self.data["segments"]
+
+ metric_group = extract_metric_group_from_metric(metric=static_fields["metric"])
+ is_timeseries_data = DataSourceFileType[metric_group].is_timeseries
+
+ # If timeseries data
+ if is_timeseries_data:
+ plots = [
+ {
+ "x_axis": x_axis,
+ "y_axis": y_axis,
+ "line_colour": segment["colour"],
+ **static_fields,
+ secondary_category: segment["secondary_field_value"],
+ "chart_type": self.data["chart_type"],
+ "label": segment["label"],
+ }
+ for segment in segments
+ ]
+ groups_plots.extend(plots)
+
+ else:
+ for primary_field_value in primary_field_values:
+ plots = [
+ {
+ "x_axis": x_axis,
+ "y_axis": y_axis,
+ "line_colour": segment["colour"],
+ **static_fields,
+ x_axis: primary_field_value,
+ secondary_category: segment["secondary_field_value"],
+ "chart_type": self.data["chart_type"],
+ "label": segment["label"],
+ }
+ for segment in segments
+ ]
+ groups_plots.extend(plots)
return DualCategoryChartRequestParams(
chart_type=self.data["chart_type"],
- primary_field_values=self.data["primary_field_values"],
+ primary_field_values=primary_field_values,
secondary_category=self.data["secondary_category"],
static_fields=self.data["static_fields"],
- segments=self.data["segments"],
+ plots=groups_plots,
file_format=self.data["file_format"],
chart_height=self.data["chart_height"] or DEFAULT_CHART_HEIGHT,
chart_width=self.data["chart_width"] or DEFAULT_CHART_WIDTH,
@@ -97,4 +168,5 @@ def to_models(self, request: Request) -> DualCategoryChartRequestParams:
or DEFAULT_Y_AXIS_MINIMUM_VAlUE,
y_axis_maximum_value=self.data["y_axis_maximum_value"],
request=request,
+ legend_title=self.data.get("legend_title", ""),
)
diff --git a/metrics/api/serializers/charts/single_category_charts.py b/metrics/api/serializers/charts/single_category_charts.py
index 46b9d958f0..233c348b09 100644
--- a/metrics/api/serializers/charts/single_category_charts.py
+++ b/metrics/api/serializers/charts/single_category_charts.py
@@ -112,13 +112,3 @@ class EncodedChartsRequestSerializer(ChartsSerializer):
help_text=help_texts.ENCODED_CHARTS_FILE_FORMAT_FIELD,
default="svg",
)
-
-
-class EncodedChartResponseSerializer(serializers.Serializer):
- last_updated = serializers.CharField(
- help_text=help_texts.ENCODED_CHARTS_LAST_UPDATED,
- allow_blank=True,
- )
- chart = serializers.CharField(help_text=help_texts.ENCODED_CHARTS_RESPONSE)
- alt_text = serializers.CharField(help_text=help_texts.CHARTS_ALT_TEXT)
- figure = serializers.DictField(help_text=help_texts.CHARTS_FIGURE_OUTPUT)
diff --git a/metrics/api/serializers/charts/subplot_charts.py b/metrics/api/serializers/charts/subplot_charts.py
index 99475d04c4..b366b3227e 100644
--- a/metrics/api/serializers/charts/subplot_charts.py
+++ b/metrics/api/serializers/charts/subplot_charts.py
@@ -213,7 +213,3 @@ def to_models(self, request: Request) -> SubplotChartRequestParameters:
subplots=self.validated_data["subplots"],
request=request,
)
-
-
-class ChartPreviewQueryParamsSerializer(serializers.Serializer):
- preview = serializers.BooleanField(required=False)
diff --git a/metrics/api/serializers/dual_category_tables.py b/metrics/api/serializers/dual_category_tables.py
new file mode 100644
index 0000000000..37a038e068
--- /dev/null
+++ b/metrics/api/serializers/dual_category_tables.py
@@ -0,0 +1,139 @@
+import contextlib
+
+from django.db.utils import OperationalError
+from rest_framework import serializers
+from rest_framework.request import Request
+
+from metrics.api.serializers import help_texts
+from metrics.api.serializers.plots import PlotSerializer
+from metrics.domain.common.utils import (
+ DEFAULT_CHART_HEIGHT,
+ DEFAULT_CHART_WIDTH,
+ DEFAULT_X_AXIS,
+ DEFAULT_Y_AXIS,
+ ChartAxisFields,
+)
+from metrics.domain.models import ChartRequestParams
+
+
+class DualCategoryTableSegmentSerializer(serializers.Serializer):
+ colour = serializers.CharField(
+ required=False,
+ allow_blank=True,
+ allow_null=True,
+ default="",
+ help_text=help_texts.LABEL_FIELD,
+ )
+
+ secondary_field_value = serializers.CharField(
+ required=False,
+ allow_blank=True,
+ allow_null=True,
+ default="",
+ help_text=help_texts.LABEL_FIELD,
+ )
+
+ label = serializers.CharField(
+ required=False,
+ allow_blank=True,
+ allow_null=True,
+ default="",
+ help_text=help_texts.LABEL_FIELD,
+ )
+
+
+class DualCategoryTableSegmentListSerializer(serializers.ListSerializer):
+ child = DualCategoryTableSegmentSerializer()
+
+
+class DualCategoryTablesSerializer(serializers.Serializer):
+
+ segments = DualCategoryTableSegmentListSerializer()
+
+ static_fields = PlotSerializer()
+
+ x_axis = serializers.ChoiceField(
+ choices=ChartAxisFields.choices(),
+ required=False,
+ allow_blank=True,
+ allow_null=True,
+ help_text=help_texts.CHART_X_AXIS,
+ default=DEFAULT_X_AXIS,
+ )
+
+ y_axis = serializers.ChoiceField(
+ choices=ChartAxisFields.choices(),
+ required=False,
+ allow_blank=True,
+ allow_null=True,
+ help_text=help_texts.CHART_Y_AXIS,
+ default=DEFAULT_Y_AXIS,
+ )
+
+ primary_field_values = serializers.ListField(
+ child=serializers.CharField(),
+ help_text="List of primary field values for this segment",
+ required=True,
+ allow_empty=False,
+ )
+
+ secondary_category = serializers.CharField(
+ help_text="Secondary category field for the chart",
+ required=True,
+ )
+
+ def __init__(self, *args, **kwargs):
+ with contextlib.suppress(OperationalError):
+ super().__init__(*args, **kwargs)
+
+ def to_models(self, request: Request) -> ChartRequestParams:
+
+ groups_plots = []
+ primary_field_values = self.data.get("primary_field_values")
+ x_axis = self.data.get("x_axis") or DEFAULT_X_AXIS
+ y_axis = self.data.get("y_axis") or DEFAULT_Y_AXIS
+ static_fields = self.data.get("static_fields")
+ print(f"AIDAN static_fields {static_fields}")
+ topic = static_fields.get("topic")
+
+ for primary_field_value in primary_field_values:
+ for segment in self.data["segments"]:
+ plot = {
+ "y_axis": y_axis,
+ "x_axis": primary_field_value,
+ self.data.get("secondary_category"): segment[
+ "secondary_field_value"
+ ],
+ **static_fields,
+ }
+ groups_plots.append(plot)
+ print(f"AIDAN: plots {groups_plots}")
+ return ChartRequestParams(
+ chart_height=DEFAULT_CHART_HEIGHT,
+ chart_width=DEFAULT_CHART_WIDTH,
+ file_format="svg",
+ plots=groups_plots,
+ request=request,
+ x_axis=x_axis,
+ y_axis=y_axis,
+ )
+
+
+class DualCategoryTablesResponseValueSerializer(serializers.Serializer):
+ label = serializers.CharField()
+ value = serializers.CharField()
+ in_reporting_delay_period = serializers.BooleanField()
+ # Confidence intervals aren't implemented for dual category charts
+
+
+class DualCategoryTablesResponseValuesListSerializer(serializers.ListSerializer):
+ child = DualCategoryTablesResponseValueSerializer()
+
+
+class DualCategoryTablesResponsePlotsListSerializer(serializers.Serializer):
+ reference = serializers.CharField()
+ values = DualCategoryTablesResponseValuesListSerializer()
+
+
+class DualCategoryTablesResponseSerializer(serializers.ListSerializer):
+ child = DualCategoryTablesResponsePlotsListSerializer()
diff --git a/metrics/api/urls_construction.py b/metrics/api/urls_construction.py
index fd42ad32a0..afcddffdeb 100644
--- a/metrics/api/urls_construction.py
+++ b/metrics/api/urls_construction.py
@@ -24,6 +24,7 @@
BulkDownloadsView,
ChartsView,
ColdAlertViewSet,
+ DualCategoryTablesView,
DownloadsView,
EncodedChartsView,
HeadlinesView,
@@ -206,6 +207,7 @@ def construct_public_api_urlpatterns(
re_path(f"^{API_PREFIX}maps/v1", MapsView.as_view()),
re_path(f"^{API_PREFIX}tables/v4", TablesView.as_view()),
re_path(f"^{API_PREFIX}tables/subplot/v1", TablesSubplotView.as_view()),
+ re_path(f"^{API_PREFIX}tables/dual-category/v1", DualCategoryTablesView.as_view()),
re_path(f"^{API_PREFIX}trends/v3", TrendsView.as_view()),
]
diff --git a/metrics/api/views/__init__.py b/metrics/api/views/__init__.py
index d7dafa127a..2b827c9b2f 100644
--- a/metrics/api/views/__init__.py
+++ b/metrics/api/views/__init__.py
@@ -3,7 +3,7 @@
from .headlines import HeadlinesView
from .downloads import DownloadsView, BulkDownloadsView, SubplotDownloadsView
from .health import HealthView
-from .tables import TablesView, TablesSubplotView
+from .tables import DualCategoryTablesView, TablesView, TablesSubplotView
from .trends import TrendsView
from .audit import (
AuditAPITimeSeriesViewSet,
diff --git a/metrics/api/views/charts/dual_category_charts.py b/metrics/api/views/charts/dual_category_charts.py
index 618fd7a0cf..2a7c79ce69 100644
--- a/metrics/api/views/charts/dual_category_charts.py
+++ b/metrics/api/views/charts/dual_category_charts.py
@@ -1,20 +1,42 @@
+import io
import logging
from http import HTTPStatus
+from django.http import FileResponse
from drf_spectacular.utils import OpenApiExample, extend_schema
from rest_framework import permissions
from rest_framework.response import Response
from rest_framework.views import APIView
import config
+from caching.private_api.decorators import cache_response
from metrics.api.enums import AppMode
from metrics.api.serializers.charts import (
ChartsResponseSerializer,
)
+from metrics.api.serializers.charts.common import (
+ ChartPreviewQueryParamsSerializer,
+ EncodedChartResponseSerializer,
+)
from metrics.api.serializers.charts.dual_category_charts import (
DualCategoryChartSerializer,
)
from metrics.domain.charts.colour_scheme import RGBAChartLineColours
+from metrics.domain.models.charts.dual_category_charts import (
+ DualCategoryChartRequestParams,
+)
+from metrics.interfaces.charts.common.generation import (
+ ChartResult,
+ generate_chart_as_file,
+ generate_encoded_chart,
+)
+from metrics.interfaces.charts.dual_category_charts.access import (
+ DualCategoryChartsInterface,
+)
+from metrics.interfaces.plots.access import (
+ DataNotFoundForAnyPlotError,
+ InvalidPlotParametersError,
+)
CHARTS_API_TAG = "charts"
@@ -25,18 +47,15 @@
"file_format": "svg",
"chart_height": 200,
"chart_width": 320,
- "x_axis": "metric",
- "primary_field_values": ["m", "f"],
- "y_axis": "sex",
+ "x_axis": "date",
+ "y_axis": "metric",
"x_axis_title": "",
"y_axis_title": "",
"y_axis_minimum_value": None,
"y_axis_maximum_value": None,
- "chart_type": "bar",
+ "chart_type": "stacked_bar",
"secondary_category": "age",
"static_fields": {
- "theme": "infectious_disease",
- "sub_theme": "respiratory",
"topic": "COVID-19",
"metric": "COVID-19_cases_rateRollingMean",
"stratum": "default",
@@ -70,7 +89,6 @@ def get_permissions(self) -> list[type[permissions.BasePermission]]:
return [permissions.IsAuthenticated()]
return super().get_permissions()
- @classmethod
@extend_schema(
request=DualCategoryChartSerializer,
responses={HTTPStatus.OK.value: ChartsResponseSerializer},
@@ -83,15 +101,109 @@ def get_permissions(self) -> list[type[permissions.BasePermission]]:
)
],
)
- def post(cls, request, *args, **kwargs):
- request_serializer = DualCategoryChartSerializer(data=request.data)
- request_serializer.is_valid(raise_exception=True)
+ def post(self, request, *args, **kwargs) -> FileResponse | Response:
+ chart_preview_serializer = ChartPreviewQueryParamsSerializer(
+ data=request.query_params
+ )
+ chart_preview_serializer.is_valid(raise_exception=True)
+ payload = chart_preview_serializer.validated_data
+
+ if payload.get("preview", False):
+ return self._process_post_request_as_preview(request, *args, **kwargs)
+ return self._process_post_request_as_encoded_svg(request, *args, **kwargs)
+
+ @cache_response(timeout=0)
+ def _process_post_request_as_preview(
+ self, request, *args, **kwargs
+ ) -> FileResponse:
+ """Handles the inbound request as `preview=true` in this case we don't use the cache
+
+ Notes:
+ - With a timeout of `0`, the response is never
+ actually put into the cache
+
+ Returns:
+ `Response` containing the rendered chart as an image file
+
+ """
+ serializer = DualCategoryChartSerializer(data=request.data)
+ serializer.is_valid(raise_exception=True)
+
+ chart_request_params: DualCategoryChartRequestParams = serializer.to_models(
+ request=request
+ )
+
+ return self._handle_chart_as_file(chart_request_params=chart_request_params)
+
+ @classmethod
+ def _handle_chart_as_file(
+ cls, chart_request_params: DualCategoryChartRequestParams
+ ) -> FileResponse | Response:
+ """
+ Handles the process of generating a chart and returning it as a file response.
+
+ Args:
+ chart_request_params: A `DualCategoryChartRequestParams` model containing all the necessary parameters to generate the chart.
+
+ Returns:
+ A `FileResponse` containing the generated chart image, or a `Response` with an error message if chart generation fails due to invalid parameters or missing data.
+ """
+ try:
+ chart_image: bytes = generate_chart_as_file(
+ chart_request_params=chart_request_params,
+ interface=DualCategoryChartsInterface,
+ )
+ except (InvalidPlotParametersError, DataNotFoundForAnyPlotError) as error:
+ return Response(
+ status=HTTPStatus.BAD_REQUEST, data={"error_message": str(error)}
+ )
+
+ return FileResponse(
+ io.BytesIO(chart_image),
+ content_type=f"image/{chart_request_params.file_format}",
+ )
+
+ @classmethod
+ def _handle_encoded_svg(
+ cls, chart_request_params: DualCategoryChartRequestParams
+ ) -> Response:
+ """Handles the process of generating a chart, encoding it, and returning it in the response.
+
+ Args:
+ chart_request_params: A `DualCategoryChartRequestParams` model containing all the necessary parameters to generate the chart.
+
+ Returns:
+ A `Response` containing the encoded chart and related metadata, or an error message if chart generation fails due to invalid parameters or missing data.
+ """
+ try:
+ chart_result: ChartResult = generate_encoded_chart(
+ chart_request_params=chart_request_params,
+ interface=DualCategoryChartsInterface,
+ )
+
+ except (InvalidPlotParametersError, DataNotFoundForAnyPlotError) as error:
+ return Response(
+ status=HTTPStatus.BAD_REQUEST, data={"error_message": str(error)}
+ )
+
+ serializer = EncodedChartResponseSerializer(data=chart_result.output())
+ serializer.is_valid(raise_exception=True)
- chart_request_params = request_serializer.to_models(request=request)
+ return Response(data=serializer.data)
- logger.info("This endpoint is not yet complete")
+ @cache_response()
+ def _process_post_request_as_encoded_svg(
+ self, request, *args, **kwargs
+ ) -> Response:
+ """Handles the inbound request as `preview=false` in this case we use the cache
- temporary_dict_representation = chart_request_params.model_dump()
- temporary_dict_representation.pop("request")
+ Returns:
+ `FileResponse` containing the chart image
+ """
+ serializer = DualCategoryChartSerializer(data=request.data)
+ serializer.is_valid(raise_exception=True)
- return Response(data=temporary_dict_representation)
+ chart_request_params: DualCategoryChartRequestParams = serializer.to_models(
+ request=request
+ )
+ return self._handle_encoded_svg(chart_request_params=chart_request_params)
diff --git a/metrics/api/views/charts/single_category_charts.py b/metrics/api/views/charts/single_category_charts.py
index 2767c3982b..41b9e86d42 100644
--- a/metrics/api/views/charts/single_category_charts.py
+++ b/metrics/api/views/charts/single_category_charts.py
@@ -14,9 +14,9 @@
from metrics.api.serializers import ChartsSerializer
from metrics.api.serializers.charts import (
ChartsResponseSerializer,
- EncodedChartResponseSerializer,
EncodedChartsRequestSerializer,
)
+from metrics.api.serializers.charts.common import EncodedChartResponseSerializer
from metrics.domain.models import ChartRequestParams
from metrics.interfaces.charts.single_category_charts import access
from metrics.interfaces.plots.access import (
diff --git a/metrics/api/views/charts/subplot_charts/api_view.py b/metrics/api/views/charts/subplot_charts/api_view.py
index 3fe72ad50a..19ce77eb0c 100644
--- a/metrics/api/views/charts/subplot_charts/api_view.py
+++ b/metrics/api/views/charts/subplot_charts/api_view.py
@@ -8,11 +8,11 @@
from rest_framework.views import APIView
from caching.private_api.decorators import cache_response
-from metrics.api.serializers.charts import (
+from metrics.api.serializers.charts.common import (
+ ChartPreviewQueryParamsSerializer,
EncodedChartResponseSerializer,
)
from metrics.api.serializers.charts.subplot_charts import (
- ChartPreviewQueryParamsSerializer,
SubplotChartRequestSerializer,
)
from metrics.api.views.charts.subplot_charts.request_example import (
@@ -66,8 +66,7 @@ def _process_post_request_as_preview(
actually put into the cache
Returns:
- `Response` containing the JSON data for the
- chart and all of its associated deliverables
+ `Response` containing the rendered chart as an image file
"""
serializer = SubplotChartRequestSerializer(data=request.data)
diff --git a/metrics/api/views/tables/__init__.py b/metrics/api/views/tables/__init__.py
index 5740425784..d0cd77e7a6 100644
--- a/metrics/api/views/tables/__init__.py
+++ b/metrics/api/views/tables/__init__.py
@@ -1,2 +1,3 @@
+from .dual_category_tables import DualCategoryTablesView
from .single_category_tables import TablesView
from .subplot_tables.api_view import TablesSubplotView
diff --git a/metrics/api/views/tables/dual_category_tables.py b/metrics/api/views/tables/dual_category_tables.py
new file mode 100644
index 0000000000..cf1596c477
--- /dev/null
+++ b/metrics/api/views/tables/dual_category_tables.py
@@ -0,0 +1,112 @@
+from http import HTTPStatus
+
+from drf_spectacular.utils import extend_schema
+from rest_framework.response import Response
+from rest_framework.views import APIView
+
+from caching.private_api.decorators import cache_response
+from metrics.api.decorators.auth import require_authorisation
+from metrics.api.serializers.dual_category_tables import (
+ DualCategoryTablesResponseSerializer,
+ DualCategoryTablesSerializer,
+)
+from metrics.api.views.tables.common import TABLES_API_TAG
+from metrics.interfaces.plots.access import (
+ DataNotFoundForAnyPlotError,
+ InvalidPlotParametersError,
+)
+from metrics.interfaces.tables import access
+
+
+class DualCategoryTablesView(APIView):
+ permission_classes = []
+
+ @classmethod
+ @extend_schema(
+ request=DualCategoryTablesSerializer,
+ responses={HTTPStatus.OK.value: DualCategoryTablesResponseSerializer},
+ tags=[TABLES_API_TAG],
+ )
+ @require_authorisation
+ def post(cls, request, *args, **kwargs):
+ """This endpoint can be used to generate chart data in tabular format.
+
+ Multiple plots can be added as an array of objects from the request body.
+
+ This payload takes the following set of parameters for each plot:
+
+ | Parameter name | Description | Example | Mandatory |
+ |------------------|----------------------------------------------------------------------------|--------------------------|-----------|
+ | `topic` | The name of the disease/threat | COVID-19 | Yes |
+ | `metric` | The name of the metric being queried for | COVID-19_deaths_ONSByDay | Yes |
+ | `stratum` | The smallest subgroup a metric can be broken down into | default | No |
+ | `geography` | The geography constraints to apply any data filtering to | London | No |
+ | `geography_type` | The type of geographical categorisation to apply any data filtering to | Nation | No |
+ | `age` | The patient age band | 0_4 | No |
+ | `date_from` | The date from which to start the data slice from. In the format YYYY-MM-DD | 2023-01-01 | No |
+ | `date_to` | The date to end the data slice to. In the format YYYY-MM-DD | 2023-05-01 | No |
+ | `label` | The label to assign on the legend for this individual plot | Daily Covid deaths | No |
+
+ ---
+
+ # Main errors
+ There are certain combination of `topic / metric` which do not make sense.
+ This is primarily because a set of `metric` values are not available for every `topic`.
+ As well as this, certain `metric` names reference data of a certain profile.
+
+ ---
+
+ ## Selected metric not available for topic
+
+ In these cases, this endpoint will return an HTTP 400 BAD REQUEST.
+ For example, if a metric like `COVID-19_deaths_ONSByDay` (which is only used for `COVID-19`)
+ is being asked for with a topic of `Influenza`.
+
+ Then an HTTP 400 BAD REQUEST is returned with the following error message:
+ `Influenza` does not have a corresponding metric of `COVID-19`
+
+ ---
+
+ ## Ordering of data
+
+ Note that for tables which are `date` based i.e. where the `x_axis` field is set to `date`.
+
+ Then the data will be returned in descending order from newest -> oldest:
+
+ ```
+ | 2023-09-29 | 1 |
+ | 2023-09-28 | 2 |
+ | 2023-09-27 | 3 |
+ ```
+
+ For tables which **not** `date` based i.e. where the `x_axis` field is set to something like `age`.
+
+ Then the data will be returned in ascending order:
+
+ ```
+ | 00 - 04 | 1 |
+ | 05 - 09 | 2 |
+ | 10 - 14 | 3 |
+ ```
+
+ """
+ print(f"AIDAN: serialising {request.data}")
+ request_serializer = DualCategoryTablesSerializer(data=request.data)
+ print(f"AIDAN: validating serialiser")
+ request_serializer.is_valid(raise_exception=True)
+
+ print(f"AIDAN: converting to models")
+ request_params = request_serializer.to_models(request=request)
+
+ try:
+ print(f"AIDAN: generating table")
+ tabular_data: list[dict[str, str]] = access.generate_table_for_full_plots(
+ request_params=request_params
+ )
+ except (InvalidPlotParametersError, DataNotFoundForAnyPlotError) as error:
+ print(f"AIDAN: error due to {error}")
+ return Response(
+ status=HTTPStatus.BAD_REQUEST, data={"error_message": str(error)}
+ )
+
+ return Response(tabular_data)
diff --git a/metrics/domain/charts/chart_settings/dual_category.py b/metrics/domain/charts/chart_settings/dual_category.py
new file mode 100644
index 0000000000..bd2978cd3e
--- /dev/null
+++ b/metrics/domain/charts/chart_settings/dual_category.py
@@ -0,0 +1,50 @@
+from metrics.domain.charts.chart_settings.single_category import (
+ SingleCategoryChartSettings,
+)
+from metrics.domain.models.plots import ChartGenerationPayload, PlotGenerationData
+
+
+class DualCategoryChartSettings(SingleCategoryChartSettings):
+ def __init__(self, *, chart_generation_payload: ChartGenerationPayload):
+ super().__init__(chart_generation_payload=chart_generation_payload)
+ self.plots_data: list[PlotGenerationData] = chart_generation_payload.plots
+
+ def get_stacked_bar_chart_config(self) -> dict:
+ """
+ Builds the configuration for the stacked bar chart.
+
+ Returns:
+ The configuration for the stacked bar chart.
+ """
+ chart_config = self._get_base_chart_config()
+
+ chart_config["barmode"] = "stack"
+
+ return {**chart_config, **self._get_legend_config()}
+
+ def _get_legend_config(self) -> dict:
+ """
+ Builds the configuration for the legend.
+
+ Returns:
+ The configuration for the legend.
+ """
+ legend_config = {
+ "font": self._get_tick_font_config(),
+ "orientation": "h",
+ "y": 1.0,
+ "x": 0.5,
+ "xanchor": "center",
+ "yanchor": "bottom",
+ "entrywidth": 80,
+ }
+
+ if legend_title := self._chart_generation_payload.legend_title:
+ legend_config["title"] = {
+ "text": f"{legend_title}",
+ "side": "top",
+ }
+
+ return {
+ "legend": legend_config,
+ }
diff --git a/metrics/domain/charts/reporting_delay_period.py b/metrics/domain/charts/reporting_delay_period.py
index 8c7f16c812..00edd127db 100644
--- a/metrics/domain/charts/reporting_delay_period.py
+++ b/metrics/domain/charts/reporting_delay_period.py
@@ -1,6 +1,6 @@
import contextlib
import logging
-from datetime import datetime
+from datetime import date, datetime
import plotly
from plotly.graph_objs import Scatter
@@ -55,9 +55,13 @@ def _get_last_x_value_at_end_of_reporting_delay_period(
def get_x_value_at_start_of_reporting_delay_period(
chart_plots_data: list[PlotGenerationData],
-) -> str:
- index: int = chart_plots_data[0].start_of_reporting_delay_period_index
- return chart_plots_data[0].x_axis_values[index]
+) -> date:
+ values = [
+ plot.x_axis_values[plot.start_of_reporting_delay_period_index]
+ for plot in chart_plots_data
+ ]
+
+ return min(values)
def add_reporting_delay_period(
diff --git a/metrics/domain/charts/stacked_bar/__init__.py b/metrics/domain/charts/stacked_bar/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/metrics/domain/charts/stacked_bar/generation.py b/metrics/domain/charts/stacked_bar/generation.py
new file mode 100644
index 0000000000..94a4874b2a
--- /dev/null
+++ b/metrics/domain/charts/stacked_bar/generation.py
@@ -0,0 +1,57 @@
+from collections import defaultdict
+
+import plotly.graph_objects as go
+
+from metrics.domain.charts.chart_settings.dual_category import DualCategoryChartSettings
+from metrics.domain.charts.reporting_delay_period import add_reporting_delay_period
+from metrics.domain.models.plots import ChartGenerationPayload
+
+
+def generate_stacked_bar(
+ *,
+ chart_generation_payload: ChartGenerationPayload,
+):
+ """
+ Generates a stacked bar chart.
+
+ Args:
+ chart_generation_payload: The payload for the chart generation.
+
+ Returns:
+ The figure for the stacked bar chart.
+ """
+ figure = go.Figure()
+ settings = DualCategoryChartSettings(
+ chart_generation_payload=chart_generation_payload,
+ )
+
+ grouped: dict[str, dict] = defaultdict(
+ lambda: {"x_axis_values": [], "y_axis_values": []}
+ )
+
+ secondary_category = chart_generation_payload.secondary_category
+ for plot in chart_generation_payload.plots:
+ group = getattr(plot.parameters, secondary_category)
+ grouped[group]["x_axis_values"].extend(plot.x_axis_values)
+ grouped[group]["y_axis_values"].extend(plot.y_axis_values)
+ grouped[group]["label"] = plot.parameters.label
+
+ for label, data in grouped.items():
+ figure.add_trace(
+ go.Bar(
+ x=data["x_axis_values"],
+ y=data["y_axis_values"],
+ name=data["label"] or label,
+ )
+ )
+
+ if settings.is_date_type_x_axis:
+ add_reporting_delay_period(
+ chart_plots_data=chart_generation_payload.plots,
+ figure=figure,
+ )
+
+ layout_args = settings.get_stacked_bar_chart_config()
+ figure.update_layout(**layout_args)
+
+ return figure
diff --git a/metrics/domain/models/charts/__init__.py b/metrics/domain/models/charts/__init__.py
index cd9c2e2b85..e69de29bb2 100644
--- a/metrics/domain/models/charts/__init__.py
+++ b/metrics/domain/models/charts/__init__.py
@@ -1 +0,0 @@
-from .dual_category_charts import DualCategoryChartRequestParams
diff --git a/metrics/domain/models/charts/dual_category_charts.py b/metrics/domain/models/charts/dual_category_charts.py
index 09e6f73ccb..6fb6191288 100644
--- a/metrics/domain/models/charts/dual_category_charts.py
+++ b/metrics/domain/models/charts/dual_category_charts.py
@@ -1,12 +1,10 @@
from pydantic import BaseModel
from metrics.domain.models.charts.common import BaseChartRequestParams
-from metrics.domain.models.charts.segments import SegmentParameters
+from metrics.domain.models.plots import PlotParameters
class StaticFields(BaseModel):
- theme: str
- sub_theme: str
topic: str
metric: str
geography: str
@@ -23,4 +21,5 @@ class DualCategoryChartRequestParams(BaseChartRequestParams):
secondary_category: str
primary_field_values: list[str]
static_fields: StaticFields
- segments: list[SegmentParameters]
+ plots: list[PlotParameters]
+ legend_title: str | None = ""
diff --git a/metrics/domain/models/plots.py b/metrics/domain/models/plots.py
index 07d614160b..82077d8a3a 100644
--- a/metrics/domain/models/plots.py
+++ b/metrics/domain/models/plots.py
@@ -231,6 +231,7 @@ class ChartGenerationPayload(BaseModel):
legend_title: str | None = ""
confidence_intervals: bool | None = False
confidence_colour: str | None = ""
+ secondary_category: str | None = ""
class CompletePlotData(BaseModel):
diff --git a/metrics/domain/models/plots_text.py b/metrics/domain/models/plots_text.py
index 94b9f8fb14..becc5a926c 100644
--- a/metrics/domain/models/plots_text.py
+++ b/metrics/domain/models/plots_text.py
@@ -20,6 +20,7 @@
READABLE_DATE_FORMAT = "%d %B %Y"
TREND_CHART_TYPE = "line_single_simplified"
+DUAL_CATEGORY_CHART_TYPE = "stacked_bar"
class PlotsText:
@@ -343,6 +344,8 @@ def _stringify_chart_type(cls, *, plot_parameters: PlotParameters) -> str:
match chart_type:
case ChartTypes.bar.value:
return "bar"
+ case ChartTypes.stacked_bar.value:
+ return "stacked bar"
case ChartTypes.line_multi_coloured.value:
return "line"
case ChartTypes.line_single_simplified.value:
@@ -377,6 +380,9 @@ def _describe_plot_type(self, *, plot_parameters: PlotParameters) -> str:
return plot_description
+ if self._plot_is_dual_category_chart(plot_parameters=plot_parameters):
+ return f"This is a {line_colour} {plot_type} chart. "
+
return f"This is a {line_colour} {line_type} {plot_type} plot. "
@classmethod
@@ -430,6 +436,10 @@ def _plot_is_date_based_timeseries_data(
def _plot_is_simplified_chart(cls, *, plot_parameters: PlotParameters) -> bool:
return plot_parameters.chart_type == TREND_CHART_TYPE
+ @classmethod
+ def _plot_is_dual_category_chart(cls, *, plot_parameters: PlotParameters) -> bool:
+ return plot_parameters.chart_type == DUAL_CATEGORY_CHART_TYPE
+
@classmethod
def _plot_is_headline_data(cls, *, plot_data: PlotGenerationData) -> bool:
return plot_data.parameters.is_headline_data
diff --git a/metrics/interfaces/charts/common/exceptions.py b/metrics/interfaces/charts/common/exceptions.py
new file mode 100644
index 0000000000..35dc57b57e
--- /dev/null
+++ b/metrics/interfaces/charts/common/exceptions.py
@@ -0,0 +1,4 @@
+class InvalidFileFormatError(Exception):
+ def __init__(self):
+ message = "Invalid file format, must be `svg`"
+ super().__init__(message)
diff --git a/metrics/interfaces/charts/common/generation.py b/metrics/interfaces/charts/common/generation.py
index 9579ab3b5d..f93faea360 100644
--- a/metrics/interfaces/charts/common/generation.py
+++ b/metrics/interfaces/charts/common/generation.py
@@ -6,6 +6,8 @@
import plotly
from scour import scour
+from metrics.interfaces.charts.common.chart_output import ChartOutput
+
@dataclass
class ChartResult:
diff --git a/metrics/interfaces/charts/dual_category_charts/__init__.py b/metrics/interfaces/charts/dual_category_charts/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/metrics/interfaces/charts/dual_category_charts/access.py b/metrics/interfaces/charts/dual_category_charts/access.py
new file mode 100644
index 0000000000..9523d38570
--- /dev/null
+++ b/metrics/interfaces/charts/dual_category_charts/access.py
@@ -0,0 +1,198 @@
+from datetime import datetime
+
+import plotly.graph_objects as go
+
+from metrics.data.models.core_models import CoreHeadline, CoreTimeSeries
+from metrics.domain.charts.stacked_bar.generation import generate_stacked_bar
+from metrics.domain.common.utils import (
+ extract_metric_group_from_metric,
+)
+from metrics.domain.models.charts.dual_category_charts import (
+ DualCategoryChartRequestParams,
+)
+from metrics.domain.models.plots import (
+ ChartGenerationPayload,
+ PlotGenerationData,
+)
+from metrics.domain.models.plots_text import PlotsText
+from metrics.interfaces.charts.common.chart_output import ChartOutput
+from metrics.interfaces.plots.access import PlotsInterface
+from metrics.utils.type_hints import CORE_MODEL_MANAGER_TYPE
+
+DEFAULT_CORE_TIME_SERIES_MANAGER = CoreTimeSeries.objects
+DEFAULT_CORE_HEADLINE_MANAGER = CoreHeadline.objects
+
+
+class DualCategoryChartsInterface:
+ def __init__(
+ self,
+ *,
+ chart_request_params: DualCategoryChartRequestParams,
+ core_model_manager: CORE_MODEL_MANAGER_TYPE | None = None,
+ plots_interface: PlotsInterface | None = None,
+ ):
+ self.chart_request_params = chart_request_params
+ self.chart_type = self.chart_request_params.chart_type
+ self.metric_group = extract_metric_group_from_metric(
+ metric=self.chart_request_params.static_fields.metric,
+ )
+ self.core_model_manager = core_model_manager or self._set_core_model_manager()
+ self.plots_interface = plots_interface or PlotsInterface(
+ chart_request_params=self.chart_request_params,
+ core_model_manager=self.core_model_manager,
+ )
+
+ self.last_updated: str = ""
+
+ @property
+ def is_headline_data(self) -> bool:
+ return self.chart_request_params.plots[0].is_headline_data
+
+ @staticmethod
+ def _build_chart_figure(
+ chart_generation_payload: ChartGenerationPayload,
+ ) -> go.Figure:
+ """Build a Plotly chart `Figure` object for a `DualCategory` chart.
+
+ Args:
+ chart_generation_payload: An enriched `ChartGenerationPayload` model
+ which holds all the parameters like colour and plot labels
+ along with the corresponding x and y values
+ which are needed to be able to generate the chart in full.
+
+ Returns:
+ A plotly `Figure` object for the created dual-category chart.
+ """
+ return generate_stacked_bar(
+ chart_generation_payload=chart_generation_payload,
+ )
+
+ def _build_chart_plots_data(self) -> list[PlotGenerationData]:
+ """Creates a list of `PlotData` models which hold the params and corresponding data for the requested plots
+
+ Notes:
+ The corresponding timeseries data is used to enrich a
+ pydantic model which also holds the corresponding params.
+ These models can then be passed into the domain libraries.
+
+ If no data is returned for a particular plot,
+ that chart plot is skipped and
+ an enriched model will not be provided.
+
+ Returns:
+ A list of `PlotData` models for each of the requested chart plots.
+
+ Raises:
+ `DataNotFoundForAnyPlotError`: If no plots
+ returned any data from the underlying queries
+
+ """
+ plots_data: list[PlotGenerationData] = self.plots_interface.build_plots_data()
+ self._set_latest_date_from_plots_data(plots_data=plots_data)
+
+ return plots_data
+
+ def _set_latest_date_from_plots_data(
+ self, *, plots_data: list[PlotGenerationData]
+ ) -> None:
+ """Extracts the latest date from the list of given `plots_data`
+
+ Notes:
+ This extracted value is set on the `_latest_date`
+ instance attribute on this object
+
+ Args:
+ plots_data: List of `PlotData` models,
+ where each model represents a requested plot.
+ Note that each `PlotData` model is enriched
+ with the according x and y values along with
+ requests parameters like colour and plot label.
+
+ Returns:
+ None
+
+ """
+ try:
+ latest_date: datetime.date = max(plot.latest_date for plot in plots_data)
+ except (ValueError, TypeError):
+ return
+
+ self.last_updated: str = datetime.strftime(latest_date, "%Y-%m-%d")
+
+ def _build_chart_generation_payload(self) -> ChartGenerationPayload:
+ plots_data: list[PlotGenerationData] = self._build_chart_plots_data()
+
+ return ChartGenerationPayload(
+ plots=plots_data,
+ x_axis_title=self.chart_request_params.x_axis_title,
+ y_axis_title=self.chart_request_params.y_axis_title,
+ chart_height=self.chart_request_params.chart_height,
+ chart_width=self.chart_request_params.chart_width,
+ y_axis_minimum_value=self.chart_request_params.y_axis_minimum_value,
+ y_axis_maximum_value=self.chart_request_params.y_axis_maximum_value,
+ legend_title=self.chart_request_params.legend_title,
+ confidence_intervals=self.chart_request_params.confidence_intervals,
+ confidence_colour=self.chart_request_params.confidence_colour,
+ secondary_category=self.chart_request_params.secondary_category,
+ )
+
+ def _set_core_model_manager(self) -> bool:
+ """Returns `core_model_manger` based on the `metric_group`.
+
+ Notes:
+ The charts interface can be used to generate charts for
+ either `CoreTimeSeries` or `CoreHeadline` data.
+ this function returns the Django manager to match the
+ current `metric_group` or defaults to `CoreTimeSeries`
+ manager.
+
+ Returns:
+ Manager: either `CoreTimeSeries` or `CoreHeadline`
+ """
+ if self.is_headline_data:
+ return DEFAULT_CORE_HEADLINE_MANAGER
+
+ return DEFAULT_CORE_TIME_SERIES_MANAGER
+
+ @classmethod
+ def build_chart_description(cls, *, plots_data: list[PlotGenerationData]) -> str:
+ """Creates a description to summarize the contents of the chart.
+
+ Args:
+ plots_data: List of `PlotData` models,
+ where each model represents a requested plot.
+ Note that each `PlotData` model is enriched with data
+ with the according x and y values along with
+ requests parameters like colour and plot label.
+
+ Returns:
+ Single string describing the entire chart
+
+ """
+ plots_text = PlotsText(plots_data=plots_data)
+ return plots_text.construct_text()
+
+ def generate_chart_output(self) -> ChartOutput:
+ """Generates a `plotly` chart figure and a corresponding description
+
+ Returns:
+ An enriched `ChartOutput` model containing:
+ figure - a plotly `Figure` object for the created chart
+ description - a string representation
+ which summarises the produced chart
+ """
+ chart_generation_payload: ChartGenerationPayload = (
+ self._build_chart_generation_payload()
+ )
+ description = self.build_chart_description(
+ plots_data=chart_generation_payload.plots
+ )
+ figure = self._build_chart_figure(
+ chart_generation_payload=chart_generation_payload
+ )
+
+ return ChartOutput(
+ figure=figure,
+ description=description,
+ is_headline=self.is_headline_data,
+ )
diff --git a/metrics/interfaces/charts/single_category_charts/access.py b/metrics/interfaces/charts/single_category_charts/access.py
index 471083c159..2ede45ce59 100644
--- a/metrics/interfaces/charts/single_category_charts/access.py
+++ b/metrics/interfaces/charts/single_category_charts/access.py
@@ -22,6 +22,7 @@
)
from metrics.domain.models.plots_text import PlotsText
from metrics.interfaces.charts.common.chart_output import ChartOutput
+from metrics.interfaces.charts.common.exceptions import InvalidFileFormatError
from metrics.interfaces.plots.access import PlotsInterface
from metrics.utils.type_hints import CORE_MODEL_MANAGER_TYPE
@@ -29,12 +30,6 @@
DEFAULT_CORE_HEADLINE_MANAGER = CoreHeadline.objects
-class InvalidFileFormatError(Exception):
- def __init__(self):
- message = "Invalid file format, must be `svg`"
- super().__init__(message)
-
-
class InvalidChartTypeCombinationError(Exception):
def __init__(self, invalid_chart_types: list[str]):
message = f"There has been an invalid combination of plots selected, Please review your plot data. {', '.join(invalid_chart_types)}"
diff --git a/tests/integration/metrics/api/views/charts/test_dual_category_charts.py b/tests/integration/metrics/api/views/charts/test_dual_category_charts.py
index 0ec1ad0733..070f8a9c6c 100644
--- a/tests/integration/metrics/api/views/charts/test_dual_category_charts.py
+++ b/tests/integration/metrics/api/views/charts/test_dual_category_charts.py
@@ -16,41 +16,54 @@ class TestChartsView:
def path(self) -> str:
return "/api/charts/dual-category/v1"
+ @pytest.mark.parametrize(
+ "query_params,response_header",
+ [
+ ({}, "application/json"),
+ ({"preview": False}, "application/json"),
+ ({"preview": True}, "image/svg"),
+ ],
+ )
@pytest.mark.django_db
- def test_returns_correct_response_for_age_based_chart(
+ def test_returns_correct_response_for_timeseries_chart(
self,
+ query_params: dict[str, bool],
+ response_header: str,
admin_user: User,
):
"""
- Given a valid payload to create a chart
+ Given a valid payload to create a timeseries dual-category chart
When the `POST /api/charts/dual-category/v1/` endpoint is hit
+ with the given preview query params
Then an HTTP 200 OK response is returned
"""
# Given
client = APIClient()
client.force_authenticate(user=admin_user)
valid_payload = EXAMPLE_DUAL_CATEGORY_CHART_REQUEST_PAYLOAD
+ static_fields = valid_payload["static_fields"]
+
for age in ("00-04", "05-11"):
- for sex in ("m", "f"):
- CoreTimeSeriesFactory.create_record(
- topic_name=valid_payload["static_fields"]["topic"],
- metric_name=valid_payload["static_fields"]["metric"],
- stratum_name=valid_payload["static_fields"]["stratum"],
- age_name=age,
- geography_name=valid_payload["static_fields"]["geography"],
- geography_type_name=valid_payload["static_fields"][
- "geography_type"
- ],
- sex=sex,
- )
+ CoreTimeSeriesFactory.create_record(
+ topic_name=static_fields["topic"],
+ metric_name=static_fields["metric"],
+ stratum_name=static_fields["stratum"],
+ age_name=age,
+ geography_name=static_fields["geography"],
+ geography_type_name=static_fields["geography_type"],
+ sex=static_fields["sex"],
+ date=static_fields["date_from"],
+ metric_value=100,
+ )
# When
response: Response = client.post(
path=self.path,
+ query_params=query_params,
data=valid_payload,
format="json",
)
# Then
assert response.status_code == HTTPStatus.OK
- assert response.headers["Content-Type"] == "application/json"
+ assert response.headers["Content-Type"] == response_header
diff --git a/tests/integration/metrics/api/views/tables/test_dual_category.py b/tests/integration/metrics/api/views/tables/test_dual_category.py
new file mode 100644
index 0000000000..d4895abeeb
--- /dev/null
+++ b/tests/integration/metrics/api/views/tables/test_dual_category.py
@@ -0,0 +1,152 @@
+import datetime
+from http import HTTPStatus
+
+import pytest
+from rest_framework.response import Response
+from rest_framework.test import APIClient
+
+from metrics.data.models.core_models import CoreTimeSeries, CoreHeadline
+
+
+@pytest.mark.django_db
+class TestTablesView:
+ @property
+ def path(self) -> str:
+ return "/api/tables/dual-category/v1/"
+
+ def test_timeseries_plot(
+ self,
+ core_timeseries_example: list[CoreTimeSeries],
+ ):
+ """
+ Given a valid payload to create a chart
+ When the `POST /api/tables/v4/` endpoint is hit with a single plot
+ Then the response is of the correct format
+ """
+ # Given
+ client = APIClient()
+ core_timeseries: CoreTimeSeries = core_timeseries_example[0]
+ topic: str = core_timeseries.metric.metric_group.topic.name
+ metric: str = core_timeseries.metric.name
+
+ valid_payload = {
+ "x_axis": "age",
+ "label": "Dual category by age / sex",
+ "x_axis_title": "",
+ "primary_field_values": [
+ "00-04",
+ ],
+ "y_axis": "metric",
+ "y_axis_title": "",
+ "static_fields": {
+ "date_from": "2000-01-01",
+ "date_to": datetime.date.today(),
+ "metric": metric,
+ "topic": topic,
+ },
+ "secondary_category": "sex",
+ "segments": [
+ {
+ "secondary_field_value": "f",
+ "label": "0 to 4 years female",
+ },
+ {
+ "secondary_field_value": "m",
+ "label": "0 to 4 years male",
+ },
+ ],
+ }
+
+ # When
+ response: Response = client.post(
+ path=self.path, data=valid_payload, format="json"
+ )
+
+ # Then
+ expected_response = [
+ {
+ "reference": "all",
+ "values": [
+ {
+ "label": "Plot1",
+ "value": "123.0000",
+ "in_reporting_delay_period": False,
+ },
+ {
+ "label": "Plot2",
+ "value": "123.0000",
+ "in_reporting_delay_period": False,
+ },
+ ],
+ }
+ ]
+
+ assert expected_response == response.data
+
+ def test_headline_plot(
+ self,
+ core_headline_example: CoreHeadline,
+ ):
+ """
+ Given a valid payload to create a chart
+ When the `POST /api/tables/v4/` endpoint is hit with a single plot
+ Then the response is of the correct format
+ """
+ # Given
+ client = APIClient()
+ topic: str = core_headline_example.metric.metric_group.topic.name
+ metric: str = core_headline_example.metric.name
+
+ valid_payload = {
+ "x_axis": "age",
+ "label": "Dual category by age / sex",
+ "x_axis_title": "",
+ "primary_field_values": [
+ "00-04",
+ ],
+ "y_axis": "metric",
+ "y_axis_title": "",
+ "static_fields": {
+ "date_from": "2000-01-01",
+ "date_to": datetime.date.today(),
+ "metric": metric,
+ "topic": topic,
+ },
+ "secondary_category": "sex",
+ "segments": [
+ {
+ "secondary_field_value": "f",
+ "label": "0 to 4 years female",
+ },
+ {
+ "secondary_field_value": "m",
+ "label": "0 to 4 years male",
+ },
+ ],
+ }
+
+ # When
+ response: Response = client.post(
+ path=self.path, data=valid_payload, format="json"
+ )
+
+ # Then
+ expected_response = [
+ {
+ "reference": "default",
+ "values": [
+ {
+ "label": "Amount",
+ "value": "123.0000",
+ "in_reporting_delay_period": False,
+ },
+ {
+ "label": "Amount",
+ "value": "123.0000",
+ "in_reporting_delay_period": False,
+ },
+ ],
+ }
+ ]
+
+ assert expected_response == response.data
diff --git a/tests/unit/metrics/api/serializers/charts/test_dual_category_charts.py b/tests/unit/metrics/api/serializers/charts/test_dual_category_charts.py
index 6220f1da4e..a55a88cee5 100644
--- a/tests/unit/metrics/api/serializers/charts/test_dual_category_charts.py
+++ b/tests/unit/metrics/api/serializers/charts/test_dual_category_charts.py
@@ -1,10 +1,27 @@
+import copy
+
import pytest
from rest_framework.exceptions import ValidationError
from metrics.api.serializers.charts.dual_category_charts import (
DualCategoryChartSegmentSerializer,
+ DualCategoryChartSerializer,
+)
+from metrics.api.serializers.plots import PlotSerializer
+from metrics.api.views.charts.dual_category_charts import (
+ EXAMPLE_DUAL_CATEGORY_CHART_REQUEST_PAYLOAD,
)
from metrics.domain.charts import colour_scheme
+from metrics.domain.common.utils import ChartAxisFields
+from metrics.domain.models.charts.dual_category_charts import (
+ DualCategoryChartRequestParams,
+)
+from tests.fakes.factories.metrics.metric_factory import FakeMetricFactory
+from tests.fakes.managers.metric_manager import FakeMetricManager
+from tests.fakes.managers.topic_manager import FakeTopicManager
+
+HEADLINE_METRIC = "COVID-19_headline_vaccines_spring24Uptake"
+TIMESERIES_METRIC = "COVID-19_cases_rateRollingMean"
class TestDualCategoryChartSegmentSerializer:
@@ -185,3 +202,213 @@ def test_missing_label_is_still_deemed_valid(self):
# Then
assert is_serializer_valid
+
+
+class TestDualCategoryChartSerializer:
+ # Success cases
+ @pytest.mark.parametrize(
+ "metric,x_axis,primary_field_values",
+ [
+ (HEADLINE_METRIC, "age", ["m", "f"]),
+ (TIMESERIES_METRIC, "date", None),
+ (TIMESERIES_METRIC, None, None),
+ ],
+ )
+ def test_validation_with_valid_metric_and_primary_field_values_combination(
+ self, metric: str, x_axis: str | None, primary_field_values: list[str] | None
+ ):
+ """
+ Given a payload containing a valid metric and primary_field_values combination
+ passed to a `DualCategoryChartSerializer` object
+ When `validate()` is called from the serializer
+ Then no ValidationError is raised and the data is returned unchanged
+ """
+ # Given
+ valid_payload = EXAMPLE_DUAL_CATEGORY_CHART_REQUEST_PAYLOAD.copy()
+ valid_payload["static_fields"] = valid_payload["static_fields"].copy()
+ valid_payload["static_fields"]["metric"] = metric
+ valid_payload["x_axis"] = x_axis
+ valid_payload["primary_field_values"] = primary_field_values
+
+ serializer = DualCategoryChartSerializer()
+
+ # When
+ is_valid = serializer.validate(attrs=valid_payload)
+
+ # Then
+ assert is_valid == valid_payload
+
+ # Failure cases
+ def test_validation_with_primary_field_values_for_timeseries_data(self):
+ """
+ Given a payload containing primary_field_values for a timeseries metric
+ passed to a `DualCategoryChartSerializer` object
+ When `validate()` is called from the serializer
+ Then a `ValidationError` is raised as primary_field_values should not be provided
+ """
+ # Given
+ invalid_payload = EXAMPLE_DUAL_CATEGORY_CHART_REQUEST_PAYLOAD.copy()
+ invalid_payload["primary_field_values"] = ["m", "f"]
+
+ serializer = DualCategoryChartSerializer()
+
+ # When / Then
+ with pytest.raises(ValidationError) as exc_info:
+ serializer.validate(attrs=invalid_payload)
+
+ assert exc_info.value.detail["primary_field_values"] == (
+ "This field should not be provided for timeseries data."
+ )
+
+ def test_validation_with_non_date_x_axis_for_timeseries_data(self):
+ """
+ Given a payload containing a non-date x_axis for a timeseries metric
+ passed to a `DualCategoryChartSerializer` object
+ When `validate()` is called from the serializer
+ Then a `ValidationError` is raised as x_axis must be date for timeseries data
+ """
+ # Given
+ invalid_payload = EXAMPLE_DUAL_CATEGORY_CHART_REQUEST_PAYLOAD.copy()
+ invalid_payload["x_axis"] = "age"
+
+ serializer = DualCategoryChartSerializer()
+
+ # When / Then
+ with pytest.raises(ValidationError) as exc_info:
+ serializer.validate(attrs=invalid_payload)
+
+ assert exc_info.value.detail["x_axis"] == (
+ "This field should be set to 'date' for timeseries data."
+ )
+
+ def test_validation_with_primary_field_values_missing_for_headline_data(self):
+ """
+ Given a payload containing a headline metric but no primary_field_values
+ passed to a `DualCategoryChartSerializer` object
+ When `validate()` is called from the serializer
+ Then a `ValidationError` is raised as primary_field_values are required
+ """
+ # Given
+ invalid_payload = EXAMPLE_DUAL_CATEGORY_CHART_REQUEST_PAYLOAD.copy()
+ invalid_payload["static_fields"] = invalid_payload["static_fields"].copy()
+ invalid_payload["static_fields"]["metric"] = HEADLINE_METRIC
+ invalid_payload["x_axis"] = "age"
+ invalid_payload.pop("primary_field_values", None)
+
+ serializer = DualCategoryChartSerializer()
+
+ # When / Then
+ with pytest.raises(ValidationError) as exc_info:
+ serializer.validate(attrs=invalid_payload)
+
+ assert exc_info.value.detail["primary_field_values"] == (
+ "This field is required for headline data."
+ )
+
+ def test_to_models_builds_timeseries_plots_when_x_axis_is_date(
+ self, plot_serializer_payload_and_model_managers
+ ):
+ """
+ Given a valid payload with x_axis set to date (timeseries)
+ When `to_models()` is called from the serializer
+ Then one plot per segment is returned without primary field values
+ """
+ # Given
+ plot_payload, metric_manager, topic_manager = (
+ plot_serializer_payload_and_model_managers
+ )
+ valid_payload = EXAMPLE_DUAL_CATEGORY_CHART_REQUEST_PAYLOAD.copy()
+ valid_payload["x_axis"] = ChartAxisFields.date.name
+ valid_payload.pop("primary_field_values", None)
+ valid_payload["static_fields"]["topic"] = plot_payload["topic"]
+ valid_payload["static_fields"]["metric"] = plot_payload["metric"]
+
+ serializer_context = {
+ "topic_manager": topic_manager,
+ "metric_manager": metric_manager,
+ }
+ serializer = DualCategoryChartSerializer(
+ data=valid_payload,
+ context=serializer_context,
+ )
+ serializer.fields["static_fields"] = PlotSerializer(context=serializer_context)
+ serializer.is_valid(raise_exception=True)
+
+ # When
+ result: DualCategoryChartRequestParams = serializer.to_models(request=None)
+
+ # Then
+ segments = valid_payload["segments"]
+ secondary_category = valid_payload["secondary_category"]
+ static_fields = valid_payload["static_fields"]
+
+ assert result.primary_field_values == []
+ assert len(result.plots) == len(segments)
+
+ for index, segment in enumerate(segments):
+ plot = result.plots[index]
+ assert plot.x_axis == ChartAxisFields.date.name
+ assert plot.y_axis == valid_payload["y_axis"]
+ assert plot.line_colour == segment["colour"]
+ assert plot.chart_type == valid_payload["chart_type"]
+ assert getattr(plot, secondary_category) == segment["secondary_field_value"]
+ assert plot.topic == static_fields["topic"]
+ assert plot.metric == static_fields["metric"]
+ assert plot.date_from == static_fields["date_from"]
+ assert plot.date_to == static_fields["date_to"]
+
+ def test_to_models_builds_headline_plots_with_primary_field_values(self):
+ """
+ Given a valid payload with a headline metric and primary field values
+ When `to_models()` is called from the serializer
+ Then one plot per segment is returned for each primary field value
+ """
+ # Given
+ fake_metric = FakeMetricFactory.build_example_metric(
+ metric_name=HEADLINE_METRIC,
+ metric_group_name="headline",
+ )
+ fake_topic = fake_metric.metric_group.topic
+ metric_manager = FakeMetricManager([fake_metric])
+ topic_manager = FakeTopicManager([fake_topic])
+ valid_payload = EXAMPLE_DUAL_CATEGORY_CHART_REQUEST_PAYLOAD.copy()
+ valid_payload["static_fields"] = valid_payload["static_fields"].copy()
+ valid_payload["static_fields"]["topic"] = fake_topic.name
+ valid_payload["static_fields"]["metric"] = HEADLINE_METRIC
+ valid_payload["x_axis"] = ChartAxisFields.sex.name
+ valid_payload["primary_field_values"] = ["m", "f"]
+
+ serializer_context = {
+ "topic_manager": topic_manager,
+ "metric_manager": metric_manager,
+ }
+ serializer = DualCategoryChartSerializer(
+ data=valid_payload,
+ context=serializer_context,
+ )
+ serializer.fields["static_fields"] = PlotSerializer(context=serializer_context)
+ serializer.is_valid(raise_exception=True)
+
+ # When
+ result: DualCategoryChartRequestParams = serializer.to_models(request=None)
+
+ # Then
+ segments = valid_payload["segments"]
+ secondary_category = valid_payload["secondary_category"]
+ x_axis = valid_payload["x_axis"]
+
+ assert result.primary_field_values == ["m", "f"]
+ assert len(result.plots) == len(segments) * len(
+ valid_payload["primary_field_values"]
+ )
+
+ for index, plot in enumerate(result.plots):
+ segment_index = index % len(segments)
+ primary_index = index // len(segments)
+ segment = segments[segment_index]
+ primary_field_value = valid_payload["primary_field_values"][primary_index]
+
+ assert plot.x_axis == x_axis
+ assert getattr(plot, x_axis) == primary_field_value
+ assert getattr(plot, secondary_category) == segment["secondary_field_value"]
+ assert plot.line_colour == segment["colour"]
diff --git a/tests/unit/metrics/api/serializers/charts/test_single_category_charts.py b/tests/unit/metrics/api/serializers/charts/test_single_category_charts.py
index fbf7aa36e8..0f773316be 100644
--- a/tests/unit/metrics/api/serializers/charts/test_single_category_charts.py
+++ b/tests/unit/metrics/api/serializers/charts/test_single_category_charts.py
@@ -1,10 +1,10 @@
import pytest
from rest_framework.exceptions import ValidationError
+from metrics.api.serializers.charts.common import EncodedChartResponseSerializer
from metrics.api.serializers.charts import (
ChartPlotSerializer,
ChartsSerializer,
- EncodedChartResponseSerializer,
EncodedChartsRequestSerializer,
)
from metrics.domain.charts import colour_scheme
diff --git a/tests/unit/metrics/api/views/charts/test_dual_category_charts.py b/tests/unit/metrics/api/views/charts/test_dual_category_charts.py
index 5ffa586746..578a1b03a9 100644
--- a/tests/unit/metrics/api/views/charts/test_dual_category_charts.py
+++ b/tests/unit/metrics/api/views/charts/test_dual_category_charts.py
@@ -1,8 +1,22 @@
+from http import HTTPStatus
+from unittest import mock
+
+import pytest
from _pytest.monkeypatch import MonkeyPatch
+from django.http import FileResponse
from rest_framework import permissions
+from rest_framework.response import Response
import config
-from metrics.api.views import DualCategoryChartsView
+from metrics.api.views.charts.dual_category_charts import DualCategoryChartsView
+from metrics.domain.models.charts.dual_category_charts import (
+ DualCategoryChartRequestParams,
+)
+from metrics.interfaces.charts.common.generation import ChartResult
+from metrics.interfaces.plots.access import (
+ DataNotFoundForAnyPlotError,
+ InvalidPlotParametersError,
+)
class TestDualCategoryChartsView:
@@ -22,3 +36,125 @@ def test_authentication_is_required(self, monkeypatch: MonkeyPatch):
# Then
assert type(permissions_on_view[0]) is permissions.IsAuthenticated
+
+
+class TestDualCategoryChartsViewErrorHandling:
+ @pytest.mark.parametrize(
+ "exception",
+ [DataNotFoundForAnyPlotError(), InvalidPlotParametersError()],
+ )
+ @mock.patch(
+ "metrics.api.views.charts.dual_category_charts.generate_chart_as_file",
+ )
+ def test_handle_chart_as_file_returns_bad_request_when_chart_generation_fails(
+ self,
+ mocked_generate_chart_as_file: mock.MagicMock,
+ exception: Exception,
+ ):
+ """
+ Given chart generation raises a plot error
+ When `_handle_chart_as_file()` is called
+ Then an HTTP 400 response with the error message is returned
+ """
+ # Given
+ mocked_generate_chart_as_file.side_effect = exception
+ chart_request_params = mock.MagicMock(spec=DualCategoryChartRequestParams)
+
+ # When
+ response = DualCategoryChartsView._handle_chart_as_file(
+ chart_request_params=chart_request_params,
+ )
+
+ # Then
+ assert response.status_code == HTTPStatus.BAD_REQUEST
+ assert response.data == {"error_message": str(exception)}
+
+ @pytest.mark.parametrize(
+ "exception",
+ [DataNotFoundForAnyPlotError(), InvalidPlotParametersError()],
+ )
+ @mock.patch(
+ "metrics.api.views.charts.dual_category_charts.generate_encoded_chart",
+ )
+ def test_handle_encoded_svg_returns_bad_request_when_chart_generation_fails(
+ self,
+ mocked_generate_encoded_chart: mock.MagicMock,
+ exception: Exception,
+ ):
+ """
+ Given chart generation raises a plot error
+ When `_handle_encoded_svg()` is called
+ Then an HTTP 400 response with the error message is returned
+ """
+ # Given
+ mocked_generate_encoded_chart.side_effect = exception
+ chart_request_params = mock.MagicMock(spec=DualCategoryChartRequestParams)
+
+ # When
+ response = DualCategoryChartsView._handle_encoded_svg(
+ chart_request_params=chart_request_params,
+ )
+
+ # Then
+ assert response.status_code == HTTPStatus.BAD_REQUEST
+ assert response.data == {"error_message": str(exception)}
+
+
+class TestDualCategoryChartsViewSuccessHandling:
+ @mock.patch(
+ "metrics.api.views.charts.dual_category_charts.generate_chart_as_file",
+ )
+ def test_handle_chart_as_file_returns_file_response_when_chart_generation_succeeds(
+ self,
+ mocked_generate_chart_as_file: mock.MagicMock,
+ ):
+ """
+ Given chart generation succeeds
+ When `_handle_chart_as_file()` is called
+ Then a `FileResponse` containing the chart image is returned
+ """
+ # Given
+ mocked_generate_chart_as_file.return_value = b""
+ chart_request_params = mock.MagicMock(spec=DualCategoryChartRequestParams)
+ chart_request_params.file_format = "svg"
+
+ # When
+ response = DualCategoryChartsView._handle_chart_as_file(
+ chart_request_params=chart_request_params,
+ )
+
+ # Then
+ assert isinstance(response, FileResponse)
+ assert response.getvalue() == b""
+ assert response.headers["Content-Type"] == "image/svg"
+
+ @mock.patch(
+ "metrics.api.views.charts.dual_category_charts.generate_encoded_chart",
+ )
+ def test_handle_encoded_svg_returns_response_when_chart_generation_succeeds(
+ self,
+ mocked_generate_encoded_chart: mock.MagicMock,
+ ):
+ """
+ Given chart generation succeeds
+ When `_handle_encoded_svg()` is called
+ Then an HTTP 200 response containing the encoded chart is returned
+ """
+ # Given
+ mocked_generate_encoded_chart.return_value = ChartResult(
+ last_updated="2024-01-01",
+ chart="encoded-svg",
+ alt_text="Chart description",
+ figure={},
+ )
+ chart_request_params = mock.MagicMock(spec=DualCategoryChartRequestParams)
+
+ # When
+ response = DualCategoryChartsView._handle_encoded_svg(
+ chart_request_params=chart_request_params,
+ )
+
+ # Then
+ assert isinstance(response, Response)
+ assert response.status_code == HTTPStatus.OK
+ assert response.data["chart"] == "encoded-svg"
diff --git a/tests/unit/metrics/domain/charts/chart_settings/test_chart_settings_dual_category.py b/tests/unit/metrics/domain/charts/chart_settings/test_chart_settings_dual_category.py
new file mode 100644
index 0000000000..de3b73d95a
--- /dev/null
+++ b/tests/unit/metrics/domain/charts/chart_settings/test_chart_settings_dual_category.py
@@ -0,0 +1,104 @@
+import pytest
+
+from metrics.domain.charts.chart_settings.dual_category import DualCategoryChartSettings
+from metrics.domain.models import ChartGenerationPayload, PlotGenerationData
+from tests.conftest import fake_plot_data
+
+
+@pytest.fixture()
+def fake_dual_category_chart_settings(
+ fake_plot_data: PlotGenerationData,
+) -> DualCategoryChartSettings:
+ payload = ChartGenerationPayload(
+ chart_width=640,
+ chart_height=400,
+ plots=[fake_plot_data],
+ x_axis_title="Date",
+ y_axis_title="Cases",
+ )
+ return DualCategoryChartSettings(chart_generation_payload=payload)
+
+
+class TestDualCategoryChartSettings:
+ def test_get_legend_config(
+ self, fake_dual_category_chart_settings: DualCategoryChartSettings
+ ):
+ """
+ Given an instance of `DualCategoryChartSettings` without a legend title
+ When `_get_legend_config()` is called
+ Then the legend configuration is returned without a title
+ """
+ # Given
+ chart_settings = fake_dual_category_chart_settings
+
+ # When
+ legend_config = chart_settings._get_legend_config()
+
+ # Then
+ assert legend_config == {
+ "legend": {
+ "font": chart_settings._get_tick_font_config(),
+ "orientation": "h",
+ "y": 1.0,
+ "x": 0.5,
+ "xanchor": "center",
+ "yanchor": "bottom",
+ "entrywidth": 80,
+ },
+ }
+
+ def test_get_legend_config_includes_legend_title_when_provided(
+ self, fake_dual_category_chart_settings: DualCategoryChartSettings
+ ):
+ """
+ Given an instance of `DualCategoryChartSettings` with a legend title
+ When `_get_legend_config()` is called
+ Then the legend configuration includes a formatted title
+ """
+ # Given
+ chart_settings = fake_dual_category_chart_settings
+ legend_title = "Age group"
+ chart_settings._chart_generation_payload.legend_title = legend_title
+
+ # When
+ legend_config = chart_settings._get_legend_config()
+
+ # Then
+ assert legend_config == {
+ "legend": {
+ "font": chart_settings._get_tick_font_config(),
+ "orientation": "h",
+ "y": 1.0,
+ "x": 0.5,
+ "xanchor": "center",
+ "yanchor": "bottom",
+ "entrywidth": 80,
+ "title": {
+ "text": f"{legend_title}",
+ "side": "top",
+ },
+ },
+ }
+
+ def test_get_stacked_bar_chart_config(
+ self, fake_dual_category_chart_settings: DualCategoryChartSettings
+ ):
+ """
+ Given an instance of `DualCategoryChartSettings`
+ When `get_stacked_bar_chart_config()` is called
+ Then stacked bar layout and legend settings are merged
+ """
+ # Given
+ chart_settings = fake_dual_category_chart_settings
+
+ # When
+ stacked_bar_config = chart_settings.get_stacked_bar_chart_config()
+
+ expected_config = {
+ **chart_settings._get_base_chart_config(),
+ **chart_settings._get_legend_config(),
+ "barmode": "stack",
+ }
+
+ # Then
+ assert stacked_bar_config == expected_config
diff --git a/tests/unit/metrics/domain/charts/stacked_bar/test_generation.py b/tests/unit/metrics/domain/charts/stacked_bar/test_generation.py
new file mode 100644
index 0000000000..c1b16028d1
--- /dev/null
+++ b/tests/unit/metrics/domain/charts/stacked_bar/test_generation.py
@@ -0,0 +1,87 @@
+from unittest import mock
+
+import plotly.graph_objects as go
+
+from metrics.domain.charts.stacked_bar.generation import generate_stacked_bar
+from metrics.domain.models.plots import ChartGenerationPayload, PlotGenerationData
+
+MODULE_PATH = "metrics.domain.charts.stacked_bar.generation"
+
+HEIGHT = 400
+WIDTH = 640
+
+
+class TestGenerateStackedBar:
+ @mock.patch(f"{MODULE_PATH}.add_reporting_delay_period")
+ def test_adds_reporting_delay_period_when_x_axis_is_date_type(
+ self,
+ mocked_add_reporting_delay_period: mock.MagicMock,
+ fake_plot_data: PlotGenerationData,
+ ):
+ """
+ Given stacked bar plot data with date values on the x-axis
+ When `generate_stacked_bar()` is called
+ Then the reporting delay period is added to the figure
+ """
+ # Given
+ fake_plot_data.parameters.chart_type = "stacked_bar"
+ fake_plot_data.parameters.age = "all"
+
+ # Adding in_reporting_delay_period to the last date in the x-axis values
+ fake_plot_data.additional_values = {
+ "in_reporting_delay_period": [False]
+ * (len(fake_plot_data.x_axis_values) - 1)
+ + [True],
+ }
+ chart_payload = ChartGenerationPayload(
+ chart_width=WIDTH,
+ chart_height=HEIGHT,
+ plots=[fake_plot_data],
+ x_axis_title="Date",
+ y_axis_title="Cases",
+ secondary_category="age",
+ )
+
+ # When
+ figure = generate_stacked_bar(chart_generation_payload=chart_payload)
+
+ # Then
+ mocked_add_reporting_delay_period.assert_called_once_with(
+ chart_plots_data=chart_payload.plots,
+ figure=figure,
+ )
+ assert isinstance(figure, go.Figure)
+ assert len(figure.data) == 1
+
+ @mock.patch(f"{MODULE_PATH}.add_reporting_delay_period")
+ def test_does_not_add_reporting_delay_period_when_x_axis_is_not_date_type(
+ self,
+ mocked_add_reporting_delay_period: mock.MagicMock,
+ fake_plot_data: PlotGenerationData,
+ ):
+ """
+ Given stacked bar plot data with non-date values on the x-axis
+ When `generate_stacked_bar()` is called
+ Then the reporting delay period is not added to the figure
+ """
+ # Given
+ fake_plot_data.parameters.chart_type = "stacked_bar"
+ fake_plot_data.parameters.age = "all"
+ fake_plot_data.x_axis_values = ["0-4", "5-9", "10-14"]
+ fake_plot_data.y_axis_values = fake_plot_data.y_axis_values[
+ : len(fake_plot_data.x_axis_values)
+ ]
+ chart_payload = ChartGenerationPayload(
+ chart_width=WIDTH,
+ chart_height=HEIGHT,
+ plots=[fake_plot_data],
+ x_axis_title="Age",
+ y_axis_title="Cases",
+ secondary_category="age",
+ )
+
+ # When
+ generate_stacked_bar(chart_generation_payload=chart_payload)
+
+ # Then
+ mocked_add_reporting_delay_period.assert_not_called()
diff --git a/tests/unit/metrics/domain/models/test_plots_text.py b/tests/unit/metrics/domain/models/test_plots_text.py
index e74f8450e0..27b35c743f 100644
--- a/tests/unit/metrics/domain/models/test_plots_text.py
+++ b/tests/unit/metrics/domain/models/test_plots_text.py
@@ -484,6 +484,27 @@ def test_can_describe_singular_timeseries_plots(
assert "This plot has a value of '123'" in text
assert "This plot has a value of '456.01'" in text
+ def test_returns_correct_text_for_dual_category_stacked_bar_chart(
+ self, fake_plot_data: PlotGenerationData
+ ):
+ """
+ Given a dual category stacked bar plot
+ When `construct_text()` is called from an instance of `PlotsText`
+ Then the returned text describes a stacked bar chart
+ """
+ # Given
+ fake_plot_data.parameters.chart_type = ChartTypes.stacked_bar.value
+ fake_plot_data.parameters.line_colour = (
+ RGBAChartLineColours.COLOUR_1_DARK_BLUE.name
+ )
+ plots_text = PlotsText(plots_data=[fake_plot_data])
+
+ # When
+ text: str = plots_text.construct_text()
+
+ # Then
+ assert "This is a dark blue stacked bar chart." in text
+
def test_describe_singular_metric_value_returns_empty_string_for_timeseries_plots_with_no_data(
self, fake_plot_data: PlotGenerationData
):
diff --git a/tests/unit/metrics/interfaces/charts/dual_category_charts/__init__.py b/tests/unit/metrics/interfaces/charts/dual_category_charts/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/unit/metrics/interfaces/charts/dual_category_charts/test_access.py b/tests/unit/metrics/interfaces/charts/dual_category_charts/test_access.py
new file mode 100644
index 0000000000..461f4df0f4
--- /dev/null
+++ b/tests/unit/metrics/interfaces/charts/dual_category_charts/test_access.py
@@ -0,0 +1,255 @@
+import datetime
+from unittest import mock
+
+import pytest
+
+from metrics.data.managers.core_models.headline import CoreHeadlineManager
+from metrics.data.managers.core_models.time_series import CoreTimeSeriesManager
+from metrics.domain.models.charts.dual_category_charts import (
+ DualCategoryChartRequestParams,
+ StaticFields,
+)
+from metrics.domain.models.plots import ChartGenerationPayload, PlotParameters
+from metrics.interfaces.charts.common.chart_output import ChartOutput
+from metrics.interfaces.charts.dual_category_charts.access import (
+ DualCategoryChartsInterface,
+)
+from metrics.interfaces.plots.access import PlotsInterface
+
+MODULE_PATH = "metrics.interfaces.charts.dual_category_charts.access"
+
+
+@pytest.fixture
+def dual_category_chart_request_params(
+ fake_chart_plot_parameters: PlotParameters,
+) -> DualCategoryChartRequestParams:
+ return DualCategoryChartRequestParams(
+ chart_type="stacked_bar",
+ secondary_category="age",
+ primary_field_values=["m"],
+ static_fields=StaticFields(
+ topic=fake_chart_plot_parameters.topic,
+ metric=fake_chart_plot_parameters.metric,
+ geography="England",
+ geography_type="Nation",
+ age="all",
+ sex="all",
+ stratum="default",
+ date_from="2020-02-01",
+ date_to="2021-02-01",
+ ),
+ plots=[fake_chart_plot_parameters],
+ file_format="svg",
+ chart_width=320,
+ chart_height=200,
+ x_axis="sex",
+ y_axis="metric",
+ )
+
+
+@pytest.fixture
+def dual_category_chart_request_params_headline(
+ fake_chart_plot_parameters_headline_data: PlotParameters,
+) -> DualCategoryChartRequestParams:
+ plot = fake_chart_plot_parameters_headline_data
+ return DualCategoryChartRequestParams(
+ chart_type="stacked_bar",
+ secondary_category="age",
+ primary_field_values=["all"],
+ static_fields=StaticFields(
+ topic=plot.topic,
+ metric=plot.metric,
+ geography="England",
+ geography_type="Nation",
+ age="all",
+ sex="all",
+ stratum="default",
+ date_from="",
+ date_to="",
+ ),
+ plots=[plot],
+ file_format="svg",
+ chart_width=320,
+ chart_height=200,
+ x_axis="age",
+ y_axis="metric",
+ )
+
+
+class TestDualCategoryChartsInterface:
+ def test_set_latest_date_from_plots_data_fails_silently_when_latest_date_not_provided(
+ self,
+ dual_category_chart_request_params: DualCategoryChartRequestParams,
+ ):
+ """
+ Given plot data with no valid latest dates
+ When `_set_latest_date_from_plots_data()` is called
+ Then `last_updated` is left unchanged
+ """
+ # Given
+ charts_interface = DualCategoryChartsInterface(
+ chart_request_params=dual_category_chart_request_params,
+ )
+ original_last_updated = "2024-01-01"
+ charts_interface.last_updated = original_last_updated
+ mocked_plots_data = [mock.MagicMock(latest_date=None) for _ in range(3)]
+
+ # When
+ charts_interface._set_latest_date_from_plots_data(plots_data=mocked_plots_data)
+
+ # Then
+ assert charts_interface.last_updated == original_last_updated
+
+ def test_initialises_core_model_manager_with_headline_manager(
+ self,
+ dual_category_chart_request_params_headline: DualCategoryChartRequestParams,
+ ):
+ """
+ Given a dual category chart request for headline data
+ When a `DualCategoryChartsInterface` is created
+ Then the `core_model_manager` is a `CoreHeadlineManager`
+ """
+ # When
+ charts_interface = DualCategoryChartsInterface(
+ chart_request_params=dual_category_chart_request_params_headline,
+ )
+
+ # Then
+ assert isinstance(charts_interface.core_model_manager, CoreHeadlineManager)
+
+ def test_initialises_core_model_manager_with_timeseries_manager(
+ self,
+ dual_category_chart_request_params: DualCategoryChartRequestParams,
+ ):
+ """
+ Given a dual category chart request for timeseries data
+ When a `DualCategoryChartsInterface` is created
+ Then the `core_model_manager` is a `CoreTimeSeriesManager`
+ """
+ # When
+ charts_interface = DualCategoryChartsInterface(
+ chart_request_params=dual_category_chart_request_params,
+ )
+
+ # Then
+ assert isinstance(charts_interface.core_model_manager, CoreTimeSeriesManager)
+
+ @mock.patch(f"{MODULE_PATH}.generate_stacked_bar")
+ def test_build_chart_figure_delegates_call_to_generate_stacked_bar(
+ self,
+ mock_generate_stacked_bar: mock.MagicMock,
+ dual_category_chart_request_params: DualCategoryChartRequestParams,
+ fake_plot_data,
+ ):
+ """
+ Given a valid `chart_generation_payload`
+ When `_build_chart_figure()` is called
+ Then a call is made to `generate_stacked_bar`
+ """
+ # Given
+ charts_interface = DualCategoryChartsInterface(
+ chart_request_params=dual_category_chart_request_params,
+ )
+ chart_generation_payload = ChartGenerationPayload(
+ chart_width=320,
+ chart_height=200,
+ plots=[fake_plot_data],
+ x_axis_title="Date",
+ y_axis_title="Cases",
+ secondary_category="age",
+ )
+
+ # When
+ charts_interface._build_chart_figure(
+ chart_generation_payload=chart_generation_payload,
+ )
+
+ # Then
+ mock_generate_stacked_bar.assert_called_once_with(
+ chart_generation_payload=chart_generation_payload,
+ )
+
+ @mock.patch.object(PlotsInterface, "build_plots_data")
+ def test_build_chart_plots_data_delegates_call_to_plots_interface(
+ self,
+ spy_build_plots_data: mock.MagicMock,
+ dual_category_chart_request_params: DualCategoryChartRequestParams,
+ fake_plot_data,
+ ):
+ """
+ Given a valid dual category chart request
+ When `_build_chart_plots_data()` is called
+ Then a call is made to `PlotsInterface.build_plots_data`
+ """
+ # Given
+ fake_plot_data.latest_date = datetime.date(2024, 1, 1)
+ spy_build_plots_data.return_value = [fake_plot_data]
+ charts_interface = DualCategoryChartsInterface(
+ chart_request_params=dual_category_chart_request_params,
+ )
+
+ # When
+ plots_data = charts_interface._build_chart_plots_data()
+
+ # Then
+ spy_build_plots_data.assert_called_once()
+ assert plots_data == spy_build_plots_data.return_value
+
+ def test_set_latest_date_from_plots_data_sets_last_updated(
+ self,
+ dual_category_chart_request_params: DualCategoryChartRequestParams,
+ ):
+ """
+ Given plot data with valid latest dates
+ When `_set_latest_date_from_plots_data()` is called
+ Then `last_updated` is set to the latest date
+ """
+ # Given
+ charts_interface = DualCategoryChartsInterface(
+ chart_request_params=dual_category_chart_request_params,
+ )
+ latest_date = datetime.date(2024, 6, 1)
+ mocked_plots_data = [
+ mock.MagicMock(latest_date=datetime.date(2024, 1, 1)),
+ mock.MagicMock(latest_date=latest_date),
+ ]
+
+ # When
+ charts_interface._set_latest_date_from_plots_data(plots_data=mocked_plots_data)
+
+ # Then
+ assert charts_interface.last_updated == "2024-06-01"
+
+ @mock.patch.object(DualCategoryChartsInterface, "_build_chart_generation_payload")
+ @mock.patch.object(DualCategoryChartsInterface, "_build_chart_figure")
+ def test_generate_chart_output_returns_instance_of_chart_output(
+ self,
+ mock_build_chart_figure: mock.MagicMock,
+ mock_build_chart_generation_payload: mock.MagicMock,
+ dual_category_chart_request_params: DualCategoryChartRequestParams,
+ fake_plot_data,
+ ):
+ """
+ Given a valid dual category chart request
+ When `generate_chart_output()` is called
+ Then an instance of `ChartOutput` is returned
+ """
+ # Given
+ mock_build_chart_generation_payload.return_value = ChartGenerationPayload(
+ chart_width=320,
+ chart_height=200,
+ plots=[fake_plot_data],
+ x_axis_title="Date",
+ y_axis_title="Cases",
+ secondary_category="age",
+ )
+ charts_interface = DualCategoryChartsInterface(
+ chart_request_params=dual_category_chart_request_params,
+ )
+
+ # When
+ chart_output = charts_interface.generate_chart_output()
+
+ # Then
+ assert isinstance(chart_output, ChartOutput)
+ assert chart_output.figure == mock_build_chart_figure.return_value
diff --git a/tests/unit/metrics/interfaces/charts/single_category_charts/test_access.py b/tests/unit/metrics/interfaces/charts/single_category_charts/test_access.py
index 052a3173a1..2f236a8542 100644
--- a/tests/unit/metrics/interfaces/charts/single_category_charts/test_access.py
+++ b/tests/unit/metrics/interfaces/charts/single_category_charts/test_access.py
@@ -16,11 +16,11 @@
from metrics.domain.common.utils import ChartTypes
from metrics.interfaces.charts.single_category_charts.access import (
ChartsInterface,
- InvalidFileFormatError,
InvalidChartTypeCombinationError,
generate_chart_as_file,
generate_encoded_chart,
)
+from metrics.interfaces.charts.common.exceptions import InvalidFileFormatError
from metrics.interfaces.charts.common.chart_output import ChartOutput
from metrics.interfaces.plots.access import InvalidPlotParametersError
from tests.fakes.factories.metrics.core_time_series_factory import (