Skip to content

Commit 058a9d8

Browse files
committed
addressing PR review except arch change
1 parent c99d952 commit 058a9d8

1 file changed

Lines changed: 177 additions & 18 deletions

File tree

metrics/interfaces/management/commands/seed_random.py

Lines changed: 177 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
import time
3-
from collections.abc import Iterable
3+
from collections.abc import Callable, Iterable
44
from datetime import date, timedelta
55
from decimal import Decimal
66
from typing import override
@@ -22,6 +22,11 @@
2222
Topic,
2323
)
2424
from metrics.data.models.core_models.timeseries import CoreTimeSeries
25+
from validation import enums as validation_enums
26+
from validation.geography_code import (
27+
NATION_GEOGRAPHY_CODES,
28+
UNITED_KINGDOM_GEOGRAPHY_CODE,
29+
)
2530

2631
SCALE_CONFIGS = {
2732
"small": {"geographies": 5, "metrics": 10, "days": 30},
@@ -37,22 +42,26 @@ def add_arguments(self, parser: CommandParser) -> None:
3742
"--dataset",
3843
choices=["cms", "metrics", "both"],
3944
default="both",
45+
help="Which dataset to seed: CMS, metrics, or both.",
4046
)
4147
parser.add_argument(
4248
"--scale",
4349
choices=["small", "medium", "large"],
4450
default="small",
51+
help="Size of the random metrics dataset to generate.",
4552
)
4653
parser.add_argument(
4754
"--seed",
4855
type=int,
4956
required=False,
5057
default=None,
58+
help="Optional random seed for reproducible metric values.",
5159
)
5260
parser.add_argument(
5361
"--truncate-first",
5462
action="store_true",
5563
default=False,
64+
help="Clear existing metrics tables before seeding to avoid duplicates.",
5665
)
5766

5867
def handle(self, *args, **options) -> None:
@@ -82,13 +91,18 @@ def handle(self, *args, **options) -> None:
8291

8392
if should_seed_metrics:
8493
scale_config = SCALE_CONFIGS[scale]
94+
self.stderr.write("Seeding metrics dataset...")
8595
counts = self._seed_metrics_data(
8696
scale_config=scale_config,
8797
truncate_first=truncate_first,
98+
progress_callback=self.stderr.write,
8899
)
100+
self.stderr.write("Metrics dataset seeding complete.")
89101

90102
if should_seed_cms:
103+
self.stderr.write("Building CMS site data...")
91104
call_command("build_cms_site")
105+
self.stderr.write("CMS site build complete.")
92106

93107
runtime_seconds = time.perf_counter() - started_at
94108
self._print_summary(
@@ -101,35 +115,51 @@ def handle(self, *args, **options) -> None:
101115

102116
@classmethod
103117
def _seed_metrics_data(
104-
cls, *, scale_config: dict[str, int], truncate_first: bool
118+
cls,
119+
*,
120+
scale_config: dict[str, int],
121+
truncate_first: bool,
122+
progress_callback: Callable[[str], None] | None = None,
105123
) -> dict[str, int]:
106-
if truncate_first:
107-
cls._truncate_metrics_data()
124+
"""Seed supporting metric models and time series rows for the selected scale."""
125+
if progress_callback is not None:
126+
progress_callback("Preparing metric taxonomy and geography records...")
108127

109128
with transaction.atomic():
129+
if truncate_first:
130+
cls._truncate_metrics_data()
131+
132+
(
133+
theme_names,
134+
sub_theme_rows,
135+
topic_rows,
136+
) = cls._build_theme_hierarchy_records()
110137
themes = cls._bulk_create(
111138
Theme,
112-
[Theme(name=f"Theme {index + 1}") for index in range(3)],
139+
[Theme(name=name) for name in theme_names],
113140
)
141+
themes_by_name = {theme.name: theme for theme in themes}
114142

115143
sub_themes = cls._bulk_create(
116144
SubTheme,
117145
[
118-
SubTheme(
119-
name=f"SubTheme {index + 1}", theme=themes[index % len(themes)]
120-
)
121-
for index in range(6)
146+
SubTheme(name=name, theme=themes_by_name[theme_name])
147+
for name, theme_name in sub_theme_rows
122148
],
123149
)
150+
sub_themes_by_key = {
151+
(sub_theme.name, sub_theme.theme.name): sub_theme
152+
for sub_theme in sub_themes
153+
}
124154

125155
topics = cls._bulk_create(
126156
Topic,
127157
[
128158
Topic(
129-
name=f"Topic {index + 1}",
130-
sub_theme=sub_themes[index % len(sub_themes)],
159+
name=topic_name,
160+
sub_theme=sub_themes_by_key[(sub_theme_name, theme_name)],
131161
)
132-
for index in range(12)
162+
for topic_name, sub_theme_name, theme_name in topic_rows
133163
],
134164
)
135165

@@ -144,29 +174,47 @@ def _seed_metrics_data(
144174
],
145175
)
146176

