Skip to content

Commit e3a7ec6

Browse files
authored
Merge pull request #94 from asogaard/pytorch-lightning
Implementing pytorch-lightning
2 parents a69c5f6 + e84a713 commit e3a7ec6

18 files changed

Lines changed: 551 additions & 252 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,4 @@ docs/**/*.rst
134134

135135
# Badges
136136
misc/badges
137+
lightning_logs/

envs/gnn_py38.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ channels:
66
- conda-forge
77
dependencies:
88
- python=3.8
9-
- pytorch=1.9.1
9+
- pytorch=1.9.0
1010
- cudatoolkit=11.1
1111
- pyg
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import logging
2+
import numpy as np
3+
import pandas as pd
4+
from timer import timer
5+
6+
from pytorch_lightning import Trainer
7+
from pytorch_lightning.callbacks import EarlyStopping
8+
import torch
9+
from torch.optim.adam import Adam
10+
11+
from gnn_reco.components.loss_functions import LogCoshLoss, VonMisesFisher2DLoss
12+
from gnn_reco.components.utils import fit_scaler
13+
from gnn_reco.data.constants import FEATURES, TRUTH
14+
from gnn_reco.data.utils import get_equal_proportion_neutrino_indices
15+
from gnn_reco.models import Model
16+
from gnn_reco.models.detector.icecube import IceCube86
17+
from gnn_reco.models.gnn import DynEdge, ConvNet
18+
from gnn_reco.models.graph_builders import KNNGraphBuilder
19+
from gnn_reco.models.task.reconstruction import EnergyReconstruction
20+
from gnn_reco.models.training.callbacks import ProgressBar, PiecewiseLinearLR
21+
from gnn_reco.models.training.utils import get_predictions, make_train_validation_dataloader, save_results
22+
23+
# Configurations
24+
timer.set_level(logging.INFO)
25+
logging.basicConfig(level=logging.INFO)
26+
torch.multiprocessing.set_sharing_strategy('file_system')
27+
28+
# Constants
29+
features = FEATURES.ICECUBE86
30+
truth = TRUTH.ICECUBE86
31+
32+
# Main function definition
33+
def main():
34+
35+
print(f"features: {features}")
36+
print(f"truth: {truth}")
37+
38+
# Configuraiton
39+
db = '/groups/icecube/leonbozi/datafromrasmus/GNNReco/data/databases/dev_level7_noise_muon_nu_classification_pass2_fixedRetro_v3/data/dev_level7_noise_muon_nu_classification_pass2_fixedRetro_v3.db'
40+
pulsemap = 'SRTTWOfflinePulsesDC'
41+
batch_size = 256
42+
num_workers = 10
43+
gpus = [0]
44+
target = 'energy'
45+
n_epochs = 5
46+
patience = 5
47+
archive = '/groups/icecube/asogaard/gnn/results'
48+
49+
# Common variables
50+
train_selection, _ = get_equal_proportion_neutrino_indices(db)
51+
train_selection = train_selection[0:50000]
52+
53+
training_dataloader, validation_dataloader = make_train_validation_dataloader(
54+
db,
55+
train_selection,
56+
pulsemap,
57+
features,
58+
truth,
59+
batch_size=batch_size,
60+
num_workers=num_workers,
61+
)
62+
63+
# Building model
64+
detector = IceCube86(
65+
graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8),
66+
)
67+
gnn = DynEdge(
68+
nb_inputs=detector.nb_outputs,
69+
)
70+
task = EnergyReconstruction(
71+
hidden_size=gnn.nb_outputs,
72+
target_label=target,
73+
loss_function=LogCoshLoss(
74+
transform_prediction_and_target=torch.log10,
75+
),
76+
)
77+
model = Model(
78+
detector=detector,
79+
gnn=gnn,
80+
tasks=[task],
81+
optimizer_class=Adam,
82+
optimizer_kwargs={'lr': 1e-03, 'eps': 1e-03},
83+
scheduler_class=PiecewiseLinearLR,
84+
scheduler_kwargs={
85+
'milestones': [0, len(training_dataloader) / 2, len(training_dataloader) * n_epochs],
86+
'factors': [1e-2, 1, 1e-02],
87+
},
88+
scheduler_config={
89+
'interval': 'step',
90+
},
91+
)
92+
93+
# Training model
94+
callbacks = [
95+
EarlyStopping(
96+
monitor='val_loss',
97+
patience=patience,
98+
),
99+
ProgressBar(),
100+
]
101+
102+
trainer = Trainer(
103+
gpus=gpus,
104+
max_epochs=n_epochs,
105+
callbacks=callbacks,
106+
log_every_n_steps=1,
107+
)
108+
109+
try:
110+
trainer.fit(model, training_dataloader, validation_dataloader)
111+
except KeyboardInterrupt:
112+
print("[ctrl+c] Exiting gracefully.")
113+
pass
114+
115+
# Saving predictions to file
116+
results = get_predictions(
117+
trainer,
118+
model,
119+
validation_dataloader,
120+
[target + '_pred'],
121+
[target, 'event_no'],
122+
)
123+
124+
save_results(db, 'dynedge_energy_pytorch_lightning', results,archive, model)
125+
126+
# Main function call
127+
if __name__ == "__main__":
128+
main()

