Skip to content

Commit 57522f0

Browse files
committed
Format OBR forecast importer
1 parent e08bc72 commit 57522f0

2 files changed

Lines changed: 21 additions & 59 deletions

File tree

policyengine_uk/tests/test_import_obr_forecasts.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,7 @@ def make_test_xlsx() -> bytes:
5252

5353
sheet_16 = make_sheet(
5454
{
55-
3: [
56-
make_inline_cell(
57-
"Q3", "Average weekly earnings growth (per cent)"
58-
)
59-
],
55+
3: [make_inline_cell("Q3", "Average weekly earnings growth (per cent)")],
6056
97: [
6157
make_inline_cell("B97", "2025"),
6258
make_number_cell("Q97", 5.17),
@@ -142,9 +138,7 @@ def test_extract_annual_series_from_xlsx():
142138

143139

144140
def test_release_inference_helpers():
145-
assert infer_release(
146-
"Economy_Detailed_forecast_tables_November_2025.xlsx"
147-
) == (
141+
assert infer_release("Economy_Detailed_forecast_tables_November_2025.xlsx") == (
148142
"November",
149143
2025,
150144
)
@@ -246,12 +240,9 @@ def test_update_yoy_growth_yaml_updates_forecast_window_only(tmp_path):
246240
assert "2025-01-01: 0.0280" in content
247241
assert "2026-01-01: 0.0240" in content
248242
assert (
249-
"OBR EFO March 2026 (detailed forecast tables, economy, Table 1.16)"
250-
in content
251-
)
252-
assert (
253-
"https://obr.uk/efo/economic-and-fiscal-outlook-march-2026/" in content
243+
"OBR EFO March 2026 (detailed forecast tables, economy, Table 1.16)" in content
254244
)
245+
assert "https://obr.uk/efo/economic-and-fiscal-outlook-march-2026/" in content
255246

256247

257248
def test_update_yoy_growth_yaml_keeps_existing_values_when_obr_has_blank_years(

policyengine_uk/utils/import_obr_forecasts.py

Lines changed: 17 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,7 @@ def read_url_bytes(url: str) -> bytes:
126126
return response.read()
127127

128128

129-
def load_source_bytes(
130-
url: str | None, file_path: str | None
131-
) -> tuple[str, bytes]:
129+
def load_source_bytes(url: str | None, file_path: str | None) -> tuple[str, bytes]:
132130
if bool(url) == bool(file_path):
133131
raise ValueError("Pass exactly one of --url or --file")
134132

@@ -157,13 +155,10 @@ def extract_economy_workbook_bytes(
157155
candidates = [
158156
name
159157
for name in archive.namelist()
160-
if name.lower().endswith(".xlsx")
161-
and "economy" in Path(name).name.lower()
158+
if name.lower().endswith(".xlsx") and "economy" in Path(name).name.lower()
162159
]
163160
if not candidates:
164-
raise ValueError(
165-
"Could not find an economy workbook in the source"
166-
)
161+
raise ValueError("Could not find an economy workbook in the source")
167162
candidates.sort()
168163
workbook_name = candidates[0]
169164
return Path(workbook_name).name, archive.read(workbook_name)
@@ -176,9 +171,7 @@ def _load_shared_strings(archive: ZipFile) -> list[str]:
176171
root = ET.fromstring(archive.read("xl/sharedStrings.xml"))
177172
values: list[str] = []
178173
for item in root.findall("main:si", WORKBOOK_NS):
179-
parts = [
180-
node.text or "" for node in item.iterfind(".//main:t", WORKBOOK_NS)
181-
]
174+
parts = [node.text or "" for node in item.iterfind(".//main:t", WORKBOOK_NS)]
182175
values.append("".join(parts))
183176
return values
184177

@@ -215,9 +208,7 @@ def _cell_value(cell: ET.Element, shared_strings: list[str]) -> str | None:
215208
return None
216209

217210

218-
def read_sheet_rows(
219-
xlsx_bytes: bytes, sheet_name: str
220-
) -> list[dict[str, str | None]]:
211+
def read_sheet_rows(xlsx_bytes: bytes, sheet_name: str) -> list[dict[str, str | None]]:
221212
with ZipFile(BytesIO(xlsx_bytes)) as archive:
222213
shared_strings = _load_shared_strings(archive)
223214
sheet_path = _sheet_paths(archive)[sheet_name]
@@ -234,33 +225,23 @@ def read_sheet_rows(
234225
return rows
235226

236227

237-
def find_series_column(
238-
rows: list[dict[str, str | None]], spec: SeriesSpec
239-
) -> str:
228+
def find_series_column(rows: list[dict[str, str | None]], spec: SeriesSpec) -> str:
240229
headers: dict[str, str] = {}
241230
for row in rows[:4]:
242231
for column, value in row.items():
243232
label = normalise_label(value)
244-
if label and (
245-
column not in headers or is_generic_header(headers[column])
246-
):
233+
if label and (column not in headers or is_generic_header(headers[column])):
247234
headers[column] = label
248235

249236
header = headers.get(spec.column, "")
250237
if spec.mode == "exact" and header in spec.needles:
251238
return spec.column
252-
if spec.mode == "contains" and any(
253-
needle in header for needle in spec.needles
254-
):
239+
if spec.mode == "contains" and any(needle in header for needle in spec.needles):
255240
return spec.column
256-
if spec.mode == "contains_all" and all(
257-
needle in header for needle in spec.needles
258-
):
241+
if spec.mode == "contains_all" and all(needle in header for needle in spec.needles):
259242
return spec.column
260243

261-
raise ValueError(
262-
f"Could not find a column for {spec.key} in sheet {spec.sheet}"
263-
)
244+
raise ValueError(f"Could not find a column for {spec.key} in sheet {spec.sheet}")
264245

265246

266247
def extract_annual_series_from_xlsx(
@@ -328,14 +309,10 @@ def replace_year_value(section: str, year: int, value: float) -> str:
328309

329310

330311
def replace_first_reference(section: str, title: str, href: str) -> str:
331-
title_pattern = re.compile(
332-
r"(^ - title:\s*).*$", flags=re.MULTILINE
333-
)
312+
title_pattern = re.compile(r"(^ - title:\s*).*$", flags=re.MULTILINE)
334313
href_pattern = re.compile(r"(^ href:\s*).*$", flags=re.MULTILINE)
335314

336-
updated, title_count = title_pattern.subn(
337-
rf"\g<1>{title}", section, count=1
338-
)
315+
updated, title_count = title_pattern.subn(rf"\g<1>{title}", section, count=1)
339316
if title_count == 0:
340317
raise ValueError("Could not find reference title in section")
341318

@@ -345,9 +322,7 @@ def replace_first_reference(section: str, title: str, href: str) -> str:
345322
return updated
346323

347324

348-
def replace_series_section(
349-
content: str, series_key: str, updated_section: str
350-
) -> str:
325+
def replace_series_section(content: str, series_key: str, updated_section: str) -> str:
351326
pattern = re.compile(
352327
rf"(^ {series_key}:\n.*?)(?=^ [a-z_]+:|\Z)",
353328
flags=re.MULTILINE | re.DOTALL,
@@ -383,9 +358,7 @@ def update_yoy_growth_yaml(
383358

384359
available_years = [
385360
target_year
386-
for target_year in range(
387-
forecast_start_year, forecast_end_year + 1
388-
)
361+
for target_year in range(forecast_start_year, forecast_end_year + 1)
389362
if target_year in series_values[spec.key]
390363
]
391364
if not available_years:
@@ -487,14 +460,12 @@ def main(argv: list[str] | None = None) -> int:
487460
month = args.release_month.capitalize()
488461
year = args.release_year
489462
elif args.release_month or args.release_year:
490-
raise ValueError(
491-
"Pass both --release-month and --release-year together"
492-
)
463+
raise ValueError("Pass both --release-month and --release-year together")
493464
else:
494465
month, year = infer_release(f"{source_name} {workbook_name}")
495466

496-
forecast_start_year = (
497-
args.forecast_start_year or infer_forecast_start_year(month, year)
467+
forecast_start_year = args.forecast_start_year or infer_forecast_start_year(
468+
month, year
498469
)
499470
print_summary(series_values, forecast_start_year, args.forecast_years)
500471

0 commit comments

Comments
 (0)