147-
geography_type = GeographyType.objects.create(name="Nation")
177+
geography_seed_values = cls._build_geography_seed_values(
178+
count=scale_config["geographies"]
179+
)
180+
geography_type_names = {
181+
record["geography_type"] for record in geography_seed_values
182+
}
183+
geography_types = cls._bulk_create(
184+
GeographyType,
185+
[GeographyType(name=name) for name in sorted(geography_type_names)],
186+
)
187+
geography_types_by_name = {
188+
geography_type.name: geography_type
189+
for geography_type in geography_types
190+
}
148191

149192
geographies = cls._bulk_create(
150193
Geography,
151194
[
152195
Geography(
153-
name=f"Area {index + 1}",
154-
geography_code=f"RND{index + 1:04d}",
155-
geography_type=geography_type,
196+
name=record["name"],
197+
geography_code=record["geography_code"],
198+
geography_type=geography_types_by_name[
199+
record["geography_type"]
200+
],
156201
)
157-
for index in range(scale_config["geographies"])
202+
for record in geography_seed_values
158203
],
159204
)
160205

161206
stratum = Stratum.objects.create(name="All")
162207
age = Age.objects.create(name="All ages")
163208

209+
if progress_callback is not None:
210+
progress_callback("Generating Core/API time series rows...")
164211
core_count, api_count = cls._seed_time_series_rows(
165212
metrics=metrics,
166213
geographies=geographies,
167214
stratum=stratum,
168215
age=age,
169216
days=scale_config["days"],
217+
progress_callback=progress_callback,
170218
)
171219

