11import os
2- from typing import Any , Sequence
2+ import time
3+ from typing import Any , Optional , Sequence
34
45import batchbald_redux as bbald
56import batchbald_redux .consistent_mc_dropout
@@ -133,6 +134,10 @@ def mc_forward_impl(self, input: torch.Tensor) -> torch.Tensor:
133134
134135
135136class SuperpixelClassificationTorch (SuperpixelClassificationBase ):
137+ def __init__ (self ):
138+ self .training_optimal_batchsize : Optional [int ] = None
139+ self .prediction_optimal_batchsize : Optional [int ] = None
140+
136141 def trainModelDetails (
137142 self ,
138143 record ,
@@ -144,6 +149,11 @@ def trainModelDetails(
144149 tempdir : str ,
145150 trainingSplit : float ,
146151 ):
152+ # make model
153+ num_classes : int = len (record ['labels' ])
154+ model : torch .nn .Module = _BayesianTorchModel (num_classes )
155+ model .to (model .device )
156+
147157 # print(f'Torch trainModelDetails(batchSize={batchSize}, ...)')
148158 # Make a data set and a data loader for each of training and validation
149159 count : int = len (record ['ds' ])
@@ -167,17 +177,15 @@ def trainModelDetails(
167177 val_arg2 = torch .from_numpy (record ['labelds' ][val_indices ])
168178 train_ds = torch .utils .data .TensorDataset (train_arg1 , train_arg2 )
169179 val_ds = torch .utils .data .TensorDataset (val_arg1 , val_arg2 )
180+ if batchSize < 1 :
181+ batchSize = self .findOptimalBatchSize (model , train_ds , training = True )
182+ print (f'Optimal batch size for training (device = { model .device } ) = { batchSize } ' )
170183 train_dl = torch .utils .data .DataLoader (train_ds , batch_size = batchSize )
171184 val_dl = torch .utils .data .DataLoader (val_ds , batch_size = batchSize )
172185 prog .progress (0.2 )
173186
174- # make model
175- num_classes : int = len (record ['labels' ])
176- model : torch .nn .Module = _BayesianTorchModel (num_classes )
177- model .to (model .device )
178187 prog .message ('Training model' )
179188 prog .progress (0 )
180-
181189 history = self .fitModel (
182190 model , train_dl , val_dl , epochs , callbacks = [_LogTorchProgress (prog , epochs )],
183191 )
@@ -300,7 +308,7 @@ def fitModel(
300308 return history
301309
302310 def predictLabelsForItemDetails (
303- self , batchSize : int , ds_h5 , item , model : torch .nn .Module , prog : ProgressHelper ,
311+ self , batchSize : int , ds_h5 , item , model : torch .nn .Module , prog : ProgressHelper ,
304312 ):
305313 # print(f'Torch predictLabelsForItemDetails(batchSize={batchSize}, ...)')
306314 num_superpixels : int = ds_h5 .shape [0 ]
@@ -323,6 +331,9 @@ def predictLabelsForItemDetails(
323331 ds : torch .utils .data .TensorDataset = torch .utils .data .TensorDataset (
324332 torch .from_numpy (np .array (ds_h5 ).transpose ((0 , 3 , 2 , 1 ))),
325333 )
334+ if batchSize < 1 :
335+ batchSize = self .findOptimalBatchSize (model , ds , training = False )
336+ print (f'Optimal batch size for prediction (device = { model .device } ) = { batchSize } ' )
326337 dl : torch .utils .data .DataLoader = torch .utils .data .DataLoader (ds , batch_size = batchSize )
327338 predictions : NDArray [np .float_ ] = np .zeros ((num_superpixels , bayesian_samples , num_classes ))
328339 catWeights : NDArray [np .float_ ] = np .zeros ((num_superpixels , bayesian_samples , num_classes ))
@@ -350,6 +361,54 @@ def predictLabelsForItemDetails(
350361 # scale to units
351362 return catWeights , predictions
352363
364+ def findOptimalBatchSize (
365+ self , model : torch .nn .Module , ds : torch .utils .data .TensorDataset , training : bool ,
366+ ) -> int :
367+ if training and self .training_optimal_batchsize is not None :
368+ return self .training_optimal_batchsize
369+ if not training and self .prediction_optimal_batchsize is not None :
370+ return self .prediction_optimal_batchsize
371+ # Find an optimal batch_size
372+ maximum_batchSize : int = 2 * ds .tensors [0 ].shape [0 ] - 1
373+ batchSize : int = 2
374+ # We are using a value greater than 0.0 for add_seconds so that small imprecise
375+ # timings for small batch sizes don't accidentally trip the time check.
376+ add_seconds : float = 0.05
377+ previous_time : float = 1e100
378+ while batchSize <= maximum_batchSize :
379+ try :
380+ dl : torch .utils .data .DataLoader
381+ dl = torch .utils .data .DataLoader (ds , batch_size = batchSize )
382+ start_time = time .time ()
383+ with torch .no_grad ():
384+ model .eval () # Tell torch that we will be doing predictions
385+ data : Sequence [torch .Tensor ] = next (iter (dl ))
386+ inputs : torch .Tensor = data [0 ]
387+ inputs = inputs .to (model .device )
388+ model (inputs , model .bayesian_samples )
389+ elapsed_time = time .time () - start_time
390+ if elapsed_time > 2 * previous_time + add_seconds :
391+ batchSize //= 2
392+ return self .cacheOptimalBatchSize (batchSize , model , training )
393+ previous_time = elapsed_time
394+ except RuntimeError as e :
395+ if 'out of memory' in str (e ):
396+ batchSize //= 2
397+ return self .cacheOptimalBatchSize (batchSize , model , training )
398+ else :
399+ raise e
400+ batchSize *= 2
401+ # Undo the last doubling; it was spurious
402+ batchSize //= 2
403+ return self .cacheOptimalBatchSize (batchSize , model , training )
404+
405+ def cacheOptimalBatchSize (self , batchSize : int , model : torch .nn .Module , training : bool ) -> int :
406+ if training :
407+ self .training_optimal_batchsize = batchSize
408+ else :
409+ self .prediction_optimal_batchsize = batchSize
410+ return batchSize
411+
353412 def loadModel (self , modelPath ):
354413 model = torch .load (modelPath )
355414 model .eval ()
0 commit comments