33
44import torch
55
6- from gnn_reco .components .loss_functions import VonMisesFisher2DLoss
6+ from gnn_reco .components .loss_functions import LogCoshLoss , VonMisesFisher2DLoss
77from gnn_reco .components .utils import fit_scaler
88from gnn_reco .data .constants import FEATURES , TRUTH
99from gnn_reco .data .utils import get_equal_proportion_neutrino_indices
10- from gnn_reco .legacy .reimplemented import LegacyVonMisesFisherLoss , LegacyAngularReconstruction
11- from gnn_reco .models import Model
10+ from gnn_reco .legacy .callbacks import PiecewiseLinearScheduler
11+ from gnn_reco .legacy .trainers import Trainer , Predictor
12+ from gnn_reco .legacy .model import Model
1213from gnn_reco .models .detector .icecube import IceCube86
1314from gnn_reco .models .gnn import DynEdge , ConvNet
1415from gnn_reco .models .graph_builders import KNNGraphBuilder
15- from gnn_reco .models .task .reconstruction import AngularReconstructionWithKappa
16- from gnn_reco .models .training .callbacks import PiecewiseLinearScheduler
17- from gnn_reco .models .training .trainers import Trainer , Predictor
16+ from gnn_reco .models .task .reconstruction import EnergyReconstruction
1817from gnn_reco .models .training .utils import make_train_validation_dataloader , save_results
1918
2019# Configurations
@@ -35,11 +34,11 @@ def main():
3534 # Configuraiton
3635 db = '/groups/icecube/leonbozi/datafromrasmus/GNNReco/data/databases/dev_level7_noise_muon_nu_classification_pass2_fixedRetro_v3/data/dev_level7_noise_muon_nu_classification_pass2_fixedRetro_v3.db'
3736 pulsemap = 'SRTTWOfflinePulsesDC'
38- batch_size = 1024
37+ batch_size = 256
3938 num_workers = 10
40- device = 'cuda:1 '
41- target = 'zenith '
42- n_epochs = 30
39+ device = 'cuda:0 '
40+ target = 'energy '
41+ n_epochs = 5
4342 patience = 5
4443 archive = '/groups/icecube/asogaard/gnn/results'
4544
@@ -48,14 +47,14 @@ def main():
4847
4948 # Common variables
5049 train_selection , _ = get_equal_proportion_neutrino_indices (db )
51- train_selection = train_selection [0 :500000 ]
50+ train_selection = train_selection [0 :50000 ]
5251
5352 training_dataloader , validation_dataloader = make_train_validation_dataloader (
54- db ,
55- train_selection ,
56- pulsemap ,
57- features ,
58- truth ,
53+ db ,
54+ train_selection ,
55+ pulsemap ,
56+ features ,
57+ truth ,
5958 batch_size = batch_size ,
6059 num_workers = num_workers ,
6160 )
@@ -68,19 +67,13 @@ def main():
6867 gnn = DynEdge (
6968 nb_inputs = detector .nb_outputs ,
7069 )
71- task = LegacyAngularReconstruction (
72- hidden_size = gnn .nb_outputs ,
73- target_label = target ,
74- loss_function = LegacyVonMisesFisherLoss (
75- target_scaler = scalers [ 'truth' ][ target ]
70+ task = EnergyReconstruction (
71+ hidden_size = gnn .nb_outputs ,
72+ target_label = target ,
73+ loss_function = LogCoshLoss (
74+ transform_prediction_and_target = torch . log10 ,
7675 ),
77- target_scaler = scalers ['truth' ][target ],
7876 )
79- #task = AngularReconstructionWithKappa(
80- # hidden_size=gnn.nb_outputs,
81- # target_label=target,
82- # loss_function=VonMisesFisher2DLoss(),
83- #)
8477 model = Model (
8578 detector = detector ,
8679 gnn = gnn ,
@@ -93,9 +86,9 @@ def main():
9386
9487 trainer = Trainer (
9588 training_dataloader = training_dataloader ,
96- validation_dataloader = validation_dataloader ,
89+ validation_dataloader = validation_dataloader ,
9790 optimizer = optimizer ,
98- n_epochs = n_epochs ,
91+ n_epochs = n_epochs ,
9992 scheduler = scheduler ,
10093 patience = patience ,
10194 )
@@ -107,14 +100,14 @@ def main():
107100 pass
108101
109102 predictor = Predictor (
110- dataloader = validation_dataloader ,
111- target = target ,
112- device = device ,
113- output_column_names = [target + '_pred' , target + '_sigma' ],
103+ dataloader = validation_dataloader ,
104+ target = target ,
105+ device = device ,
106+ output_column_names = [target + '_pred' ],
114107 )
115108 model ._tasks [0 ].inference = True
116109 results = predictor (model )
117- save_results (db , 'dynedge_zenith ' , results ,archive , model )
110+ save_results (db , 'dynedge_energy ' , results ,archive , model )
118111
119112# Main function call
120113if __name__ == "__main__" :
0 commit comments