Skip to content

Commit c8a87ba

Browse files
author
Lee Newberg
committed
ENH: Limit optimal batch size to data size. Optimize timing too.
1 parent 02293d5 commit c8a87ba

2 files changed

Lines changed: 5 additions & 5 deletions

File tree

superpixel_classification/SuperpixelClassification/SuperpixelClassificationTensorflow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,11 @@ def findOptimalBatchSize(self, model, ds, training) -> int:
110110
if not training and self.prediction_optimal_batchsize is not None:
111111
return self.prediction_optimal_batchsize
112112
# Find an optimal batch_size
113-
maximum_batchSize: int = len(ds)
113+
maximum_batchSize: int = 2 * len(ds) - 1
114114
batchSize: int = 2
115115
# We are using a value greater than 0.0 for add_seconds so that small imprecise
116116
# timings for small batch sizes don't accidentally trip the time check.
117-
add_seconds: float = 0.5
117+
add_seconds: float = 0.05
118118
previous_time: float = 1e100
119119
while batchSize <= maximum_batchSize:
120120
try:
@@ -134,7 +134,7 @@ def findOptimalBatchSize(self, model, ds, training) -> int:
134134
batchSize //= 2
135135
return self.cacheOptimalBatchSize(batchSize, model, training)
136136

137-
def cacheOptimalBatchSize(self, batchSize: int, model: torch.nn.Module, training: bool) -> int:
137+
def cacheOptimalBatchSize(self, batchSize, model, training) -> int:
138138
if training:
139139
self.training_optimal_batchsize = batchSize
140140
else:

superpixel_classification/SuperpixelClassification/SuperpixelClassificationTorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,11 +369,11 @@ def findOptimalBatchSize(
369369
if not training and self.prediction_optimal_batchsize is not None:
370370
return self.prediction_optimal_batchsize
371371
# Find an optimal batch_size
372-
maximum_batchSize: int = ds.tensors[0].shape[0]
372+
maximum_batchSize: int = 2 * ds.tensors[0].shape[0] - 1
373373
batchSize: int = 2
374374
# We are using a value greater than 0.0 for add_seconds so that small imprecise
375375
# timings for small batch sizes don't accidentally trip the time check.
376-
add_seconds: float = 1.0
376+
add_seconds: float = 0.05
377377
previous_time: float = 1e100
378378
while batchSize <= maximum_batchSize:
379379
try:

0 commit comments

Comments
 (0)