Skip to content

Commit 61627c4

Browse files
authored
Merge pull request #25 from Leengit/optimal_batchsize
Compute and use optimal batch size
2 parents 0102358 + c8a87ba commit 61627c4

3 files changed

Lines changed: 133 additions & 20 deletions

File tree

superpixel_classification/SuperpixelClassification/SuperpixelClassification.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@
105105
<longflag>batchsize</longflag>
106106
<label>Batch Size</label>
107107
<channel>input</channel>
108-
<default>32</default>
109-
<description>Training batch size</description>
108+
<default>-1</default>
109+
<description>Training batch size (-1 indicates optimal batch size)</description>
110110
</integer>
111111
<integer>
112112
<name>epochs</name>

superpixel_classification/SuperpixelClassification/SuperpixelClassificationTensorflow.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import os
2+
import time
3+
from typing import Optional
24

5+
import h5py
36
import tensorflow as tf
47
from SuperpixelClassificationBase import SuperpixelClassificationBase
58

@@ -29,18 +32,13 @@ def on_predict_batch_end(self, batch, logs=None):
2932

3033

3134
class SuperpixelClassificationTensorflow(SuperpixelClassificationBase):
35+
def __init__(self):
36+
self.training_optimal_batchsize: Optional[int] = None
37+
self.prediction_optimal_batchsize: Optional[int] = None
38+
3239
def trainModelDetails(self, record, annotationName, batchSize, epochs, itemsAndAnnot, prog,
3340
tempdir, trainingSplit):
3441
# print(f'Tensorflow trainModelDetails(batchSize={batchSize}, ...)')
35-
# generate split
36-
full_ds = tf.data.Dataset.from_tensor_slices((record['ds'], record['labelds']))
37-
full_ds = full_ds.shuffle(1000) # add seed=123 ?
38-
count = len(full_ds)
39-
train_size = int(count * trainingSplit)
40-
train_ds = full_ds.take(train_size).batch(batchSize)
41-
val_ds = full_ds.skip(train_size).batch(batchSize)
42-
print(batchSize, train_ds, val_ds)
43-
prog.progress(0.2)
4442
# make model
4543
num_classes = len(record['labels'])
4644
model = tf.keras.Sequential([
@@ -55,10 +53,22 @@ def trainModelDetails(self, record, annotationName, batchSize, epochs, itemsAndA
5553
# tf.keras.layers.Dropout(0.2),
5654
tf.keras.layers.Dense(128, activation='relu'),
5755
tf.keras.layers.Dense(num_classes)])
58-
prog.progress(0.4)
56+
prog.progress(0.2)
5957
model.compile(optimizer='adam',
6058
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
6159
metrics=['accuracy'])
60+
prog.progress(0.7)
61+
# generate split
62+
full_ds = tf.data.Dataset.from_tensor_slices((record['ds'], record['labelds']))
63+
full_ds = full_ds.shuffle(1000) # add seed=123 ?
64+
count = len(full_ds)
65+
train_size = int(count * trainingSplit)
66+
if batchSize < 1:
67+
batchSize = self.findOptimalBatchSize(model, full_ds, training=True)
68+
print(f'Optimal batch size for training = {batchSize}')
69+
train_ds = full_ds.take(train_size).batch(batchSize)
70+
val_ds = full_ds.skip(train_size).batch(batchSize)
71+
print(batchSize, train_ds, val_ds)
6272
prog.progress(0.9)
6373
prog.progress(1)
6474
prog.message('Training model')
@@ -75,8 +85,15 @@ def trainModelDetails(self, record, annotationName, batchSize, epochs, itemsAndA
7585
self.saveModel(model, modelPath)
7686
return history, modelPath
7787

78-
def predictLabelsForItemDetails(self, batchSize, ds, item, model, prog):
88+
def predictLabelsForItemDetails(
89+
self, batchSize, ds: h5py._hl.dataset.Dataset, item, model, prog,
90+
):
7991
# print(f'Tensorflow predictLabelsForItemDetails(batchSize={batchSize}, ...)')
92+
if batchSize < 1:
93+
batchSize = self.findOptimalBatchSize(
94+
model, tf.data.Dataset.from_tensor_slices(ds), training=False,
95+
)
96+
print(f'Optimal batch size for prediction = {batchSize}')
8097
predictions = model.predict(
8198
ds,
8299
batch_size=batchSize,
@@ -87,6 +104,43 @@ def predictLabelsForItemDetails(self, batchSize, ds, item, model, prog):
87104
catWeights = tf.nn.softmax(predictions)
88105
return catWeights, predictions
89106

107+
def findOptimalBatchSize(self, model, ds, training) -> int:
108+
if training and self.training_optimal_batchsize is not None:
109+
return self.training_optimal_batchsize
110+
if not training and self.prediction_optimal_batchsize is not None:
111+
return self.prediction_optimal_batchsize
112+
# Find an optimal batch_size
113+
maximum_batchSize: int = 2 * len(ds) - 1
114+
batchSize: int = 2
115+
# We are using a value greater than 0.0 for add_seconds so that small imprecise
116+
# timings for small batch sizes don't accidentally trip the time check.
117+
add_seconds: float = 0.05
118+
previous_time: float = 1e100
119+
while batchSize <= maximum_batchSize:
120+
try:
121+
small_ds = ds.take(batchSize).batch(batchSize)
122+
start_time = time.time()
123+
model.predict(small_ds, batch_size=batchSize)
124+
elapsed_time = time.time() - start_time
125+
if elapsed_time > 2 * previous_time + add_seconds:
126+
batchSize //= 2
127+
return self.cacheOptimalBatchSize(batchSize, model, training)
128+
previous_time = elapsed_time
129+
except tf.errors.OpError:
130+
batchSize //= 2
131+
return self.cacheOptimalBatchSize(batchSize, model, training)
132+
batchSize *= 2
133+
# Undo the last doubling; it was spurious
134+
batchSize //= 2
135+
return self.cacheOptimalBatchSize(batchSize, model, training)
136+
137+
def cacheOptimalBatchSize(self, batchSize, model, training) -> int:
138+
if training:
139+
self.training_optimal_batchsize = batchSize
140+
else:
141+
self.prediction_optimal_batchsize = batchSize
142+
return batchSize
143+
90144
def loadModel(self, modelPath):
91145
return tf.keras.models.load_model(modelPath)
92146

superpixel_classification/SuperpixelClassification/SuperpixelClassificationTorch.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
2-
from typing import Any, Sequence
2+
import time
3+
from typing import Any, Optional, Sequence
34

45
import batchbald_redux as bbald
56
import batchbald_redux.consistent_mc_dropout
@@ -133,6 +134,10 @@ def mc_forward_impl(self, input: torch.Tensor) -> torch.Tensor:
133134

134135

135136
class SuperpixelClassificationTorch(SuperpixelClassificationBase):
137+
def __init__(self):
138+
self.training_optimal_batchsize: Optional[int] = None
139+
self.prediction_optimal_batchsize: Optional[int] = None
140+
136141
def trainModelDetails(
137142
self,
138143
record,
@@ -144,6 +149,11 @@ def trainModelDetails(
144149
tempdir: str,
145150
trainingSplit: float,
146151
):
152+
# make model
153+
num_classes: int = len(record['labels'])
154+
model: torch.nn.Module = _BayesianTorchModel(num_classes)
155+
model.to(model.device)
156+
147157
# print(f'Torch trainModelDetails(batchSize={batchSize}, ...)')
148158
# Make a data set and a data loader for each of training and validation
149159
count: int = len(record['ds'])
@@ -167,17 +177,15 @@ def trainModelDetails(
167177
val_arg2 = torch.from_numpy(record['labelds'][val_indices])
168178
train_ds = torch.utils.data.TensorDataset(train_arg1, train_arg2)
169179
val_ds = torch.utils.data.TensorDataset(val_arg1, val_arg2)
180+
if batchSize < 1:
181+
batchSize = self.findOptimalBatchSize(model, train_ds, training=True)
182+
print(f'Optimal batch size for training (device = {model.device}) = {batchSize}')
170183
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batchSize)
171184
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=batchSize)
172185
prog.progress(0.2)
173186

174-
# make model
175-
num_classes: int = len(record['labels'])
176-
model: torch.nn.Module = _BayesianTorchModel(num_classes)
177-
model.to(model.device)
178187
prog.message('Training model')
179188
prog.progress(0)
180-
181189
history = self.fitModel(
182190
model, train_dl, val_dl, epochs, callbacks=[_LogTorchProgress(prog, epochs)],
183191
)
@@ -300,7 +308,7 @@ def fitModel(
300308
return history
301309

302310
def predictLabelsForItemDetails(
303-
self, batchSize: int, ds_h5, item, model: torch.nn.Module, prog: ProgressHelper,
311+
self, batchSize: int, ds_h5, item, model: torch.nn.Module, prog: ProgressHelper,
304312
):
305313
# print(f'Torch predictLabelsForItemDetails(batchSize={batchSize}, ...)')
306314
num_superpixels: int = ds_h5.shape[0]
@@ -323,6 +331,9 @@ def predictLabelsForItemDetails(
323331
ds: torch.utils.data.TensorDataset = torch.utils.data.TensorDataset(
324332
torch.from_numpy(np.array(ds_h5).transpose((0, 3, 2, 1))),
325333
)
334+
if batchSize < 1:
335+
batchSize = self.findOptimalBatchSize(model, ds, training=False)
336+
print(f'Optimal batch size for prediction (device = {model.device}) = {batchSize}')
326337
dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader(ds, batch_size=batchSize)
327338
predictions: NDArray[np.float_] = np.zeros((num_superpixels, bayesian_samples, num_classes))
328339
catWeights: NDArray[np.float_] = np.zeros((num_superpixels, bayesian_samples, num_classes))
@@ -350,6 +361,54 @@ def predictLabelsForItemDetails(
350361
# scale to units
351362
return catWeights, predictions
352363

364+
def findOptimalBatchSize(
365+
self, model: torch.nn.Module, ds: torch.utils.data.TensorDataset, training: bool,
366+
) -> int:
367+
if training and self.training_optimal_batchsize is not None:
368+
return self.training_optimal_batchsize
369+
if not training and self.prediction_optimal_batchsize is not None:
370+
return self.prediction_optimal_batchsize
371+
# Find an optimal batch_size
372+
maximum_batchSize: int = 2 * ds.tensors[0].shape[0] - 1
373+
batchSize: int = 2
374+
# We are using a value greater than 0.0 for add_seconds so that small imprecise
375+
# timings for small batch sizes don't accidentally trip the time check.
376+
add_seconds: float = 0.05
377+
previous_time: float = 1e100
378+
while batchSize <= maximum_batchSize:
379+
try:
380+
dl: torch.utils.data.DataLoader
381+
dl = torch.utils.data.DataLoader(ds, batch_size=batchSize)
382+
start_time = time.time()
383+
with torch.no_grad():
384+
model.eval() # Tell torch that we will be doing predictions
385+
data: Sequence[torch.Tensor] = next(iter(dl))
386+
inputs: torch.Tensor = data[0]
387+
inputs = inputs.to(model.device)
388+
model(inputs, model.bayesian_samples)
389+
elapsed_time = time.time() - start_time
390+
if elapsed_time > 2 * previous_time + add_seconds:
391+
batchSize //= 2
392+
return self.cacheOptimalBatchSize(batchSize, model, training)
393+
previous_time = elapsed_time
394+
except RuntimeError as e:
395+
if 'out of memory' in str(e):
396+
batchSize //= 2
397+
return self.cacheOptimalBatchSize(batchSize, model, training)
398+
else:
399+
raise e
400+
batchSize *= 2
401+
# Undo the last doubling; it was spurious
402+
batchSize //= 2
403+
return self.cacheOptimalBatchSize(batchSize, model, training)
404+
405+
def cacheOptimalBatchSize(self, batchSize: int, model: torch.nn.Module, training: bool) -> int:
406+
if training:
407+
self.training_optimal_batchsize = batchSize
408+
else:
409+
self.prediction_optimal_batchsize = batchSize
410+
return batchSize
411+
353412
def loadModel(self, modelPath):
354413
model = torch.load(modelPath)
355414
model.eval()

0 commit comments

Comments
 (0)