examples/test_model_training_sqlite.py

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,17 @@
33

44
import torch
55

6-
from gnn_reco.components.loss_functions import VonMisesFisher2DLoss
6+
from gnn_reco.components.loss_functions import LogCoshLoss, VonMisesFisher2DLoss
77
from gnn_reco.components.utils import fit_scaler
88
from gnn_reco.data.constants import FEATURES, TRUTH
99
from gnn_reco.data.utils import get_equal_proportion_neutrino_indices
10-
from gnn_reco.legacy.reimplemented import LegacyVonMisesFisherLoss, LegacyAngularReconstruction
11-
from gnn_reco.models import Model
10+
from gnn_reco.legacy.callbacks import PiecewiseLinearScheduler
11+
from gnn_reco.legacy.trainers import Trainer, Predictor
12+
from gnn_reco.legacy.model import Model
1213
from gnn_reco.models.detector.icecube import IceCube86
1314
from gnn_reco.models.gnn import DynEdge, ConvNet
1415
from gnn_reco.models.graph_builders import KNNGraphBuilder
15-
from gnn_reco.models.task.reconstruction import AngularReconstructionWithKappa
16-
from gnn_reco.models.training.callbacks import PiecewiseLinearScheduler
17-
from gnn_reco.models.training.trainers import Trainer, Predictor
16+
from gnn_reco.models.task.reconstruction import EnergyReconstruction
1817
from gnn_reco.models.training.utils import make_train_validation_dataloader, save_results
1918

2019
# Configurations
@@ -35,11 +34,11 @@ def main():
3534
# Configuraiton
3635
db = '/groups/icecube/leonbozi/datafromrasmus/GNNReco/data/databases/dev_level7_noise_muon_nu_classification_pass2_fixedRetro_v3/data/dev_level7_noise_muon_nu_classification_pass2_fixedRetro_v3.db'
3736
pulsemap = 'SRTTWOfflinePulsesDC'
38-
batch_size = 1024
37+
batch_size = 256
3938
num_workers = 10
40-
device = 'cuda:1'
41-
target = 'zenith'
42-
n_epochs = 30
39+
device = 'cuda:0'
40+
target = 'energy'
41+
n_epochs = 5
4342
patience = 5
4443
archive = '/groups/icecube/asogaard/gnn/results'
4544

@@ -48,14 +47,14 @@ def main():
4847

4948
# Common variables
5049
train_selection, _ = get_equal_proportion_neutrino_indices(db)
51-
train_selection = train_selection[0:500000]
50+
train_selection = train_selection[0:50000]
5251

