Skip to content

Commit e08bc72

Browse files
committed
Format OBR importer for black CI
1 parent 8faab7b commit e08bc72

2 files changed

Lines changed: 84 additions & 35 deletions

File tree

policyengine_uk/tests/test_import_obr_forecasts.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,19 @@ def make_test_xlsx() -> bytes:
5252

5353
sheet_16 = make_sheet(
5454
{
55-
3: [make_inline_cell("Q3", "Average weekly earnings growth (per cent)")],
56-
97: [make_inline_cell("B97", "2025"), make_number_cell("Q97", 5.17)],
57-
98: [make_inline_cell("B98", "2026"), make_number_cell("Q98", 3.33)],
55+
3: [
56+
make_inline_cell(
57+
"Q3", "Average weekly earnings growth (per cent)"
58+
)
59+
],
60+
97: [
61+
make_inline_cell("B97", "2025"),
62+
make_number_cell("Q97", 5.17),
63+
],
64+
98: [
65+
make_inline_cell("B98", "2026"),
66+
make_number_cell("Q98", 3.33),
67+
],
5868
}
5969
)
6070
sheet_17 = make_sheet(
@@ -94,11 +104,18 @@ def make_test_xlsx() -> bytes:
94104
{
95105
3: [
96106
make_inline_cell(
97-
"D3", "House price index (per cent change on a year earlier)"
107+
"D3",
108+
"House price index (per cent change on a year earlier)",
98109
)
99110
],
100-
97: [make_inline_cell("B97", "2025"), make_number_cell("D97", 2.80)],
101-
98: [make_inline_cell("B98", "2026"), make_number_cell("D98", 2.40)],
111+
97: [
112+
make_inline_cell("B97", "2025"),
113+
make_number_cell("D97", 2.80),
114+
],
115+
98: [
116+
make_inline_cell("B98", "2026"),
117+
make_number_cell("D98", 2.40),
118+
],
102119
}
103120
)
104121

@@ -125,7 +142,9 @@ def test_extract_annual_series_from_xlsx():
125142

126143

