Skip to content

Commit 5e217a4

Browse files
authored
Merge pull request #96 from asogaard/training-on-upgrade-mc
Training on upgrade MC
2 parents e3a7ec6 + ee46e82 commit 5e217a4

5 files changed

Lines changed: 202 additions & 7 deletions

File tree

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import logging
2+
import numpy as np
3+
import pandas as pd
4+
from timer import timer
5+
6+
from pytorch_lightning import Trainer
7+
from pytorch_lightning.callbacks import EarlyStopping
8+
import torch
9+
from torch.optim.adam import Adam
10+
from torch.utils.data import dataloader
11+
12+
from gnn_reco.components.loss_functions import LogCoshLoss, VonMisesFisher2DLoss
13+
from gnn_reco.components.utils import fit_scaler
14+
from gnn_reco.data.constants import FEATURES, TRUTH
15+
from gnn_reco.data.utils import get_desired_event_numbers
16+
from gnn_reco.models import Model
17+
from gnn_reco.models.detector.icecube import IceCubeUpgrade
18+
from gnn_reco.models.gnn import DynEdge, ConvNet
19+
from gnn_reco.models.graph_builders import KNNGraphBuilder
20+
from gnn_reco.models.task.reconstruction import EnergyReconstruction
21+
from gnn_reco.models.training.callbacks import ProgressBar, PiecewiseLinearLR
22+
from gnn_reco.models.training.utils import get_predictions, make_train_validation_dataloader, save_results
23+
24+
# Configurations
25+
timer.set_level(logging.INFO)
26+
logging.basicConfig(level=logging.INFO)
27+
torch.multiprocessing.set_sharing_strategy('file_system')
28+
29+
# Constants
30+
features = FEATURES.UPGRADE
31+
truth = TRUTH.UPGRADE
32+
33+
# Main function definition
34+
def main():
35+
36+
print(f"features: {features}")
37+
print(f"truth: {truth}")
38+
39+
# Configuraiton
40+
db = '/groups/icecube/asogaard/temp/sqlite_test_upgrade/data_test/data/data_test.db'
41+
pulsemap = 'I3RecoPulseSeriesMapRFCleaned_mDOM'
42+
batch_size = 128
43+
num_workers = 10
44+
gpus = [0]
45+
target = 'energy'
46+
n_epochs = 30
47+
patience = 5
48+
archive = '/groups/icecube/asogaard/gnn/results'
49+
50+
# Common variables
51+
train_selection = get_desired_event_numbers(db, 1000000, fraction_nu_e=1.)
52+
53+
training_dataloader, validation_dataloader = make_train_validation_dataloader(
54+
db,
55+
train_selection,
56+
pulsemap,
57+
features,
58+
truth,
59+
batch_size=batch_size,
60+
num_workers=num_workers,
61+
)
62+
63+
# Building model
64+
detector = IceCubeUpgrade(
65+
graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8),
66+
)
67+
gnn = DynEdge(
68+
nb_inputs=detector.nb_outputs,
69+
)
70+
task = EnergyReconstruction(
71+
hidden_size=gnn.nb_outputs,
72+
target_label=target,
73+
loss_function=LogCoshLoss(
74+
transform_prediction_and_target=torch.log10,
75+
),
76+
)
77+
model = Model(
78+
detector=detector,
79+
gnn=gnn,
80+
tasks=[task],
81+
optimizer_class=Adam,
82+
optimizer_kwargs={'lr': 1e-03, 'eps': 1e-03},
83+
scheduler_class=PiecewiseLinearLR,
84+
scheduler_kwargs={
85+
'milestones': [0, len(training_dataloader) / 2, len(training_dataloader) * n_epochs],
86+
'factors': [1e-2, 1, 1e-02],
87+
},
88+
scheduler_config={
89+
'interval': 'step',
90+
},
91+
)
92+
93+
# Training model
94+
callbacks = [
95+
EarlyStopping(
96+
monitor='val_loss',
97+
patience=patience,
98+
),
99+
ProgressBar(),
100+
]
101+
102+
trainer = Trainer(
103+
gpus=gpus,
104+
max_epochs=n_epochs,
105+
callbacks=callbacks,
106+
log_every_n_steps=1,
107+
)
108+
109+
try:
110+
trainer.fit(model, training_dataloader, validation_dataloader)
111+
except KeyboardInterrupt:
112+
print("[ctrl+c] Exiting gracefully.")
113+
pass
114+
115+
# Saving predictions to file
116+
results = get_predictions(
117+
trainer,
118+
model,
119+
validation_dataloader,
120+
[target + '_pred'],
121+
[target, 'event_no'],
122+
)
123+
124+
save_results(db, 'test_upgrade_mDOM_energy', results, archive, model)
125+
126+
# Main function call
127+
if __name__ == "__main__":
128+
main()

