@@ -86,7 +86,7 @@ class MLPTrainer:
8686 # Hardware info:
8787 _kind_device :str = "CPU" # Device type used to train (CPU or GPU)
8888 _device_index :int = 0 # Device index (core index or GPU card index)
89-
89+ _model_index : int = 0
9090 # MLP input (controlling) variables.
9191 _controlling_vars :list [str ] = ["Density" ,
9292 "Energy" ]
@@ -99,6 +99,7 @@ class MLPTrainer:
9999 _Np_train :int = None
100100 _Np_test :int = None
101101 _Np_val :int = None
102+ _Np_batch :int = None
102103 _X_train_norm :np .ndarray = None
103104 _Y_train_norm :np .ndarray = None
104105 _X_test_norm :np .ndarray = None
@@ -337,6 +338,7 @@ def SetBatchExpo(self, batch_expo:int=DefaultProperties.batch_size_exponent):
337338 if batch_expo < 0 :
338339 raise Exception ("Mini-batch exponent should be higher than zero." )
339340 self ._batch_expo = batch_expo
341+ self ._Np_batch = int (2 ** self ._batch_expo )
340342 return
341343
342344 def SetHiddenLayers (self , layers_input :list [int ]= DefaultProperties .hidden_layer_architecture ):
@@ -379,7 +381,7 @@ def SetDecaySteps(self):
379381 """Set the number of steps in the exponential decay algorithm. The number of steps scale are proportioned based on the number of epochs,
380382 and training data size and mini batch size.
381383 """
382- self ._decay_steps = int (1e-3 * self ._n_epochs * self ._Np_train / ( 2 ** self ._batch_expo ) )
384+ self ._decay_steps = int (1e-3 * self ._n_epochs * self ._Np_train / self ._Np_batch )
383385 return
384386
385387 def RestartTraining (self ):
@@ -757,7 +759,7 @@ def Train_MLP(self):
757759 verbose = self ._verbose )
758760 self .history = self ._model .fit (self ._X_train_norm , self ._Y_train_norm , \
759761 epochs = self ._n_epochs , \
760- batch_size = 2 ** self ._batch_expo ,\
762+ batch_size = self ._Np_batch ,\
761763 verbose = self ._verbose , \
762764 validation_data = (self ._X_val_norm , self ._Y_val_norm ), \
763765 shuffle = True ,\
@@ -1041,7 +1043,7 @@ def PrepareValidationHistory(self):
10411043 return
10421044
10431045 def SetTrainBatches (self ):
1044- train_batches = tf .data .Dataset .from_tensor_slices ((self ._X_train_norm , self ._Y_train_norm )).batch (2 ** self ._batch_expo )
1046+ train_batches = tf .data .Dataset .from_tensor_slices ((self ._X_train_norm , self ._Y_train_norm )).batch (self ._Np_batch )
10451047 return train_batches
10461048
10471049
@@ -1055,7 +1057,6 @@ def LoopEpochs(self):
10551057 self .LoopBatches (train_batches = train_batches )
10561058
10571059 val_loss = self .ValidationLoss ()
1058-
10591060 if (self ._i_epoch + 1 ) % self .callback_every == 0 :
10601061 self .TestLoss ()
10611062 self .CustomCallback ()
@@ -1288,7 +1289,6 @@ def SetTrainBatches(self):
12881289 train_batches_domain = tf .data .Dataset .from_tensor_slices ((self ._X_train_norm , self ._Y_state_train_norm )).batch (2 ** self ._batch_expo )
12891290 domain_batches_list = [b for b in train_batches_domain ]
12901291
1291- batch_size_train = 2 ** self ._batch_expo
12921292
12931293 if self ._enable_boundary_loss :
12941294 # Collect projection array data.
@@ -1301,7 +1301,7 @@ def SetTrainBatches(self):
13011301 X_boundary_tf = tf .constant (self ._X_boundary_norm , dtype = self ._dt )
13021302
13031303 # Forumulate batches.
1304- batches_concat = tf .data .Dataset .from_tensor_slices ((X_boundary_tf , p_concatenated , Y_target_concatenated )).batch (batch_size_train )
1304+ batches_concat = tf .data .Dataset .from_tensor_slices ((X_boundary_tf , p_concatenated , Y_target_concatenated )).batch (self . _Np_batch )
13051305 batches_concat_list = [b for b in batches_concat ]
13061306
13071307 # Re-size boundary data batches to that of the domain batches such that both data can be evaluated simultaneously during training.
@@ -1652,7 +1652,7 @@ def __init__(self, Config_in:Config):
16521652 self ._Config = Config_in
16531653 self .alpha_expo = self ._Config .GetAlphaExpo ()
16541654 self .lr_decay = self ._Config .GetLRDecay ()
1655- self .batch_expo = self ._Config .GetBatchExpo ()
1655+ self .SetBatchExpo ( self ._Config .GetBatchExpo () )
16561656 self .activation_function = self ._Config .GetActivationFunction ()
16571657 self .architecture = []
16581658 for n in self ._Config .GetHiddenLayerArchitecture ():
0 commit comments