Skip to content

Commit d70f24a

Browse files
committed
Add penalization term to calibration loss for alpha/beta uncertainty (#419)
When alpha_uncertainty / beta_uncertainty are provided in the experiment config, the calibration training loss now includes penalty terms that constrain inferred alpha and beta to stay near their guess values, weighted by the specified uncertainties. Made-with: Cursor
1 parent c227338 commit d70f24a

2 files changed

Lines changed: 161 additions & 10 deletions

File tree

ml/Neural_Net_Classes.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,62 @@ def train_model(
156156
break
157157

158158

159+
def _calibration_penalty(
160+
c_normcal,
161+
o_normcal,
162+
c_guess,
163+
o_guess,
164+
c_norm,
165+
o_norm,
166+
alpha_uncertainty,
167+
beta_uncertainty,
168+
):
169+
"""Compute penalty that keeps inferred alpha/beta near their guess values.
170+
171+
The inferred calibration is obtained by composing the guess calibration,
172+
normalization, and learned normalized calibration (see build_inferred_calibration).
173+
From those compositions:
174+
alpha_inferred = 1 / (c_guess * c_normcal)
175+
beta_inferred = o_guess + c_guess*o_norm + c_guess*c_norm*o_normcal
176+
- c_guess*c_normcal*o_norm
177+
178+
The penalty is sum((alpha_I - alpha_G)^2 / alpha_U^2)
179+
+ sum((beta_I - beta_G )^2 / beta_U^2)
180+
where alpha_G = 1/c_guess and beta_G = o_guess.
181+
Dimensions with infinite uncertainty contribute zero penalty.
182+
"""
183+
c_inferred = c_guess * c_normcal
184+
alpha_inferred = 1.0 / c_inferred
185+
alpha_guess = 1.0 / c_guess
186+
187+
beta_inferred = (
188+
o_guess + c_guess * o_norm + c_guess * c_norm * o_normcal - c_inferred * o_norm
189+
)
190+
beta_guess = o_guess
191+
192+
penalty_alpha = torch.sum(
193+
(alpha_inferred - alpha_guess) ** 2 / alpha_uncertainty**2
194+
)
195+
penalty_beta = torch.sum((beta_inferred - beta_guess) ** 2 / beta_uncertainty**2)
196+
return penalty_alpha + penalty_beta
197+
198+
159199
def train_calibration(
160200
model,
161201
exp_inputs,
162202
exp_targets,
203+
c_guess_input,
204+
o_guess_input,
205+
c_norm_input,
206+
o_norm_input,
207+
alpha_uncertainty_input,
208+
beta_uncertainty_input,
209+
c_guess_output,
210+
o_guess_output,
211+
c_norm_output,
212+
o_norm_output,
213+
alpha_uncertainty_output,
214+
beta_uncertainty_output,
163215
num_epochs=5000,
164216
lr=0.001,
165217
):
@@ -174,10 +226,26 @@ def train_calibration(
174226
calibrated_input = (1 / c_normcal_input) * (x - o_normcal_input)
175227
calibrated_output = c_normcal_output * model(calibrated_input) + o_normcal_output
176228
229+
A penalization term is added to the loss to keep the inferred alpha/beta
230+
close to their guess values (see _calibration_penalty). Dimensions with
231+
infinite uncertainty contribute zero penalty.
232+
177233
Args:
178234
model: frozen callable that maps exp_inputs -> predictions
179235
exp_inputs: experimental input tensor
180236
exp_targets: experimental target values (may contain NaN)
237+
c_guess_input: guess calibration coefficients for inputs
238+
o_guess_input: guess calibration offsets for inputs
239+
c_norm_input: normalization coefficients for inputs
240+
o_norm_input: normalization offsets for inputs
241+
alpha_uncertainty_input: uncertainty on alpha for inputs (inf = no penalty)
242+
beta_uncertainty_input: uncertainty on beta for inputs (inf = no penalty)
243+
c_guess_output: guess calibration coefficients for outputs
244+
o_guess_output: guess calibration offsets for outputs
245+
c_norm_output: normalization coefficients for outputs
246+
o_norm_output: normalization offsets for outputs
247+
alpha_uncertainty_output: uncertainty on alpha for outputs (inf = no penalty)
248+
beta_uncertainty_output: uncertainty on beta for outputs (inf = no penalty)
181249
num_epochs: number of training epochs
182250
lr: learning rate
183251
@@ -217,7 +285,30 @@ def train_calibration(
217285
base_predictions = model(calibrated_inputs)
218286
calibrated_outputs = c_normcal_output * base_predictions + o_normcal_output
219287

220-
loss = nan_mse_loss(exp_targets, calibrated_outputs)
288+
loss = (
289+
nan_mse_loss(exp_targets, calibrated_outputs)
290+
+ _calibration_penalty(
291+
c_normcal_input,
292+
o_normcal_input,
293+
c_guess_input,
294+
o_guess_input,
295+
c_norm_input,
296+
o_norm_input,
297+
alpha_uncertainty_input,
298+
beta_uncertainty_input,
299+
)
300+
+ _calibration_penalty(
301+
c_normcal_output,
302+
o_normcal_output,
303+
c_guess_output,
304+
o_guess_output,
305+
c_norm_output,
306+
o_norm_output,
307+
alpha_uncertainty_output,
308+
beta_uncertainty_output,
309+
)
310+
)
311+
221312
loss.backward()
222313
optimizer.step()
223314

ml/train_model.py

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -200,31 +200,47 @@ def build_guess_calibration(config_dict, input_variables, output_variables):
200200

201201
def _get_calibration(exp_name):
202202
if exp_name in depends_on_lookup:
203-
# Experimental variables is part of the "simulation_calibration" section
204203
entry = depends_on_lookup[exp_name]
205-
return entry["name"], entry["alpha_guess"], entry["beta_guess"]
204+
return (
205+
entry["name"],
206+
entry["alpha_guess"],
207+
entry["beta_guess"],
208+
entry.get("alpha_uncertainty", float("inf")),
209+
entry.get("beta_uncertainty", float("inf")),
210+
)
206211
else:
207-
# Experimental variable is not part of the "simulation_calibration" section
208-
# In this case, no calibration is needed ; the simulation variable is identical
209-
return exp_name, 1.0, 0.0
212+
# No calibration needed; the simulation variable is identical
213+
return exp_name, 1.0, 0.0, float("inf"), float("inf")
210214

211215
# Build the list of simulation variables
212216
sim_input_names = []
213217
alpha_input_list = []
214218
beta_input_list = []
219+
alpha_uncertainty_input_list = []
220+
beta_uncertainty_input_list = []
215221
for key in input_variables:
216-
sim_name, alpha, beta = _get_calibration(input_variables[key]["name"])
222+
sim_name, alpha, beta, alpha_u, beta_u = _get_calibration(
223+
input_variables[key]["name"]
224+
)
217225
sim_input_names.append(sim_name)
218226
alpha_input_list.append(alpha)
219227
beta_input_list.append(beta)
228+
alpha_uncertainty_input_list.append(alpha_u)
229+
beta_uncertainty_input_list.append(beta_u)
220230
sim_output_names = []
221231
alpha_output_list = []
222232
beta_output_list = []
233+
alpha_uncertainty_output_list = []
234+
beta_uncertainty_output_list = []
223235
for key in output_variables:
224-
sim_name, alpha, beta = _get_calibration(output_variables[key]["name"])
236+
sim_name, alpha, beta, alpha_u, beta_u = _get_calibration(
237+
output_variables[key]["name"]
238+
)
225239
sim_output_names.append(sim_name)
226240
alpha_output_list.append(alpha)
227241
beta_output_list.append(beta)
242+
alpha_uncertainty_output_list.append(alpha_u)
243+
beta_uncertainty_output_list.append(beta_u)
228244

229245
# Build the AffineInputTransforms for the guess calibration
230246
alpha_inputs = torch.tensor(alpha_input_list, dtype=torch.float)
@@ -240,11 +256,22 @@ def _get_calibration(exp_name):
240256
n_outputs, coefficient=1.0 / alpha_outputs, offset=beta_outputs
241257
)
242258

259+
uncertainty_inputs = {
260+
"alpha": torch.tensor(alpha_uncertainty_input_list, dtype=torch.float),
261+
"beta": torch.tensor(beta_uncertainty_input_list, dtype=torch.float),
262+
}
263+
uncertainty_outputs = {
264+
"alpha": torch.tensor(alpha_uncertainty_output_list, dtype=torch.float),
265+
"beta": torch.tensor(beta_uncertainty_output_list, dtype=torch.float),
266+
}
267+
243268
return (
244269
input_guess_calibration,
245270
output_guess_calibration,
246271
sim_input_names,
247272
sim_output_names,
273+
uncertainty_inputs,
274+
uncertainty_outputs,
248275
)
249276

250277

@@ -307,11 +334,18 @@ def train_calibration_phase(
307334
input_names,
308335
output_names,
309336
device,
337+
input_guess_calibration,
338+
output_guess_calibration,
339+
input_normalization,
340+
output_normalization,
341+
uncertainty_inputs,
342+
uncertainty_outputs,
310343
):
311344
"""Phase 2: Train calibration layers on experimental data.
312345
313346
Passes the frozen model to train_calibration(), which re-evaluates it at
314-
each iteration.
347+
each iteration. A penalization term constrains inferred alpha/beta toward
348+
their guess values, weighted by the provided uncertainties.
315349
316350
Returns an AffineInputTransform representing the learned calibration.
317351
"""
@@ -336,7 +370,25 @@ def predict_fn(x):
336370

337371
# Train calibration
338372
c_normcal_input, o_normcal_input, c_normcal_output, o_normcal_output = (
339-
train_calibration(predict_fn, exp_X, exp_y, num_epochs=5000, lr=0.001)
373+
train_calibration(
374+
predict_fn,
375+
exp_X,
376+
exp_y,
377+
c_guess_input=input_guess_calibration.coefficient.to(device),
378+
o_guess_input=input_guess_calibration.offset.to(device),
379+
c_norm_input=input_normalization.coefficient.to(device),
380+
o_norm_input=input_normalization.offset.to(device),
381+
alpha_uncertainty_input=uncertainty_inputs["alpha"].to(device),
382+
beta_uncertainty_input=uncertainty_inputs["beta"].to(device),
383+
c_guess_output=output_guess_calibration.coefficient.to(device),
384+
o_guess_output=output_guess_calibration.offset.to(device),
385+
c_norm_output=output_normalization.coefficient.to(device),
386+
o_norm_output=output_normalization.offset.to(device),
387+
alpha_uncertainty_output=uncertainty_outputs["alpha"].to(device),
388+
beta_uncertainty_output=uncertainty_outputs["beta"].to(device),
389+
num_epochs=5000,
390+
lr=0.001,
391+
)
340392
)
341393

342394
# Build calibration transforms
@@ -564,6 +616,8 @@ def register_model_to_mlflow(model, model_type, experiment, config_dict):
564616
output_guess_calibration,
565617
sim_input_names,
566618
sim_output_names,
619+
uncertainty_inputs,
620+
uncertainty_outputs,
567621
) = build_guess_calibration(config_dict, input_variables, output_variables)
568622

569623
# Convert experimental data to simulation variable space
@@ -649,6 +703,12 @@ def register_model_to_mlflow(model, model_type, experiment, config_dict):
649703
sim_input_names,
650704
sim_output_names,
651705
device,
706+
input_guess_calibration=input_guess_calibration,
707+
output_guess_calibration=output_guess_calibration,
708+
input_normalization=input_normalization,
709+
output_normalization=output_normalization,
710+
uncertainty_inputs=uncertainty_inputs,
711+
uncertainty_outputs=uncertainty_outputs,
652712
)
653713
)
654714

0 commit comments

Comments
 (0)