172220
return {
@@ -181,6 +229,7 @@ def _seed_metrics_data(
181229

182230
@classmethod
183231
def _truncate_metrics_data(cls) -> None:
232+
"""Delete all seeded metrics-related rows in dependency-safe order."""
184233
APITimeSeries.objects.all().delete()
185234
CoreTimeSeries.objects.all().delete()
186235
Metric.objects.all().delete()
@@ -201,6 +250,7 @@ def _seed_time_series_rows(
201250
stratum: Stratum,
202251
age: Age,
203252
days: int,
253+
progress_callback: Callable[[str], None] | None = None,
204254
) -> tuple[int, int]:
205255
frequency = TimePeriod.Weekly.value
206256
today = date.today()
@@ -210,8 +260,11 @@ def _seed_time_series_rows(
210260
api_rows: list[APITimeSeries] = []
211261
core_count = 0
212262
api_count = 0
263+
total_metrics = len(metrics)
264+
total_row_count = total_metrics * len(geographies) * days
265+
log_interval = max(1, total_metrics // 10) if total_metrics else 1
213266

214-
for metric in metrics:
267+
for metric_index, metric in enumerate(metrics, start=1):
215268
topic = metric.topic
216269
sub_theme = topic.sub_theme
217270
theme = sub_theme.theme
@@ -281,6 +334,19 @@ def _seed_time_series_rows(
281334
api_count += len(api_rows)
282335
api_rows = []
283336

337+
if (
338+
progress_callback is not None
339+
and (
340+
metric_index == total_metrics
341+
or metric_index % log_interval == 0
342+
)
343+
):
344+
processed_row_count = metric_index * len(geographies) * days
345+
progress_callback(
346+
f"Processed {metric_index}/{total_metrics} metrics "
347+
f"({processed_row_count:,}/{total_row_count:,} row groups)."
348+
)
349+
284350
if core_rows:
285351
CoreTimeSeries.objects.bulk_create(core_rows, batch_size=batch_size)
286352
core_count += len(core_rows)
@@ -289,12 +355,105 @@ def _seed_time_series_rows(
289355
APITimeSeries.objects.bulk_create(api_rows, batch_size=batch_size)
290356
api_count += len(api_rows)
291357

358+
if progress_callback is not None:
359+
progress_callback(
360+
"Inserted "
361+
f"{core_count:,} CoreTimeSeries rows and "
362+
f"{api_count:,} APITimeSeries rows."
363+
)
364+
292365
return core_count, api_count
293366

294367
@staticmethod
295368
def _bulk_create(model, records: Iterable):
369+
"""Materialise and bulk insert a sequence of model instances."""
296370
return model.objects.bulk_create(list(records))
297371

372+
@classmethod
373+
def _build_theme_hierarchy_records(
374+
cls,
375+
) -> tuple[list[str], list[tuple[str, str]], list[tuple[str, str, str]]]:
376+
child_to_parent: dict[str, str] = {}
377+
normalised_to_child: dict[str, str] = {}
378+
parent_by_name = validation_enums.ParentTheme.__members__
379+
380+
for child_theme_group in validation_enums.ChildTheme:
381+
resolved_parent = (
382+
parent_by_name[child_theme_group.name].value
383+
if child_theme_group.name in parent_by_name
384+
else validation_enums.ParentTheme.INFECTIOUS_DISEASE.value
385+
)
386+
for sub_theme_name in child_theme_group.return_list():
387+
child_to_parent[sub_theme_name] = resolved_parent
388+
normalised_to_child[cls._normalise_key(sub_theme_name)] = (
389+
sub_theme_name
390+
)
391+
392+
topic_rows: list[tuple[str, str, str]] = []
393+
sub_theme_pairs: set[tuple[str, str]] = set()
394+
for topic_group in validation_enums.Topic:
395+
normalised_topic_group = cls._normalise_key(topic_group.name)
396+
sub_theme_name = normalised_to_child.get(normalised_topic_group)
397+
if sub_theme_name is None:
398+
continue
399+
400+
parent_theme_name = child_to_parent[sub_theme_name]
401+
sub_theme_pairs.add((sub_theme_name, parent_theme_name))
402+
for topic_value in topic_group.return_list():
403+
topic_rows.append((topic_value, sub_theme_name, parent_theme_name))
404+
405+
theme_names = sorted({parent_name for _, parent_name in sub_theme_pairs})
406+
sub_theme_rows = sorted(
407+
sub_theme_pairs,
408+
key=lambda value: (value[1], value[0]),
409+
)
410+
return theme_names, sub_theme_rows, topic_rows
411+
412+
@classmethod
413+
def _build_geography_seed_values(cls, *, count: int) -> list[dict[str, str]]:
414+
geographies: list[dict[str, str]] = [
415+
{
416+
"name": "United Kingdom",
417+
"geography_code": UNITED_KINGDOM_GEOGRAPHY_CODE,
418+
"geography_type": (
419+
validation_enums.GeographyType.UNITED_KINGDOM.value
420+
),
421+
}
422+
]
423+
424+
geographies.extend(
425+
{
426+
"name": name,
427+
"geography_code": code,
428+
"geography_type": validation_enums.GeographyType.NATION.value,
429+
}
430+
for name, code in NATION_GEOGRAPHY_CODES.items()
431+
)
432+
433+
if len(geographies) >= count:
434+
return geographies[:count]
435+
436+
extra_required = count - len(geographies)
437+
geographies.extend(
438+
{
439+
"name": cls._format_enum_name(ltla.name),
440+
"geography_code": ltla.value,
441+
"geography_type": (
442+
validation_enums.GeographyType.LOWER_TIER_LOCAL_AUTHORITY.value
443+
),
444+
}
445+
for ltla in list(validation_enums.LTLAs)[:extra_required]
446+
)
447+
return geographies[:count]
448+
449+
@staticmethod
450+
def _normalise_key(value: str) -> str:
451+
return value.lower().replace("-", "_")
452+
453+
@staticmethod
454+
def _format_enum_name(value: str) -> str:
455+
return value.replace("_", " ").title()
456+
298457
def _print_summary(
299458
self,
300459
*,

0 commit comments

Comments
 (0)