From 46c8fd7495cf33a315ae083ac2068efd9b22c27b Mon Sep 17 00:00:00 2001 From: RevathiJambunathan Date: Thu, 18 Jun 2026 09:42:33 -0700 Subject: [PATCH 1/2] de-duplicate training --- ml/Neural_Net_Classes.py | 152 ++++++++++++++++++++++----------------- 1 file changed, 86 insertions(+), 66 deletions(-) diff --git a/ml/Neural_Net_Classes.py b/ml/Neural_Net_Classes.py index 1b2f3b6..677d9c5 100644 --- a/ml/Neural_Net_Classes.py +++ b/ml/Neural_Net_Classes.py @@ -53,6 +53,66 @@ def mask_grad_hook(grad): return mse_loss +def _run_training_loop( + parameters, + forward_fn, + train_inputs, + train_targets, + val_inputs=None, + val_targets=None, + num_epochs=1500, + lr=0.001, + patience_lr=100, + patience_early=150, + factor=0.5, + threshold=1e-4, + label="", +): + prefix = f"{label} " if label else "" + optimizer = optim.Adam(parameters, lr=lr) + scheduler = ReduceLROnPlateau( + optimizer, "min", factor=factor, patience=patience_lr, threshold=threshold + ) + early_stopper = EarlyStopping(patience=patience_early) + + for epoch in range(num_epochs): + optimizer.zero_grad() + + outputs = forward_fn(train_inputs) + loss = nan_mse_loss(train_targets, outputs) + loss.backward() + optimizer.step() + + current_loss = loss.item() + + if val_inputs is not None: + with torch.no_grad(): + val_outputs = forward_fn(val_inputs) + val_loss = nan_mse_loss(val_targets, val_outputs) + monitor_loss = val_loss.item() + else: + monitor_loss = current_loss + + scheduler.step(monitor_loss) + + if (epoch + 1) % (num_epochs / 10) == 0: + if val_inputs is not None: + print( + f"{prefix}Epoch [{epoch + 1}/{num_epochs}], Loss:{current_loss:.6f}, Val Loss:{monitor_loss:.6f}" + ) + else: + print( + f"{prefix}Epoch [{epoch + 1}/{num_epochs}], Loss:{current_loss:.6f}" + ) + + early_stopper(monitor_loss) + if early_stopper.early_stop: + print( + f"{prefix}Early stopping at epoch {epoch} with loss {monitor_loss:.6f}" + ) + break + + class CombinedNN(nn.Module): """ 5 layer neural network @@ -117,43 +177,20 @@ def train_model( val_targets, num_epochs=1500, ): - optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) - scheduler = ReduceLROnPlateau( - optimizer, - "min", + _run_training_loop( + parameters=list(self.parameters()), + forward_fn=self, + train_inputs=train_inputs, + train_targets=train_targets, + val_inputs=val_inputs, + val_targets=val_targets, + num_epochs=num_epochs, + lr=self.learning_rate, + patience_lr=self.patience_LRreduction, + patience_early=self.patience_earlystopping, factor=self.factor, - patience=self.patience_LRreduction, threshold=self.threshold, ) - early_stopper = EarlyStopping(patience=self.patience_earlystopping) - - for epoch in range(num_epochs): - optimizer.zero_grad() - - outputs = self(train_inputs) - loss = nan_mse_loss(train_targets, outputs) - loss.backward() - - optimizer.step() - - current_loss = loss.item() - scheduler.step(current_loss) - - with torch.no_grad(): - val_outputs = self(val_inputs) - val_loss = nan_mse_loss(val_targets, val_outputs) - - if (epoch + 1) % (num_epochs / 10) == 0: - print( - f"Epoch [{epoch + 1}/{num_epochs}], Loss:{loss.item():.6f}, Val Loss:{val_loss.item():.6f}" - ) - - early_stopper(val_loss.item()) - if early_stopper.early_stop: - print( - f"Early stopping triggered at epoch {epoch} with val loss {val_loss.item():.6f}" - ) - break def train_calibration( @@ -202,39 +239,22 @@ def train_calibration( torch.zeros(n_outputs, dtype=exp_inputs.dtype, device=device) ) - optimizer = optim.Adam( - [c_normcal_input, o_normcal_input, c_normcal_output, o_normcal_output], lr=lr - ) - scheduler = ReduceLROnPlateau( - optimizer, "min", factor=0.5, patience=200, threshold=1e-4 - ) - early_stopper = EarlyStopping(patience=500) - - for epoch in range(num_epochs): - optimizer.zero_grad() - - calibrated_inputs = (1.0 / c_normcal_input) * (exp_inputs - o_normcal_input) + def calibrated_forward(x): + calibrated_inputs = (1.0 / c_normcal_input) * (x - o_normcal_input) base_predictions = model(calibrated_inputs) - calibrated_outputs = c_normcal_output * base_predictions + o_normcal_output - - loss = nan_mse_loss(exp_targets, calibrated_outputs) - loss.backward() - optimizer.step() - - current_loss = loss.item() - scheduler.step(current_loss) - - if (epoch + 1) % (num_epochs / 10) == 0: - print( - f"Calibration Epoch [{epoch + 1}/{num_epochs}], Loss:{current_loss:.6f}" - ) - - early_stopper(current_loss) - if early_stopper.early_stop: - print( - f"Calibration early stopping at epoch {epoch} with loss {current_loss:.6f}" - ) - break + return c_normcal_output * base_predictions + o_normcal_output + + _run_training_loop( + parameters=[c_normcal_input, o_normcal_input, c_normcal_output, o_normcal_output], + forward_fn=calibrated_forward, + train_inputs=exp_inputs, + train_targets=exp_targets, + num_epochs=num_epochs, + lr=lr, + patience_lr=200, + patience_early=500, + label="Calibration", + ) return ( c_normcal_input.detach(), From 291e3ed54db242560943d647c3602a210b9f9944 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Jun 2026 16:56:17 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ml/Neural_Net_Classes.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ml/Neural_Net_Classes.py b/ml/Neural_Net_Classes.py index 677d9c5..d3655f5 100644 --- a/ml/Neural_Net_Classes.py +++ b/ml/Neural_Net_Classes.py @@ -245,7 +245,12 @@ def calibrated_forward(x): return c_normcal_output * base_predictions + o_normcal_output _run_training_loop( - parameters=[c_normcal_input, o_normcal_input, c_normcal_output, o_normcal_output], + parameters=[ + c_normcal_input, + o_normcal_input, + c_normcal_output, + o_normcal_output, + ], forward_fn=calibrated_forward, train_inputs=exp_inputs, train_targets=exp_targets,