1+ from contextlib import nullcontext
2+
13import torch
2- from policyengine_uk import Microsimulation
34import pandas as pd
45import numpy as np
56import h5py
@@ -39,59 +40,74 @@ def calibrate_local_areas(
3940 verbose: Whether to print progress
4041 area_name: Name of the area type for logging
4142 """
42- dataset = dataset .copy ()
43- matrix , y , r = matrix_fn (dataset )
43+ progress_tracker = ProcessingProgress () if verbose else None
44+
45+ def track_stage (stage_name : str ):
46+ if progress_tracker is None :
47+ return nullcontext ()
48+ return progress_tracker .track_stage (stage_name )
49+
50+ with track_stage (f"{ area_name } : copy dataset" ):
51+ dataset = dataset .copy ()
52+
53+ with track_stage (f"{ area_name } : build local target matrix" ):
54+ matrix , y , r = matrix_fn (dataset )
4455 m_c , y_c = matrix .copy (), y .copy ()
45- m_national , y_national = national_matrix_fn (dataset )
56+
57+ with track_stage (f"{ area_name } : build national target matrix" ):
58+ m_national , y_national = national_matrix_fn (dataset )
4659 m_n , y_n = m_national .copy (), y_national .copy ()
4760
48- # Weights - area_count x num_households
49- # Use country-aware initialization: divide each household's weight by the
50- # number of areas in its country, not the total area count. This ensures
51- # households start at approximately correct weight for their country's targets.
52- # The country_mask r[i,j]=1 iff household j is in same country as area i.
53- areas_per_household = r .sum (
54- axis = 0
55- ) # number of areas each household can contribute to
56- areas_per_household = np .maximum (areas_per_household , 1 ) # avoid division by zero
57- original_weights = np .log (
58- dataset .household .household_weight .values / areas_per_household
59- + np .random .random (len (dataset .household .household_weight .values )) * 0.01
60- )
61- weights = torch .tensor (
62- np .ones ((area_count , len (original_weights ))) * original_weights ,
63- dtype = torch .float32 ,
64- requires_grad = True ,
65- )
66-
67- # Set up validation targets if specified
68- validation_targets_local = (
69- matrix .columns .isin (excluded_training_targets )
70- if hasattr (matrix , "columns" )
71- else None
72- )
73- validation_targets_national = (
74- m_national .columns .isin (excluded_training_targets )
75- if hasattr (m_national , "columns" )
76- else None
77- )
78- dropout_targets = len (excluded_training_targets ) > 0
79-
80- # Convert to tensors
81- metrics = torch .tensor (
82- matrix .values if hasattr (matrix , "values" ) else matrix ,
83- dtype = torch .float32 ,
84- )
85- y = torch .tensor (y .values if hasattr (y , "values" ) else y , dtype = torch .float32 )
86- matrix_national = torch .tensor (
87- m_national .values if hasattr (m_national , "values" ) else m_national ,
88- dtype = torch .float32 ,
89- )
90- y_national = torch .tensor (
91- y_national .values if hasattr (y_national , "values" ) else y_national ,
92- dtype = torch .float32 ,
93- )
94- r = torch .tensor (r , dtype = torch .float32 )
61+ with track_stage (f"{ area_name } : prepare tensors and optimizer" ):
62+ # Weights - area_count x num_households
63+ # Use country-aware initialization: divide each household's weight by the
64+ # number of areas in its country, not the total area count. This ensures
65+ # households start at approximately correct weight for their country's targets.
66+ # The country_mask r[i,j]=1 iff household j is in same country as area i.
67+ areas_per_household = r .sum (
68+ axis = 0
69+ ) # number of areas each household can contribute to
70+ areas_per_household = np .maximum (
71+ areas_per_household , 1
72+ ) # avoid division by zero
73+ original_weights = np .log (
74+ dataset .household .household_weight .values / areas_per_household
75+ + np .random .random (len (dataset .household .household_weight .values )) * 0.01
76+ )
77+ weights = torch .tensor (
78+ np .ones ((area_count , len (original_weights ))) * original_weights ,
79+ dtype = torch .float32 ,
80+ requires_grad = True ,
81+ )
82+
83+ # Set up validation targets if specified
84+ validation_targets_local = (
85+ matrix .columns .isin (excluded_training_targets )
86+ if hasattr (matrix , "columns" )
87+ else None
88+ )
89+ validation_targets_national = (
90+ m_national .columns .isin (excluded_training_targets )
91+ if hasattr (m_national , "columns" )
92+ else None
93+ )
94+ dropout_targets = len (excluded_training_targets ) > 0
95+
96+ # Convert to tensors
97+ metrics = torch .tensor (
98+ matrix .values if hasattr (matrix , "values" ) else matrix ,
99+ dtype = torch .float32 ,
100+ )
101+ y = torch .tensor (y .values if hasattr (y , "values" ) else y , dtype = torch .float32 )
102+ matrix_national = torch .tensor (
103+ m_national .values if hasattr (m_national , "values" ) else m_national ,
104+ dtype = torch .float32 ,
105+ )
106+ y_national = torch .tensor (
107+ y_national .values if hasattr (y_national , "values" ) else y_national ,
108+ dtype = torch .float32 ,
109+ )
110+ r = torch .tensor (r , dtype = torch .float32 )
95111
96112 def sre (x , y ):
97113 one_way = ((1 + x ) / (1 + y ) - 1 ) ** 2
@@ -160,8 +176,6 @@ def dropout_weights(weights, p):
160176 final_weights = (torch .exp (weights ) * r ).detach ().numpy ()
161177 performance = pd .DataFrame ()
162178
163- progress_tracker = ProcessingProgress () if verbose else None
164-
165179 if verbose and progress_tracker :
166180 with progress_tracker .track_calibration (
167181 epochs , nested_progress
@@ -171,8 +185,8 @@ def dropout_weights(weights, p):
171185
172186 optimizer .zero_grad ()
173187 weights_ = torch .exp (dropout_weights (weights , 0.05 )) * r
174- l = loss (weights_ )
175- l .backward ()
188+ loss_value = loss (weights_ )
189+ loss_value .backward ()
176190 optimizer .step ()
177191
178192 local_close = pct_close (weights_ , local = True , national = False )
@@ -187,7 +201,9 @@ def dropout_weights(weights, p):
187201 )
188202 else :
189203 update_calibration (
190- epoch + 1 , loss_value = l .item (), calculating_loss = False
204+ epoch + 1 ,
205+ loss_value = loss_value .item (),
206+ calculating_loss = False ,
191207 )
192208
193209 if epoch % 10 == 0 :
@@ -225,8 +241,8 @@ def dropout_weights(weights, p):
225241 for epoch in range (epochs ):
226242 optimizer .zero_grad ()
227243 weights_ = torch .exp (dropout_weights (weights , 0.05 )) * r
228- l = loss (weights_ )
229- l .backward ()
244+ loss_value = loss (weights_ )
245+ loss_value .backward ()
230246 optimizer .step ()
231247
232248 local_close = pct_close (weights_ , local = True , national = False )
@@ -236,12 +252,12 @@ def dropout_weights(weights, p):
236252 if dropout_targets :
237253 validation_loss = loss (weights_ , validation = True )
238254 print (
239- f"Training loss: { l .item ():,.3f} , Validation loss: { validation_loss .item ():,.3f} , Epoch: { epoch } , "
255+ f"Training loss: { loss_value .item ():,.3f} , Validation loss: { validation_loss .item ():,.3f} , Epoch: { epoch } , "
240256 f"{ area_name } <10%: { local_close :.1%} , National<10%: { national_close :.1%} "
241257 )
242258 else :
243259 print (
244- f"Loss: { l .item ()} , Epoch: { epoch } , { area_name } <10%: { local_close :.1%} , National<10%: { national_close :.1%} "
260+ f"Loss: { loss_value .item ()} , Epoch: { epoch } , { area_name } <10%: { local_close :.1%} , National<10%: { national_close :.1%} "
245261 )
246262
247263 if epoch % 10 == 0 :
0 commit comments