Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 91 additions & 66 deletions ml/Neural_Net_Classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,66 @@ def mask_grad_hook(grad):
return mse_loss


def _run_training_loop(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a docstring for this function? Thanks!

parameters,
forward_fn,
train_inputs,
train_targets,
val_inputs=None,
val_targets=None,
num_epochs=1500,
lr=0.001,
patience_lr=100,
patience_early=150,
factor=0.5,
threshold=1e-4,
label="",
):
prefix = f"{label} " if label else ""
optimizer = optim.Adam(parameters, lr=lr)
scheduler = ReduceLROnPlateau(
optimizer, "min", factor=factor, patience=patience_lr, threshold=threshold
)
early_stopper = EarlyStopping(patience=patience_early)

for epoch in range(num_epochs):
optimizer.zero_grad()

outputs = forward_fn(train_inputs)
loss = nan_mse_loss(train_targets, outputs)
loss.backward()
optimizer.step()

current_loss = loss.item()

if val_inputs is not None:
with torch.no_grad():
val_outputs = forward_fn(val_inputs)
val_loss = nan_mse_loss(val_targets, val_outputs)
monitor_loss = val_loss.item()
else:
monitor_loss = current_loss

scheduler.step(monitor_loss)

if (epoch + 1) % (num_epochs / 10) == 0:
if val_inputs is not None:
print(
f"{prefix}Epoch [{epoch + 1}/{num_epochs}], Loss:{current_loss:.6f}, Val Loss:{monitor_loss:.6f}"
)
else:
print(
f"{prefix}Epoch [{epoch + 1}/{num_epochs}], Loss:{current_loss:.6f}"
)

early_stopper(monitor_loss)
if early_stopper.early_stop:
print(
f"{prefix}Early stopping at epoch {epoch} with loss {monitor_loss:.6f}"
)
break


class CombinedNN(nn.Module):
"""
5 layer neural network
Expand Down Expand Up @@ -117,43 +177,20 @@ def train_model(
val_targets,
num_epochs=1500,
):
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
scheduler = ReduceLROnPlateau(
optimizer,
"min",
_run_training_loop(
parameters=list(self.parameters()),
forward_fn=self,
train_inputs=train_inputs,
train_targets=train_targets,
val_inputs=val_inputs,
val_targets=val_targets,
num_epochs=num_epochs,
lr=self.learning_rate,
patience_lr=self.patience_LRreduction,
patience_early=self.patience_earlystopping,
factor=self.factor,
patience=self.patience_LRreduction,
threshold=self.threshold,
)
early_stopper = EarlyStopping(patience=self.patience_earlystopping)

for epoch in range(num_epochs):
optimizer.zero_grad()

outputs = self(train_inputs)
loss = nan_mse_loss(train_targets, outputs)
loss.backward()

optimizer.step()

current_loss = loss.item()
scheduler.step(current_loss)

with torch.no_grad():
val_outputs = self(val_inputs)
val_loss = nan_mse_loss(val_targets, val_outputs)

if (epoch + 1) % (num_epochs / 10) == 0:
print(
f"Epoch [{epoch + 1}/{num_epochs}], Loss:{loss.item():.6f}, Val Loss:{val_loss.item():.6f}"
)

early_stopper(val_loss.item())
if early_stopper.early_stop:
print(
f"Early stopping triggered at epoch {epoch} with val loss {val_loss.item():.6f}"
)
break


def train_calibration(
Expand Down Expand Up @@ -202,39 +239,27 @@ def train_calibration(
torch.zeros(n_outputs, dtype=exp_inputs.dtype, device=device)
)

optimizer = optim.Adam(
[c_normcal_input, o_normcal_input, c_normcal_output, o_normcal_output], lr=lr
)
scheduler = ReduceLROnPlateau(
optimizer, "min", factor=0.5, patience=200, threshold=1e-4
)
early_stopper = EarlyStopping(patience=500)

for epoch in range(num_epochs):
optimizer.zero_grad()

calibrated_inputs = (1.0 / c_normcal_input) * (exp_inputs - o_normcal_input)
def calibrated_forward(x):
calibrated_inputs = (1.0 / c_normcal_input) * (x - o_normcal_input)
base_predictions = model(calibrated_inputs)
calibrated_outputs = c_normcal_output * base_predictions + o_normcal_output

loss = nan_mse_loss(exp_targets, calibrated_outputs)
loss.backward()
optimizer.step()

current_loss = loss.item()
scheduler.step(current_loss)

if (epoch + 1) % (num_epochs / 10) == 0:
print(
f"Calibration Epoch [{epoch + 1}/{num_epochs}], Loss:{current_loss:.6f}"
)

early_stopper(current_loss)
if early_stopper.early_stop:
print(
f"Calibration early stopping at epoch {epoch} with loss {current_loss:.6f}"
)
break
return c_normcal_output * base_predictions + o_normcal_output

_run_training_loop(
parameters=[
c_normcal_input,
o_normcal_input,
c_normcal_output,
o_normcal_output,
],
forward_fn=calibrated_forward,
train_inputs=exp_inputs,
train_targets=exp_targets,
num_epochs=num_epochs,
lr=lr,
patience_lr=200,
patience_early=500,
label="Calibration",
)

return (
c_normcal_input.detach(),
Expand Down