@@ -187,12 +187,15 @@ def es_pipeline_standard():
187187
188188def test_external_source_debug_sample_pipeline ():
189189 n_iters = 10
190+ prefetch_queue_depth = 2
190191 pipe_load = load_images_pipeline ()
191- pipe_standard = es_pipeline_standard ()
192- pipe_debug = es_pipeline_debug ()
192+ pipe_standard = es_pipeline_standard (prefetch_queue_depth = prefetch_queue_depth )
193+ pipe_debug = es_pipeline_debug (prefetch_queue_depth = prefetch_queue_depth )
193194 pipe_load .build ()
194195 pipe_debug .build ()
195- for _ in range (n_iters ):
196+ # Call feed_input `prefetch_queue_depth` extra times to avoid issues with
197+ # missing batches near the end of the epoch caused by prefetching
198+ for _ in range (n_iters + prefetch_queue_depth ):
196199 images , labels = pipe_load .run ()
197200 pipe_debug .feed_input ('input' , [np .array (t ) for t in images ])
198201 pipe_debug .feed_input ('labels' , np .array (labels .as_tensor ()))
@@ -209,12 +212,15 @@ def es_pipeline(source, batch):
209212
210213def _test_external_source_debug (source , batch ):
211214 n_iters = 8
212- pipe_debug = es_pipeline (source , batch , debug = True )
213- pipe_standard = es_pipeline (source , batch )
215+ prefetch_queue_depth = 2
216+ pipe_debug = es_pipeline (source , batch , prefetch_queue_depth = prefetch_queue_depth , debug = True )
217+ pipe_standard = es_pipeline (source , batch , prefetch_queue_depth = prefetch_queue_depth )
214218 pipe_debug .build ()
215219 pipe_standard .build ()
216220 if source is None :
217- for _ in range (n_iters ):
221+ # Call feed_input `prefetch_queue_depth` extra times to avoid issues with
222+ # missing batches near the end of the epoch caused by prefetching
223+ for _ in range (n_iters + prefetch_queue_depth ):
218224 x = np .random .rand (8 , 5 , 1 )
219225 pipe_debug .feed_input ('input' , x )
220226 pipe_standard .feed_input ('input' , x )
0 commit comments