Skip to content

Commit bec2380

Browse files
Added batch size as a class variable
1 parent 6e06796 commit bec2380

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

Manifold_Generation/MLP/Trainer_Base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class MLPTrainer:
8686
# Hardware info:
8787
_kind_device:str = "CPU" # Device type used to train (CPU or GPU)
8888
_device_index:int = 0 # Device index (core index or GPU card index)
89-
89+
_model_index:int = 0
9090
# MLP input (controlling) variables.
9191
_controlling_vars:list[str] = ["Density",
9292
"Energy"]
@@ -99,6 +99,7 @@ class MLPTrainer:
9999
_Np_train:int = None
100100
_Np_test:int = None
101101
_Np_val:int = None
102+
_Np_batch:int = None
102103
_X_train_norm:np.ndarray = None
103104
_Y_train_norm:np.ndarray = None
104105
_X_test_norm:np.ndarray = None
@@ -337,6 +338,7 @@ def SetBatchExpo(self, batch_expo:int=DefaultProperties.batch_size_exponent):
337338
if batch_expo < 0:
338339
raise Exception("Mini-batch exponent should be higher than zero.")
339340
self._batch_expo = batch_expo
341+
self._Np_batch = int(2**self._batch_expo)
340342
return
341343

342344
def SetHiddenLayers(self, layers_input:list[int]=DefaultProperties.hidden_layer_architecture):
@@ -379,7 +381,7 @@ def SetDecaySteps(self):
379381
"""Set the number of steps in the exponential decay algorithm. The number of steps scale are proportioned based on the number of epochs,
380382
and training data size and mini batch size.
381383
"""
382-
self._decay_steps = int(1e-3 * self._n_epochs * self._Np_train / (2**self._batch_expo))
384+
self._decay_steps = int(1e-3 * self._n_epochs * self._Np_train / self._Np_batch)
383385
return
384386

385387
def RestartTraining(self):
@@ -757,7 +759,7 @@ def Train_MLP(self):
757759
verbose=self._verbose)
758760
self.history = self._model.fit(self._X_train_norm, self._Y_train_norm, \
759761
epochs=self._n_epochs, \
760-
batch_size=2**self._batch_expo,\
762+
batch_size=self._Np_batch,\
761763
verbose=self._verbose, \
762764
validation_data=(self._X_val_norm, self._Y_val_norm), \
763765
shuffle=True,\
@@ -1041,7 +1043,7 @@ def PrepareValidationHistory(self):
10411043
return
10421044

10431045
def SetTrainBatches(self):
1044-
train_batches = tf.data.Dataset.from_tensor_slices((self._X_train_norm, self._Y_train_norm)).batch(2**self._batch_expo)
1046+
train_batches = tf.data.Dataset.from_tensor_slices((self._X_train_norm, self._Y_train_norm)).batch(self._Np_batch)
10451047
return train_batches
10461048

10471049

@@ -1055,7 +1057,6 @@ def LoopEpochs(self):
10551057
self.LoopBatches(train_batches=train_batches)
10561058

10571059
val_loss = self.ValidationLoss()
1058-
10591060
if (self._i_epoch + 1) % self.callback_every == 0:
10601061
self.TestLoss()
10611062
self.CustomCallback()
@@ -1288,7 +1289,6 @@ def SetTrainBatches(self):
12881289
train_batches_domain = tf.data.Dataset.from_tensor_slices((self._X_train_norm, self._Y_state_train_norm)).batch(2**self._batch_expo)
12891290
domain_batches_list = [b for b in train_batches_domain]
12901291

1291-
batch_size_train = 2**self._batch_expo
12921292

12931293
if self._enable_boundary_loss:
12941294
# Collect projection array data.
@@ -1301,7 +1301,7 @@ def SetTrainBatches(self):
13011301
X_boundary_tf = tf.constant(self._X_boundary_norm, dtype=self._dt)
13021302

13031303
# Forumulate batches.
1304-
batches_concat = tf.data.Dataset.from_tensor_slices((X_boundary_tf, p_concatenated, Y_target_concatenated)).batch(batch_size_train)
1304+
batches_concat = tf.data.Dataset.from_tensor_slices((X_boundary_tf, p_concatenated, Y_target_concatenated)).batch(self._Np_batch)
13051305
batches_concat_list = [b for b in batches_concat]
13061306

13071307
# Re-size boundary data batches to that of the domain batches such that both data can be evaluated simultaneously during training.
@@ -1652,7 +1652,7 @@ def __init__(self, Config_in:Config):
16521652
self._Config=Config_in
16531653
self.alpha_expo = self._Config.GetAlphaExpo()
16541654
self.lr_decay = self._Config.GetLRDecay()
1655-
self.batch_expo = self._Config.GetBatchExpo()
1655+
self.SetBatchExpo(self._Config.GetBatchExpo())
16561656
self.activation_function = self._Config.GetActivationFunction()
16571657
self.architecture = []
16581658
for n in self._Config.GetHiddenLayerArchitecture():

0 commit comments

Comments
 (0)