22import torch
33import yaml
44import matplotlib .pyplot as plt
5- import torch .nn .functional as F
6-
7- from monai .data import Dataset , DataLoader , CacheDataset , PersistentDataset
5+ from monai .data import DataLoader , CacheDataset
86from tqdm import tqdm
97
108from dataset import get_dataset
11- from transforms import get_train_transforms
9+ from transforms import get_transforms
1210from unet import build_model
1311
1412
1513torch .backends .cudnn .benchmark = True
1614
17- # -----------------------------
18- # CONFIG
19- # -----------------------------
15+
2016
2117def load_config ():
2218 with open ("config.yaml" ) as f :
2319 return yaml .safe_load (f )
2420
2521
26- # -----------------------------
27- # TRAIN
28- # -----------------------------
2922
3023def main ():
3124
@@ -35,30 +28,39 @@ def main():
3528
3629 print ("Using device:" , device )
3730
38- data = get_dataset (cfg ["data_dir" ])
31+ all_data = get_dataset (cfg ["data_dir" ])
32+ val_data , train_data = all_data [:2 ], all_data [2 :]
3933
40- transforms = get_train_transforms (
41- cfg ["patch_size" ],
42- )
34+ train_transforms = get_transforms (cfg ["patch_size" ], cfg ["train_num_samples" ])
35+ val_transforms = get_transforms (cfg ["patch_size" ], cfg ["val_num_samples" ])
4336
44- print ("Preparing dataset ..." )
45- dataset = PersistentDataset (
46- data = data ,
47- transform = transforms ,
48- cache_dir = cfg ["cache_dir" ]
37+ print ("Caching train dataset..." )
38+ train_dataset = CacheDataset (
39+ data = train_data ,
40+ transform = train_transforms ,
41+ cache_rate = 1.0 , # Change this to reduce memory footprint
42+ num_workers = cfg ["num_workers" ],
43+ )
44+ loader = DataLoader (
45+ train_dataset ,
46+ batch_size = cfg ["batch_size" ],
47+ shuffle = True ,
48+ num_workers = cfg ["num_workers" ],
49+ pin_memory = True ,
50+ persistent_workers = True
4951 )
5052
51- print ("Caching dataset..." )
52- dataset = CacheDataset (
53- data = dataset ,
53+ print ("Caching val dataset..." )
54+ val_dataset = CacheDataset (
55+ data = val_data ,
56+ transform = val_transforms ,
5457 cache_rate = 1.0 ,
55- num_workers = 8 ,
58+ num_workers = cfg [ "num_workers" ] ,
5659 )
57-
58- loader = DataLoader (
59- dataset ,
60+ val_loader = DataLoader (
61+ val_dataset ,
6062 batch_size = cfg ["batch_size" ],
61- shuffle = True ,
63+ shuffle = False ,
6264 num_workers = cfg ["num_workers" ],
6365 pin_memory = True ,
6466 persistent_workers = True
@@ -77,17 +79,18 @@ def main():
7779 T_max = cfg ["epochs" ]
7880 )
7981
82+ scaler = torch .amp .GradScaler ("cuda" )
8083 l1_loss = torch .nn .L1Loss ()
8184
82- scaler = torch .amp .GradScaler ("cuda" )
83-
84- os .makedirs ("outputs/checkpoints" , exist_ok = True )
85- os .makedirs ("outputs/logs" , exist_ok = True )
86- os .makedirs ("outputs/plots" , exist_ok = True )
85+ out = cfg ["output_dir" ]
86+ os .makedirs (f"{ out } /checkpoints" , exist_ok = True )
87+ os .makedirs (f"{ out } /logs" , exist_ok = True )
88+ os .makedirs (f"{ out } /plots" , exist_ok = True )
8789
88- best_loss = float ("inf" )
90+ best_val_loss = float ("inf" )
8991
90- loss_history = []
92+ train_loss_history = []
93+ val_loss_history = []
9194
9295 print ("Starting training..." )
9396
@@ -101,9 +104,10 @@ def main():
101104
102105 for batch in pbar :
103106
104- x = batch ["input" ].to (device )
105- y = batch ["ct" ].to (device )
106-
107+ x = batch ["input" ].to (device )
108+ y = batch ["ct" ].to (device )
109+ mask = batch ["prediction_mask" ].bool ().to (device )
110+ y [~ mask ] = 0 # don't bother trying to predict the bed
107111 optimizer .zero_grad ()
108112
109113 with torch .amp .autocast ("cuda" ):
@@ -120,41 +124,60 @@ def main():
120124
121125 pbar .set_description (f"loss { loss .item ():.4f} " )
122126
123- avg_loss = epoch_loss / len (loader )
127+ avg_train_loss = epoch_loss / len (loader )
124128
125- print ( "Epoch" , epoch , "Loss" , avg_loss )
129+ scheduler . step ( )
126130
127- loss_history .append (avg_loss )
131+ # validation
132+ model .eval ()
133+ val_loss = 0
134+ with torch .no_grad ():
135+ for batch in val_loader :
136+ x = batch ["input" ].to (device )
137+ y = batch ["ct" ].to (device )
138+ mask = batch ["prediction_mask" ].bool ().to (device )
139+ y [~ mask ] = 0 # don't bother trying to predict the bed
128140
129- scheduler .step ()
141+ with torch .amp .autocast ("cuda" ):
142+ pred = model (x )
143+ loss = l1_loss (pred , y )
144+ val_loss += loss .item ()
145+ avg_val_loss = val_loss / len (val_loader )
146+
147+ print (f"Epoch { epoch } train={ avg_train_loss :.4f} val={ avg_val_loss :.4f} " )
148+
149+ train_loss_history .append (avg_train_loss )
150+ val_loss_history .append (avg_val_loss )
130151
131- # best checkpoint
132- if avg_loss < best_loss :
152+ # best checkpoint (by val)
153+ if avg_val_loss < best_val_loss :
133154
134- best_loss = avg_loss
155+ best_val_loss = avg_val_loss
135156
136157 torch .save (
137158 model .state_dict (),
138- "outputs /checkpoints/best_model.pth"
159+ f" { out } /checkpoints/best_model.pth"
139160 )
140161
141162 # last checkpoint
142163 torch .save (
143164 model .state_dict (),
144- "outputs /checkpoints/last_model.pth"
165+ f" { out } /checkpoints/last_model.pth"
145166 )
146167
147168 # log
148- with open ("outputs /logs/train_log.txt" ,"a" ) as f :
149- f .write (f"{ epoch } ,{ avg_loss } \n " )
169+ with open (f" { out } /logs/train_log.txt" , "a" ) as f :
170+ f .write (f"{ epoch } ,{ avg_train_loss } , { avg_val_loss } \n " )
150171
151172 # plot loss
152173 plt .figure ()
153- plt .plot (loss_history )
174+ plt .plot (train_loss_history , label = "train" )
175+ plt .plot (val_loss_history , label = "val" )
154176 plt .xlabel ("Epoch" )
155177 plt .ylabel ("Loss" )
156- plt .title ("Training Loss" )
157- plt .savefig ("outputs/plots/loss_curve.png" )
178+ plt .title ("Train / Val Loss" )
179+ plt .legend ()
180+ plt .savefig (f"{ out } /plots/loss_curve.png" )
158181 plt .close ()
159182
160183
0 commit comments