Skip to content

Commit 9f8e289

Browse files
authored
Support current region scoping metadata (#364)
1 parent 54e42b7 commit 9f8e289

3 files changed

Lines changed: 91 additions & 16 deletions

File tree

changelog.d/364.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Support current PolicyEngine region scoping metadata when reseeding regions.

scripts/seed_regions.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import argparse
2323
import time
24+
from typing import Any
2425

2526
from rich.progress import Progress, SpinnerColumn, TextColumn
2627
from seed_utils import console, get_session
@@ -36,6 +37,34 @@
3637
from policyengine_api.models.region import RegionType # noqa: E402
3738

3839

40+
def _region_scoping_strategy(pe_region: Any) -> Any | None:
41+
return getattr(pe_region, "scoping_strategy", None)
42+
43+
44+
def _region_requires_filter(pe_region: Any) -> bool:
45+
strategy = _region_scoping_strategy(pe_region)
46+
return bool(getattr(pe_region, "requires_filter", strategy is not None))
47+
48+
49+
def _region_filter_field(pe_region: Any) -> str | None:
50+
if hasattr(pe_region, "filter_field"):
51+
return pe_region.filter_field
52+
strategy = _region_scoping_strategy(pe_region)
53+
return getattr(strategy, "variable_name", None)
54+
55+
56+
def _region_filter_value(pe_region: Any) -> str | None:
57+
if hasattr(pe_region, "filter_value"):
58+
return pe_region.filter_value
59+
strategy = _region_scoping_strategy(pe_region)
60+
return getattr(strategy, "variable_value", None)
61+
62+
63+
def _region_filter_strategy(pe_region: Any) -> str | None:
64+
strategy = _region_scoping_strategy(pe_region)
65+
return getattr(strategy, "strategy_type", None)
66+
67+
3968
def _group_us_datasets(
4069
session: Session,
4170
us_model_id,
@@ -197,14 +226,10 @@ def seed_us_regions(
197226
code=pe_region.code,
198227
label=pe_region.label,
199228
region_type=RegionType(pe_region.region_type),
200-
requires_filter=pe_region.requires_filter,
201-
filter_field=pe_region.filter_field,
202-
filter_value=pe_region.filter_value,
203-
filter_strategy=(
204-
pe_region.scoping_strategy.strategy_type
205-
if pe_region.scoping_strategy
206-
else None
207-
),
229+
requires_filter=_region_requires_filter(pe_region),
230+
filter_field=_region_filter_field(pe_region),
231+
filter_value=_region_filter_value(pe_region),
232+
filter_strategy=_region_filter_strategy(pe_region),
208233
parent_code=pe_region.parent_code,
209234
state_code=pe_region.state_code,
210235
state_name=pe_region.state_name,
@@ -295,14 +320,10 @@ def seed_uk_regions(session: Session) -> tuple[int, int, int]:
295320
code=pe_region.code,
296321
label=pe_region.label,
297322
region_type=RegionType(pe_region.region_type),
298-
requires_filter=pe_region.requires_filter,
299-
filter_field=pe_region.filter_field,
300-
filter_value=pe_region.filter_value,
301-
filter_strategy=(
302-
pe_region.scoping_strategy.strategy_type
303-
if pe_region.scoping_strategy
304-
else None
305-
),
323+
requires_filter=_region_requires_filter(pe_region),
324+
filter_field=_region_filter_field(pe_region),
325+
filter_value=_region_filter_value(pe_region),
326+
filter_strategy=_region_filter_strategy(pe_region),
306327
parent_code=pe_region.parent_code,
307328
state_code=None,
308329
state_name=None,

tests/test_seed_regions.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
from pathlib import Path
5+
from types import SimpleNamespace
6+
7+
SCRIPTS_DIR = Path(__file__).resolve().parents[1] / "scripts"
8+
sys.path.insert(0, str(SCRIPTS_DIR))
9+
10+
import seed_regions # noqa: E402
11+
12+
13+
def test_region_filter_helpers_support_scoping_strategy_schema() -> None:
14+
region = SimpleNamespace(
15+
scoping_strategy=SimpleNamespace(
16+
strategy_type="row_filter",
17+
variable_name="place_fips",
18+
variable_value="03000",
19+
)
20+
)
21+
22+
assert seed_regions._region_requires_filter(region) is True
23+
assert seed_regions._region_filter_field(region) == "place_fips"
24+
assert seed_regions._region_filter_value(region) == "03000"
25+
assert seed_regions._region_filter_strategy(region) == "row_filter"
26+
27+
28+
def test_region_filter_helpers_support_legacy_filter_fields() -> None:
29+
region = SimpleNamespace(
30+
requires_filter=True,
31+
filter_field="country",
32+
filter_value="ENGLAND",
33+
scoping_strategy=None,
34+
)
35+
36+
assert seed_regions._region_requires_filter(region) is True
37+
assert seed_regions._region_filter_field(region) == "country"
38+
assert seed_regions._region_filter_value(region) == "ENGLAND"
39+
assert seed_regions._region_filter_strategy(region) is None
40+
41+
42+
def test_current_policyengine_regions_do_not_require_legacy_filter_attrs() -> None:
43+
from policyengine.countries.us.regions import us_region_registry
44+
45+
place = next(
46+
region for region in us_region_registry.regions if region.region_type == "place"
47+
)
48+
49+
assert not hasattr(place, "filter_field")
50+
assert seed_regions._region_requires_filter(place) is True
51+
assert seed_regions._region_filter_field(place) == "place_fips"
52+
assert seed_regions._region_filter_value(place)
53+
assert seed_regions._region_filter_strategy(place) == "row_filter"

0 commit comments

Comments
 (0)