88
99import pandas as pd
1010import numpy as np
11+ import os
1112from policyengine_uk_data .storage import STORAGE_FOLDER
1213from policyengine_uk .data import UKSingleYearDataset
1314from policyengine_uk import Microsimulation
15+ from policyengine_uk_data .datasets .spi import (
16+ AGE_RANGES ,
17+ REGION_MAP ,
18+ SPI_RELEASE_NAME ,
19+ SPI_TAB_FILENAME ,
20+ )
1421from policyengine_uk_data .utils .stack import stack_datasets
1522from policyengine_uk_data .utils .subsample import subsample_dataset
1623
17- SPI_TAB_FOLDER = STORAGE_FOLDER / "spi_2020_21"
24+ SPI_TAB_FOLDER = STORAGE_FOLDER / SPI_RELEASE_NAME
1825SPI_RENAMES = dict (
1926 private_pension_income = "PENSION" ,
2027 self_employment_income = "PROFITS" ,
3744)
3845
3946
40- def generate_spi_table (spi : pd .DataFrame ):
47+ def _spi_age_bounds (age_code ) -> tuple [int , int ]:
48+ try :
49+ return AGE_RANGES [int (age_code )]
50+ except (TypeError , ValueError , KeyError ):
51+ return AGE_RANGES [- 1 ]
52+
53+
54+ def generate_spi_table (
55+ spi : pd .DataFrame ,
56+ seed : int = 0 ,
57+ sample_size : int | None = 100_000 ,
58+ ):
4159 """
4260 Clean and transform SPI data for income imputation model training.
4361
@@ -47,29 +65,12 @@ def generate_spi_table(spi: pd.DataFrame):
4765 Returns:
4866 Cleaned DataFrame with age and region mappings applied.
4967 """
50- LOWER = np .array ([0 , 16 , 25 , 35 , 45 , 55 , 65 , 75 ])
51- UPPER = np .array ([16 , 25 , 35 , 45 , 55 , 65 , 75 , 80 ])
68+ rng = np .random .default_rng (seed )
5269 age_range = spi .AGERANGE
53- spi ["age" ] = LOWER [age_range ] + np .random .rand (len (spi )) * (
54- UPPER [age_range ] - LOWER [age_range ]
55- )
70+ bounds = np .array ([_spi_age_bounds (age ) for age in age_range ])
71+ spi ["age" ] = bounds [:, 0 ] + rng .random (len (spi )) * (bounds [:, 1 ] - bounds [:, 0 ])
5672
57- REGIONS = {
58- 1 : "NORTH_EAST" ,
59- 2 : "NORTH_WEST" ,
60- 3 : "YORKSHIRE" ,
61- 4 : "EAST_MIDLANDS" ,
62- 5 : "WEST_MIDLANDS" ,
63- 6 : "EAST_OF_ENGLAND" ,
64- 7 : "LONDON" ,
65- 8 : "SOUTH_EAST" ,
66- 9 : "SOUTH_WEST" ,
67- 10 : "WALES" ,
68- 11 : "SCOTLAND" ,
69- 12 : "NORTHERN_IRELAND" ,
70- }
71-
72- spi ["region" ] = np .array ([REGIONS .get (x , "LONDON" ) for x in spi .GORCODE ])
73+ spi ["region" ] = spi .GORCODE .map (REGION_MAP ).fillna ("UNKNOWN" )
7374
7475 spi ["gender" ] = np .where (spi .SEX == 1 , "MALE" , "FEMALE" )
7576
@@ -78,11 +79,17 @@ def generate_spi_table(spi: pd.DataFrame):
7879
7980 spi ["employment_income" ] = spi [["PAY" , "EPB" , "TAXTERM" ]].sum (axis = 1 )
8081
81- spi = pd .concat (
82- [
83- spi .sample (100_000 , weights = spi .person_weight , replace = True ),
84- ]
85- )
82+ if sample_size is not None :
83+ spi = pd .concat (
84+ [
85+ spi .sample (
86+ sample_size ,
87+ weights = spi .person_weight ,
88+ replace = True ,
89+ random_state = seed ,
90+ ),
91+ ]
92+ )
8693
8794 return spi
8895
@@ -119,7 +126,35 @@ def generate_spi_table(spi: pd.DataFrame):
119126IMPUTATIONS = INCOME_COMPONENTS + ["gift_aid" , "charitable_investment_gifts" ]
120127
121128
122- INCOME_MODEL_PATH = STORAGE_FOLDER / "income.pkl"
129+ INCOME_MODEL_METADATA = {
130+ "spi_release_name" : SPI_RELEASE_NAME ,
131+ "spi_tab_filename" : SPI_TAB_FILENAME ,
132+ "imputations" : tuple (IMPUTATIONS ),
133+ }
134+ INCOME_MODEL_PATH = STORAGE_FOLDER / f"income_{ SPI_RELEASE_NAME } .pkl"
135+ INCOME_MODEL_SAMPLE_SIZE = 100_000
136+ TESTING_INCOME_MODEL_SAMPLE_SIZE = 10_000
137+
138+
139+ def get_income_model_sample_size () -> int :
140+ if os .environ .get ("TESTING" , "0" ) == "1" :
141+ return TESTING_INCOME_MODEL_SAMPLE_SIZE
142+ return INCOME_MODEL_SAMPLE_SIZE
143+
144+
145+ def get_income_model_metadata () -> dict :
146+ return {
147+ ** INCOME_MODEL_METADATA ,
148+ "sample_size" : get_income_model_sample_size (),
149+ }
150+
151+
152+ def _income_model_matches_current_release (model ) -> bool :
153+ if getattr (model , "metadata" , {}) != get_income_model_metadata ():
154+ return False
155+
156+ cached_outputs = set (getattr (model .model , "imputed_variables" , []))
157+ return cached_outputs == set (IMPUTATIONS )
123158
124159
125160def save_imputation_models ():
@@ -132,8 +167,9 @@ def save_imputation_models():
132167 from policyengine_uk_data .utils import QRF
133168
134169 income = QRF ()
135- spi = pd .read_csv (SPI_TAB_FOLDER / "put2021uk.tab" , delimiter = "\t " )
136- spi = generate_spi_table (spi )
170+ income .metadata = get_income_model_metadata ()
171+ spi = pd .read_csv (SPI_TAB_FOLDER / SPI_TAB_FILENAME , delimiter = "\t " )
172+ spi = generate_spi_table (spi , sample_size = get_income_model_sample_size ())
137173 spi = spi [PREDICTORS + IMPUTATIONS ]
138174 income .fit (spi [PREDICTORS ], spi [IMPUTATIONS ])
139175 income .save (INCOME_MODEL_PATH )
@@ -144,10 +180,9 @@ def create_income_model(overwrite_existing: bool = False):
144180 """
145181 Create or load income imputation model.
146182
147- If a cached model exists and its trained output columns don't match the
148- current ``IMPUTATIONS`` list, the cache is discarded and the model is
149- retrained. This handles the case where ``IMPUTATIONS`` is extended in
150- code but an older pickle is still on disk.
183+ If a cached model exists and its training metadata or output columns don't
184+ match the current SPI release and ``IMPUTATIONS`` list, the cache is
185+ discarded and the model is retrained.
151186
152187 Args:
153188 overwrite_existing: Whether to retrain model if it exists.
@@ -159,10 +194,9 @@ def create_income_model(overwrite_existing: bool = False):
159194
160195 if INCOME_MODEL_PATH .exists () and not overwrite_existing :
161196 cached = QRF (file_path = INCOME_MODEL_PATH )
162- cached_outputs = set (getattr (cached .model , "imputed_variables" , []))
163- if cached_outputs == set (IMPUTATIONS ):
197+ if _income_model_matches_current_release (cached ):
164198 return cached
165- # Cached model was trained against a different output set; retrain .
199+ # Cached model was trained against a different SPI release or output set.
166200 return save_imputation_models ()
167201
168202
0 commit comments