Skip to content

Commit d2c62c5

Browse files
authored
Fix ES debug mode test failing with missing batch (#3712)
* Fix ES debug mode test failing with missing batch Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
1 parent 83b3915 commit d2c62c5

1 file changed

Lines changed: 12 additions & 6 deletions

File tree

dali/test/python/test_pipeline_debug.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,15 @@ def es_pipeline_standard():
187187

188188
def test_external_source_debug_sample_pipeline():
189189
n_iters = 10
190+
prefetch_queue_depth = 2
190191
pipe_load = load_images_pipeline()
191-
pipe_standard = es_pipeline_standard()
192-
pipe_debug = es_pipeline_debug()
192+
pipe_standard = es_pipeline_standard(prefetch_queue_depth=prefetch_queue_depth)
193+
pipe_debug = es_pipeline_debug(prefetch_queue_depth=prefetch_queue_depth)
193194
pipe_load.build()
194195
pipe_debug.build()
195-
for _ in range(n_iters):
196+
# Call feed_input `prefetch_queue_depth` extra times to avoid issues with
197+
# missing batches near the end of the epoch caused by prefetching
198+
for _ in range(n_iters + prefetch_queue_depth):
196199
images, labels = pipe_load.run()
197200
pipe_debug.feed_input('input', [np.array(t) for t in images])
198201
pipe_debug.feed_input('labels', np.array(labels.as_tensor()))
@@ -209,12 +212,15 @@ def es_pipeline(source, batch):
209212

210213
def _test_external_source_debug(source, batch):
211214
n_iters = 8
212-
pipe_debug = es_pipeline(source, batch, debug=True)
213-
pipe_standard = es_pipeline(source, batch)
215+
prefetch_queue_depth = 2
216+
pipe_debug = es_pipeline(source, batch, prefetch_queue_depth=prefetch_queue_depth, debug=True)
217+
pipe_standard = es_pipeline(source, batch, prefetch_queue_depth=prefetch_queue_depth)
214218
pipe_debug.build()
215219
pipe_standard.build()
216220
if source is None:
217-
for _ in range(n_iters):
221+
# Call feed_input `prefetch_queue_depth` extra times to avoid issues with
222+
# missing batches near the end of the epoch caused by prefetching
223+
for _ in range(n_iters + prefetch_queue_depth):
218224
x = np.random.rand(8, 5, 1)
219225
pipe_debug.feed_input('input', x)
220226
pipe_standard.feed_input('input', x)

0 commit comments

Comments
 (0)