Skip to content

Commit ce8e19e

Browse files
Skip some targets
1 parent 3d8d6ce commit ce8e19e

2 files changed

Lines changed: 50 additions & 33 deletions

File tree

policyengine_uk_data/datasets/frs/local_areas/constituencies/calibrate.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def calibrate(
4646
# Weights - 650 x 100180
4747
original_weights = np.log(
4848
sim.calculate("household_weight", 2025).values / COUNT_CONSTITUENCIES
49-
+ np.random.random(len(sim.calculate("household_weight", 2025).values)) * 0.1
49+
+ np.random.random(len(sim.calculate("household_weight", 2025).values))
50+
* 0.01
5051
)
5152
weights = torch.tensor(
5253
np.ones((COUNT_CONSTITUENCIES, len(original_weights)))
@@ -90,7 +91,7 @@ def loss(w, validation: bool = False):
9091
else:
9192
mse_n = torch.mean((pred_n / (1 + y_national) - 1) ** 2)
9293

93-
return mse_c# + mse_n
94+
return mse_c + mse_n
9495

9596
def pct_close(w, t=0.1, constituency=True, national=True):
9697
# Return the percentage of metrics that are within t% of the target
@@ -124,7 +125,7 @@ def dropout_weights(weights, p):
124125
masked_weights[mask] = mean
125126
return masked_weights
126127

127-
optimizer = torch.optim.Adam([weights], lr=0.15)
128+
optimizer = torch.optim.Adam([weights], lr=1e-1)
128129

129130
desc = range(128) if os.environ.get("DATA_LITE") else range(epochs)
130131
final_weights = (torch.exp(weights) * r).detach().numpy()
@@ -134,10 +135,8 @@ def dropout_weights(weights, p):
134135
optimizer.zero_grad()
135136
weights_ = torch.exp(dropout_weights(weights, 0.05)) * r
136137
l = loss(weights_)
137-
l.backward()
138-
optimizer.step()
139-
c_close = pct_close(weights_, constituency=True, national=False)
140-
n_close = pct_close(weights_, constituency=False, national=True)
138+
c_close = pct_close(weights_, constituency=True, national=False, t=0.1)
139+
n_close = pct_close(weights_, constituency=False, national=True, t=0.1)
141140
if epoch % 1 == 0:
142141
if dropout_targets:
143142
validation_loss = loss(weights_, validation=True)
@@ -182,6 +181,8 @@ def dropout_weights(weights, p):
182181
f.create_dataset(
183182
"household_weight/2025", data=final_weights.sum(axis=0)
184183
)
184+
l.backward()
185+
optimizer.step()
185186

186187
return final_weights
187188

policyengine_uk_data/datasets/frs/local_areas/constituencies/loss.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -82,39 +82,60 @@ def create_constituency_target_matrix(
8282
employment_incomes.employment_income_lower_bound.sort_values().unique()
8383
) + [np.inf]
8484

85-
employment_incomes_all = employment_incomes.groupby("code")[["employment_income_count","employment_income_amount"]].sum().reset_index()
86-
85+
employment_incomes_all = (
86+
employment_incomes.groupby("code")[
87+
["employment_income_count", "employment_income_amount"]
88+
]
89+
.sum()
90+
.reset_index()
91+
)
8792

8893
hmrc_all_count_target = incomes["employment_income_count"].values
89-
ons_all_count_target = employment_incomes_all["employment_income_count"].values
94+
ons_all_count_target = employment_incomes_all[
95+
"employment_income_count"
96+
].values
9097
count_scaling_factors = hmrc_all_count_target / ons_all_count_target
9198

9299
hmrc_all_amount_target = incomes["employment_income_amount"].values
93-
ons_all_amount_target = employment_incomes_all["employment_income_amount"].values
100+
ons_all_amount_target = employment_incomes_all[
101+
"employment_income_amount"
102+
].values
94103
amount_scaling_factors = hmrc_all_amount_target / ons_all_amount_target
95104

96-
print(f"Average count scaling factor: {count_scaling_factors.mean():.1%}")
97-
print(f"Average count (HMRC): {hmrc_all_count_target.mean()/1e3:,.0f} (thousands)")
98-
print(f"Average count (ONS): {ons_all_count_target.mean()/1e3:,.0f} (thousands)")
99-
print(f"Average amount scaling factor: {amount_scaling_factors.mean():.1%}")
100-
print(f"Average amount (HMRC): {hmrc_all_amount_target.mean()/1e6:,.0f} (millions)")
101-
print(f"Average amount (ONS): {ons_all_amount_target.mean()/1e6:,.0f} (millions)")
102-
103105
for lower_bound, upper_bound in zip(bounds[:-1], bounds[1:]):
104-
continue
105-
if lower_bound <= 12_570:
106+
if (
107+
lower_bound <= 15_000
108+
): # Skip some targets with very small sample sizes
106109
continue
107110
if upper_bound >= 200_000:
108111
continue
109-
count_target = employment_incomes[
110-
(employment_incomes.employment_income_lower_bound == lower_bound)
111-
& (employment_incomes.employment_income_upper_bound == upper_bound)
112-
].employment_income_count.values * count_scaling_factors
112+
count_target = (
113+
employment_incomes[
114+
(
115+
employment_incomes.employment_income_lower_bound
116+
== lower_bound
117+
)
118+
& (
119+
employment_incomes.employment_income_upper_bound
120+
== upper_bound
121+
)
122+
].employment_income_count.values
123+
* count_scaling_factors
124+
)
113125

114-
amount_target = employment_incomes[
115-
(employment_incomes.employment_income_lower_bound == lower_bound)
116-
& (employment_incomes.employment_income_upper_bound == upper_bound)
117-
].employment_income_amount.values * amount_scaling_factors
126+
amount_target = (
127+
employment_incomes[
128+
(
129+
employment_incomes.employment_income_lower_bound
130+
== lower_bound
131+
)
132+
& (
133+
employment_incomes.employment_income_upper_bound
134+
== upper_bound
135+
)
136+
].employment_income_amount.values
137+
* amount_scaling_factors
138+
)
118139

119140
if count_target.mean() < 200:
120141
print(
@@ -135,11 +156,6 @@ def create_constituency_target_matrix(
135156
& (age >= 16)
136157
)
137158
band_str = f"{lower_bound}_{upper_bound}"
138-
matrix[f"hmrc/employment_income/count/{band_str}"] = sim.map_result(
139-
in_bound, "person", "household"
140-
)
141-
y[f"hmrc/employment_income/count/{band_str}"] = count_target
142-
143159
matrix[f"hmrc/employment_income/amount/{band_str}"] = sim.map_result(
144160
employment_income * in_bound, "person", "household"
145161
)

0 commit comments

Comments
 (0)