Skip to content

Commit 986fba9

Browse files
committed
Add example script to train on Upgrade MC
1 parent 102265a commit 986fba9

1 file changed

Lines changed: 128 additions & 0 deletions

File tree

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import logging
2+
import numpy as np
3+
import pandas as pd
4+
from timer import timer
5+
6+
from pytorch_lightning import Trainer
7+
from pytorch_lightning.callbacks import EarlyStopping
8+
import torch
9+
from torch.optim.adam import Adam
10+
from torch.utils.data import dataloader
11+
12+
from gnn_reco.components.loss_functions import LogCoshLoss, VonMisesFisher2DLoss
13+
from gnn_reco.components.utils import fit_scaler
14+
from gnn_reco.data.constants import FEATURES, TRUTH
15+
from gnn_reco.data.utils import get_desired_event_numbers
16+
from gnn_reco.models import Model
17+
from gnn_reco.models.detector.icecube import IceCubeUpgrade
18+
from gnn_reco.models.gnn import DynEdge, ConvNet
19+
from gnn_reco.models.graph_builders import KNNGraphBuilder
20+
from gnn_reco.models.task.reconstruction import EnergyReconstruction
21+
from gnn_reco.models.training.callbacks import ProgressBar, PiecewiseLinearLR
22+
from gnn_reco.models.training.utils import get_predictions, make_train_validation_dataloader, save_results
23+
24+
# Configurations
25+
timer.set_level(logging.INFO)
26+
logging.basicConfig(level=logging.INFO)
27+
torch.multiprocessing.set_sharing_strategy('file_system')
28+
29+
# Constants
30+
features = FEATURES.UPGRADE
31+
truth = TRUTH.UPGRADE
32+
33+
# Main function definition
34+
def main():
35+
36+
print(f"features: {features}")
37+
print(f"truth: {truth}")
38+
39+
# Configuraiton
40+
db = '/groups/icecube/asogaard/temp/sqlite_test_upgrade/data_test/data/data_test.db'
41+
pulsemap = 'I3RecoPulseSeriesMapRFCleaned_mDOM'
42+
batch_size = 128
43+
num_workers = 10
44+
gpus = [0]
45+
target = 'energy'
46+
n_epochs = 30
47+
patience = 5
48+
archive = '/groups/icecube/asogaard/gnn/results'
49+
50+
# Common variables
51+
train_selection = get_desired_event_numbers(db, 1000000, fraction_nu_e=1.)
52+
53+
training_dataloader, validation_dataloader = make_train_validation_dataloader(
54+
db,
55+
train_selection,
56+
pulsemap,
57+
features,
58+
truth,
59+
batch_size=batch_size,
60+
num_workers=num_workers,
61+
)
62+
63+
# Building model
64+
detector = IceCubeUpgrade(
65+
graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8),
66+
)
67+
gnn = DynEdge(
68+
nb_inputs=detector.nb_outputs,
69+
)
70+
task = EnergyReconstruction(
71+
hidden_size=gnn.nb_outputs,
72+
target_label=target,
73+
loss_function=LogCoshLoss(
74+
transform_prediction_and_target=torch.log10,
75+
),
76+
)
77+
model = Model(
78+
detector=detector,
79+
gnn=gnn,
80+
tasks=[task],
81+
optimizer_class=Adam,
82+
optimizer_kwargs={'lr': 1e-03, 'eps': 1e-03},
83+
scheduler_class=PiecewiseLinearLR,
84+
scheduler_kwargs={
85+
'milestones': [0, len(training_dataloader) / 2, len(training_dataloader) * n_epochs],
86+
'factors': [1e-2, 1, 1e-02],
87+
},
88+
scheduler_config={
89+
'interval': 'step',
90+
},
91+
)
92+
93+
# Training model
94+
callbacks = [
95+
EarlyStopping(
96+
monitor='val_loss',
97+
patience=patience,
98+
),
99+
ProgressBar(),
100+
]
101+
102+
trainer = Trainer(
103+
gpus=gpus,
104+
max_epochs=n_epochs,
105+
callbacks=callbacks,
106+
log_every_n_steps=1,
107+
)
108+
109+
try:
110+
trainer.fit(model, training_dataloader, validation_dataloader)
111+
except KeyboardInterrupt:
112+
print("[ctrl+c] Exiting gracefully.")
113+
pass
114+
115+
# Saving predictions to file
116+
results = get_predictions(
117+
trainer,
118+
model,
119+
validation_dataloader,
120+
[target + '_pred'],
121+
[target, 'event_no'],
122+
)
123+
124+
save_results(db, 'test_upgrade_mDOM_energy', results, archive, model)
125+
126+
# Main function call
127+
if __name__ == "__main__":
128+
main()

0 commit comments

Comments
 (0)