misc/badges/pylint.svg

Lines changed: 2 additions & 2 deletions
Loading

src/gnn_reco/data/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def get_desired_event_numbers(db_path, desired_size, fraction_noise=0, fraction_
2323
tot_event_nos = pd.read_sql(total_query,con)
2424
if len(tot_event_nos) < desired_size:
2525
desired_size = len(tot_event_nos)
26-
numbers_desired = [x * desired_size for x in fracs]
26+
numbers_desired = [int(x * desired_size) for x in fracs]
2727
print('Only {} events in database, using this number instead.'.format(len(tot_event_nos)))
2828

2929
list_of_dataframes = []
@@ -43,9 +43,9 @@ def get_desired_event_numbers(db_path, desired_size, fraction_noise=0, fraction_
4343
numbers_desired = [int(new_x * (len(tmp_dataframe)/number)) for new_x in numbers_desired]
4444
restart_trigger = True
4545
list_of_dataframes = []
46-
break
46+
break
4747

48-
list_of_dataframes.append(dataframe)
48+
list_of_dataframes.append(dataframe)
4949
retrieved_event_nos_pd = pd.concat(list_of_dataframes)
5050
event_no_list = retrieved_event_nos_pd.sample(frac=1, replace=False, random_state=rng).values.ravel().tolist()
5151

src/gnn_reco/models/detector/icecube.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _forward(self, data: Data) -> Data:
3030
"""
3131

3232
# Check(s)
33-
assert self.nb_inputs == 7
33+
#assert self.nb_inputs == 7
3434

3535
# Preprocessing
3636
data.x[:,0] /= 100. # dom_x
@@ -50,3 +50,65 @@ def _forward(self, data: Data) -> Data:
5050

5151
class IceCubeDeepCore(IceCube86):
5252
"""`Detector` class for IceCube-DeepCore."""
53+
54+
55+
class IceCubeUpgrade(IceCubeDeepCore):
56+
"""`Detector` class for IceCube-Upgrade."""
57+
58+
# Implementing abstract class attribute
59+
features = FEATURES.UPGRADE
60+
61+
def _forward(self, data: Data) -> Data:
62+
"""Ingests data, builds graph (connectivity/adjacency), and preprocesses features.
63+
64+
Assuming the following features, in this order (see self._features):
65+
dom_x
66+
dom_y
67+
dom_z
68+
dom_times
69+
charge
70+
rde
71+
pmt_area
72+
string
73+
pmt_number
74+
dom_number
75+
pmt_dir_x
76+
pmt_dir_y
77+
pmt_dir_z
78+
dom_type
79+
80+
Args:
81+
data (Data): Input graph data.
82+
83+
Returns:
84+
Data: Connected and preprocessed graph data.
85+
"""
86+
87+
# Check(s)
88+
#assert self.nb_inputs == 14
89+
90+
# Run IceCube/DeepCore preprocessing on first 7 features
91+
#data = super()._forward(data)
92+
93+
# Preprocessing
94+
data.x[:,0] /= 100. # dom_x
95+
data.x[:,1] /= 100. # dom_y
96+
data.x[:,2] += 350. # dom_z
97+
data.x[:,2] /= 100.
98+
data.x[:,3] /= 1.05e+04 # dom_time
99+
data.x[:,3] -= 1.
100+
data.x[:,3] *= 20.
101+
data.x[:,4] /= 1. # charge
102+
#data.x[:,5] -= 1.25 # rde
103+
#data.x[:,5] /= 0.25
104+
data.x[:,6] /= 0.05 # pmt_area
105+
data.x[:,7] -= 90 # string
106+
data.x[:,8] /= 20. # pmt_number
107+
data.x[:,9] -= 60. # dom_number
108+
data.x[:,9] /= 60.
109+
#data.x[:,10] /= 1. # pmt_dir_x
110+
#data.x[:,11] /= 1. # pmt_dir_y
111+
#data.x[:,12] /= 1. # pmt_dir_z
112+
data.x[:,13] /= 130. # dom_type
113+
114+
return data

src/gnn_reco/models/training/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,17 @@ def make_dataloader(
3434
selection=selection,
3535
)
3636

37+
def collate_fn(graphs):
38+
# Remove graphs with less than two DOM hits. Should not occur in "production."
39+
graphs = [g for g in graphs if g.n_pulses > 1]
40+
return Batch.from_data_list(graphs)
41+
3742
dataloader = DataLoader(
3843
dataset,
3944
batch_size=batch_size,
4045
shuffle=shuffle,
4146
num_workers=num_workers,
42-
collate_fn=Batch.from_data_list,
47+
collate_fn=collate_fn,
4348
persistent_workers=persistent_workers,
4449
prefetch_factor=2,
4550
)

0 commit comments

Comments
 (0)