@@ -125,7 +125,20 @@ def generate_spi_table(
125125IMPUTATIONS = INCOME_COMPONENTS + ["gift_aid" , "charitable_investment_gifts" ]
126126
127127
128- INCOME_MODEL_PATH = STORAGE_FOLDER / "income.pkl"
128+ INCOME_MODEL_METADATA = {
129+ "spi_release_name" : SPI_RELEASE_NAME ,
130+ "spi_tab_filename" : SPI_TAB_FILENAME ,
131+ "imputations" : tuple (IMPUTATIONS ),
132+ }
133+ INCOME_MODEL_PATH = STORAGE_FOLDER / f"income_{ SPI_RELEASE_NAME } .pkl"
134+
135+
136+ def _income_model_matches_current_release (model ) -> bool :
137+ if getattr (model , "metadata" , {}) != INCOME_MODEL_METADATA :
138+ return False
139+
140+ cached_outputs = set (getattr (model .model , "imputed_variables" , []))
141+ return cached_outputs == set (IMPUTATIONS )
129142
130143
131144def save_imputation_models ():
@@ -138,6 +151,7 @@ def save_imputation_models():
138151 from policyengine_uk_data .utils import QRF
139152
140153 income = QRF ()
154+ income .metadata = INCOME_MODEL_METADATA
141155 spi = pd .read_csv (SPI_TAB_FOLDER / SPI_TAB_FILENAME , delimiter = "\t " )
142156 spi = generate_spi_table (spi )
143157 spi = spi [PREDICTORS + IMPUTATIONS ]
@@ -150,10 +164,9 @@ def create_income_model(overwrite_existing: bool = False):
150164 """
151165 Create or load income imputation model.
152166
153- If a cached model exists and its trained output columns don't match the
154- current ``IMPUTATIONS`` list, the cache is discarded and the model is
155- retrained. This handles the case where ``IMPUTATIONS`` is extended in
156- code but an older pickle is still on disk.
167+ If a cached model exists and its training metadata or output columns don't
168+ match the current SPI release and ``IMPUTATIONS`` list, the cache is
169+ discarded and the model is retrained.
157170
158171 Args:
159172 overwrite_existing: Whether to retrain model if it exists.
@@ -165,10 +178,9 @@ def create_income_model(overwrite_existing: bool = False):
165178
166179 if INCOME_MODEL_PATH .exists () and not overwrite_existing :
167180 cached = QRF (file_path = INCOME_MODEL_PATH )
168- cached_outputs = set (getattr (cached .model , "imputed_variables" , []))
169- if cached_outputs == set (IMPUTATIONS ):
181+ if _income_model_matches_current_release (cached ):
170182 return cached
171- # Cached model was trained against a different output set; retrain .
183+ # Cached model was trained against a different SPI release or output set.
172184 return save_imputation_models ()
173185
174186
0 commit comments