5352
training_dataloader, validation_dataloader = make_train_validation_dataloader(
54-
db,
55-
train_selection,
56-
pulsemap,
57-
features,
58-
truth,
53+
db,
54+
train_selection,
55+
pulsemap,
56+
features,
57+
truth,
5958
batch_size=batch_size,
6059
num_workers=num_workers,
6160
)
@@ -68,19 +67,13 @@ def main():
6867
gnn = DynEdge(
6968
nb_inputs=detector.nb_outputs,
7069
)
71-
task = LegacyAngularReconstruction(
72-
hidden_size=gnn.nb_outputs,
73-
target_label=target,
74-
loss_function=LegacyVonMisesFisherLoss(
75-
target_scaler=scalers['truth'][target]
70+
task = EnergyReconstruction(
71+
hidden_size=gnn.nb_outputs,
72+
target_label=target,
73+
loss_function=LogCoshLoss(
74+
transform_prediction_and_target=torch.log10,
7675
),
77-
target_scaler=scalers['truth'][target],
7876
)
79-
#task = AngularReconstructionWithKappa(
80-
# hidden_size=gnn.nb_outputs,
81-
# target_label=target,
82-
# loss_function=VonMisesFisher2DLoss(),
83-
#)
8477
model = Model(
8578
detector=detector,
8679
gnn=gnn,
@@ -93,9 +86,9 @@ def main():
9386

9487
trainer = Trainer(
9588
training_dataloader=training_dataloader,
96-
validation_dataloader=validation_dataloader,
89+
validation_dataloader=validation_dataloader,
9790
optimizer=optimizer,
98-
n_epochs=n_epochs,
91+
n_epochs=n_epochs,
9992
scheduler=scheduler,
10093
patience=patience,
10194
)
@@ -107,14 +100,14 @@ def main():
107100
pass
108101

109102
predictor = Predictor(
110-
dataloader=validation_dataloader,
111-
target=target,
112-
device=device,
113-
output_column_names=[target + '_pred', target + '_sigma'],
103+
dataloader=validation_dataloader,
104+
target=target,
105+
device=device,
106+
output_column_names=[target + '_pred'],
114107
)
115108
model._tasks[0].inference = True
116109
results = predictor(model)
117-
save_results(db, 'dynedge_zenith', results,archive, model)
110+
save_results(db, 'dynedge_energy', results,archive, model)
118111

119112
# Main function call
120113
if __name__ == "__main__":

misc/badges/pylint.svg

Lines changed: 3 additions & 3 deletions
Loading

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,19 @@ def install(package):
2929
'torch-cluster',
3030
'torch-spline-conv',
3131
'torch-geometric==2.0.1',
32+
'pytorch-lightning',
3233
'dill',
3334
]
3435

3536
# Ensure pytorch is already installed (see e.g. https://github.com/pyg-team/pytorch_geometric/issues/861#issuecomment-566424944)
3637
try:
3738
import torch # pyright: reportMissingImports=false
3839
except ImportError:
39-
install('torch>=1.9.0')
40+
install('torch==1.9.0')
4041

4142
setup(
4243
name='gnn_reco',
43-
version='0.1.1',
44+
version='0.1.1',
4445
description='A common library for using graph neural networks (GNNs) in netrino telescopes.',
4546
url='https://github.com/icecube/gnn-reco',
4647
author='The IceCube Collaboration',

src/gnn_reco/components/loss_functions.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,19 @@ class LossFunction(_WeightedLoss):
2222
"""
2323
def __init__(
2424
self,
25-
transform_output: Optional[Callable] = None,
25+
transform_prediction_and_target: Optional[Callable] = None,
26+
transform_target: Optional[Callable] = None,
2627
**kwargs,
2728
):
2829
super().__init__(**kwargs)
29-
self._transform_output = transform_output if transform_output else lambda x: x
30+
31+
# Check(s)
32+
assert not((transform_prediction_and_target is not None) and (transform_target is not None)), \
33+
"Please specify at most one of `transform_prediction_and_target` and `transform_target`"
34+
35+
# Member variables
36+
self._transform_prediction = transform_prediction_and_target if transform_prediction_and_target else lambda x: x
37+
self._transform_target = transform_target if transform_target else self._transform_prediction
3038

3139
@final
3240
def forward(
@@ -47,8 +55,8 @@ def forward(
4755
Tensor: Loss, either averaged to a scalar (if `return_elements = False`)
4856
or elementwise terms with shape [N,] (if `return_elements = True`).
4957
"""
50-
prediction = self._transform_output(prediction)
51-
target = self._transform_output(target)
58+
prediction = self._transform_prediction(prediction)
59+
target = self._transform_target(target)
5260

5361
elements = self._forward(prediction, target)
5462
assert elements.size(dim=0) == target.size(dim=0), \

0 commit comments

Comments
 (0)