127144
def test_release_inference_helpers():
128-
assert infer_release("Economy_Detailed_forecast_tables_November_2025.xlsx") == (
145+
assert infer_release(
146+
"Economy_Detailed_forecast_tables_November_2025.xlsx"
147+
) == (
129148
"November",
130149
2025,
131150
)
@@ -137,8 +156,7 @@ def test_release_inference_helpers():
137156

138157
def test_update_yoy_growth_yaml_updates_forecast_window_only(tmp_path):
139158
yaml_path = tmp_path / "yoy_growth.yaml"
140-
yaml_path.write_text(
141-
"""obr:
159+
yaml_path.write_text("""obr:
142160
rpi:
143161
values:
144162
2024-01-01: 0.0300
@@ -209,8 +227,7 @@ def test_update_yoy_growth_yaml_updates_forecast_window_only(tmp_path):
209227
reference:
210228
- title: Old
211229
href: https://example.com/old
212-
"""
213-
)
230+
""")
214231

215232
update_yoy_growth_yaml(
216233
yaml_path=yaml_path,
@@ -229,17 +246,19 @@ def test_update_yoy_growth_yaml_updates_forecast_window_only(tmp_path):
229246
assert "2025-01-01: 0.0280" in content
230247
assert "2026-01-01: 0.0240" in content
231248
assert (
232-
"OBR EFO March 2026 (detailed forecast tables, economy, Table 1.16)" in content
249+
"OBR EFO March 2026 (detailed forecast tables, economy, Table 1.16)"
250+
in content
251+
)
252+
assert (
253+
"https://obr.uk/efo/economic-and-fiscal-outlook-march-2026/" in content
233254
)
234-
assert "https://obr.uk/efo/economic-and-fiscal-outlook-march-2026/" in content
235255

236256

237257
def test_update_yoy_growth_yaml_keeps_existing_values_when_obr_has_blank_years(
238258
tmp_path,
239259
):
240260
yaml_path = tmp_path / "yoy_growth.yaml"
241-
yaml_path.write_text(
242-
"""obr:
261+
yaml_path.write_text("""obr:
243262
mortgage_interest:
244263
values:
245264
2025-01-01: 0.0000
@@ -303,8 +322,7 @@ def test_update_yoy_growth_yaml_keeps_existing_values_when_obr_has_blank_years(
303322
reference:
304323
- title: OBR EFO November 2025 (detailed forecast tables, economy, Table 1.7)
305324
href: https://obr.uk/efo/economic-and-fiscal-outlook-november-2025/
306-
"""
307-
)
325+
""")
308326

309327
update_yoy_growth_yaml(
310328
yaml_path=yaml_path,

policyengine_uk/utils/import_obr_forecasts.py

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ def read_url_bytes(url: str) -> bytes:
126126
return response.read()
127127

128128

129-
def load_source_bytes(url: str | None, file_path: str | None) -> tuple[str, bytes]:
129+
def load_source_bytes(
130+
url: str | None, file_path: str | None
131+
) -> tuple[str, bytes]:
130132
if bool(url) == bool(file_path):
131133
raise ValueError("Pass exactly one of --url or --file")
132134

@@ -155,10 +157,13 @@ def extract_economy_workbook_bytes(
155157
candidates = [
156158
name
157159
for name in archive.namelist()
158-
if name.lower().endswith(".xlsx") and "economy" in Path(name).name.lower()
160+
if name.lower().endswith(".xlsx")
161+
and "economy" in Path(name).name.lower()
159162
]
160163
if not candidates:
161-
raise ValueError("Could not find an economy workbook in the source")
164+
raise ValueError(
165+
"Could not find an economy workbook in the source"
166+
)
162167
candidates.sort()
163168
workbook_name = candidates[0]
164169
return Path(workbook_name).name, archive.read(workbook_name)
@@ -171,7 +176,9 @@ def _load_shared_strings(archive: ZipFile) -> list[str]:
171176
root = ET.fromstring(archive.read("xl/sharedStrings.xml"))
172177
values: list[str] = []
173178
for item in root.findall("main:si", WORKBOOK_NS):
174-
parts = [node.text or "" for node in item.iterfind(".//main:t", WORKBOOK_NS)]
179+
parts = [
180+
node.text or "" for node in item.iterfind(".//main:t", WORKBOOK_NS)
181+
]
175182
values.append("".join(parts))
176183
return values
177184

@@ -208,7 +215,9 @@ def _cell_value(cell: ET.Element, shared_strings: list[str]) -> str | None:
208215
return None
209216

210217

211-
def read_sheet_rows(xlsx_bytes: bytes, sheet_name: str) -> list[dict[str, str | None]]:
218+
def read_sheet_rows(
219+
xlsx_bytes: bytes, sheet_name: str
220+
) -> list[dict[str, str | None]]:
212221
with ZipFile(BytesIO(xlsx_bytes)) as archive:
213222
shared_strings = _load_shared_strings(archive)
214223
sheet_path = _sheet_paths(archive)[sheet_name]
@@ -225,26 +234,38 @@ def read_sheet_rows(xlsx_bytes: bytes, sheet_name: str) -> list[dict[str, str |
225234
return rows
226235

227236

228-
def find_series_column(rows: list[dict[str, str | None]], spec: SeriesSpec) -> str:
237+
def find_series_column(
238+
rows: list[dict[str, str | None]], spec: SeriesSpec
239+
) -> str:
229240
headers: dict[str, str] = {}
230241
for row in rows[:4]:
231242
for column, value in row.items():
232243
label = normalise_label(value)
233-
if label and (column not in headers or is_generic_header(headers[column])):
244+
if label and (
245+
column not in headers or is_generic_header(headers[column])
246+
):
234247
headers[column] = label
235248

236249
header = headers.get(spec.column, "")
237250
if spec.mode == "exact" and header in spec.needles:
238251
return spec.column
239-
if spec.mode == "contains" and any(needle in header for needle in spec.needles):
252+
if spec.mode == "contains" and any(
253+
needle in header for needle in spec.needles
254+
):
240255
return spec.column
241-
if spec.mode == "contains_all" and all(needle in header for needle in spec.needles):
256+
if spec.mode == "contains_all" and all(
257+
needle in header for needle in spec.needles
258+
):
242259
return spec.column
243260

244-
raise ValueError(f"Could not find a column for {spec.key} in sheet {spec.sheet}")
261+
raise ValueError(
262+
f"Could not find a column for {spec.key} in sheet {spec.sheet}"
263+
)
245264

246265

247-
def extract_annual_series_from_xlsx(xlsx_bytes: bytes) -> dict[str, dict[int, float]]:
266+
def extract_annual_series_from_xlsx(
267+
xlsx_bytes: bytes,
268+
) -> dict[str, dict[int, float]]:
248269
rows_by_sheet = {
249270
sheet: read_sheet_rows(xlsx_bytes, sheet)
250271
for sheet in {spec.sheet for spec in SERIES_SPECS}
@@ -307,10 +328,14 @@ def replace_year_value(section: str, year: int, value: float) -> str:
307328

308329

309330
def replace_first_reference(section: str, title: str, href: str) -> str:
310-
title_pattern = re.compile(r"(^ - title:\s*).*$", flags=re.MULTILINE)
331+
title_pattern = re.compile(
332+
r"(^ - title:\s*).*$", flags=re.MULTILINE
333+
)
311334
href_pattern = re.compile(r"(^ href:\s*).*$", flags=re.MULTILINE)
312335

313-
updated, title_count = title_pattern.subn(rf"\g<1>{title}", section, count=1)
336+
updated, title_count = title_pattern.subn(
337+
rf"\g<1>{title}", section, count=1
338+
)
314339
if title_count == 0:
315340
raise ValueError("Could not find reference title in section")
316341

@@ -320,7 +345,9 @@ def replace_first_reference(section: str, title: str, href: str) -> str:
320345
return updated
321346

322347

323-
def replace_series_section(content: str, series_key: str, updated_section: str) -> str:
348+
def replace_series_section(
349+
content: str, series_key: str, updated_section: str
350+
) -> str:
324351
pattern = re.compile(
325352
rf"(^ {series_key}:\n.*?)(?=^ [a-z_]+:|\Z)",
326353
flags=re.MULTILINE | re.DOTALL,
@@ -356,7 +383,9 @@ def update_yoy_growth_yaml(
356383

357384
available_years = [
358385
target_year
359-
for target_year in range(forecast_start_year, forecast_end_year + 1)
386+
for target_year in range(
387+
forecast_start_year, forecast_end_year + 1
388+
)
360389
if target_year in series_values[spec.key]
361390
]
362391
if not available_years:
@@ -458,12 +487,14 @@ def main(argv: list[str] | None = None) -> int:
458487
month = args.release_month.capitalize()
459488
year = args.release_year
460489
elif args.release_month or args.release_year:
461-
raise ValueError("Pass both --release-month and --release-year together")
490+
raise ValueError(
491+
"Pass both --release-month and --release-year together"
492+
)
462493
else:
463494
month, year = infer_release(f"{source_name} {workbook_name}")
464495

465-
forecast_start_year = args.forecast_start_year or infer_forecast_start_year(
466-
month, year
496+
forecast_start_year = (
497+
args.forecast_start_year or infer_forecast_start_year(month, year)
467498
)
468499
print_summary(series_values, forecast_start_year, args.forecast_years)
469500

0 commit comments

Comments
 (0)