Skip to content

Commit f2198c9

Browse files
authored
Deterministic seeds in debug mode (#3589)
Add support for deterministic seed generation in debug mode when using pipeline seed. Currently seed generation differ between standard and debug mode, in the future we might standardize this. Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
1 parent cc26158 commit f2198c9

2 files changed

Lines changed: 43 additions & 0 deletions

File tree

dali/python/nvidia/dali/_debug_mode.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ def __init__(self, exec_func, **kwargs):
134134
self._cur_subpipeline_id = -1
135135
self._exec_func = exec_func
136136

137+
import numpy as np
138+
seed = kwargs.get('seed', -1)
139+
if seed < 0:
140+
seed = np.random.randint(0, 2**32)
141+
self._seed_generator = np.random.default_rng(seed)
142+
137143
def __enter__(self):
138144
raise RuntimeError("Currently pipeline in debug mode works only with `pipeline_def` decorator."
139145
"Using `with` statement is not supported.")
@@ -259,6 +265,8 @@ def pipe():
259265
else:
260266
kwargs_preprocessed[key] = value
261267

268+
if 'seed' not in kwargs_preprocessed and op_wrapper.__name__ != '_arithm_op':
269+
kwargs_preprocessed['seed'] = self._seed_generator.integers(0, 2**32)
262270
res = op_wrapper(*inputs_preprocessed, **kwargs_preprocessed)
263271

264272
return tuple(res) if isinstance(res, list) else res

dali/test/python/test_pipeline_debug.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,38 @@ def _test_shape_pipeline(device):
340340
def test_shape_pipeline():
341341
for device in ['cpu', 'mixed']:
342342
yield _test_shape_pipeline, device
343+
344+
345+
@pipeline_def(batch_size=8, num_threads=3, device_id=0, seed=47, debug=True)
346+
def seed_pipeline():
347+
coin_flip = fn.random.coin_flip()
348+
normal = fn.random.normal()
349+
uniform = fn.random.uniform()
350+
batch_permutation = fn.batch_permutation()
351+
return coin_flip, normal, uniform, batch_permutation
352+
353+
354+
def test_seed_generation():
355+
pipe1 = seed_pipeline()
356+
pipe2 = seed_pipeline()
357+
compare_pipelines(pipe1, pipe2, 8, 10)
358+
359+
360+
@pipeline_def(batch_size=8, num_threads=3, device_id=0, seed=47, debug=True)
361+
def seed_rn50_pipeline_base():
362+
rng = fn.random.coin_flip(probability=0.5)
363+
jpegs, labels = fn.readers.file(
364+
file_root=file_root, shard_id=0, num_shards=2, random_shuffle=True)
365+
images = fn.decoders.image(jpegs, device='mixed', output_type=types.RGB)
366+
resized_images = fn.random_resized_crop(images, device="gpu", size=(224, 224))
367+
out_type = types.FLOAT16
368+
369+
output = fn.crop_mirror_normalize(resized_images.gpu(), mirror=rng, device="gpu", dtype=out_type, crop=(
370+
224, 224), mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
371+
return rng, jpegs, labels, images, resized_images, output
372+
373+
374+
def test_seed_generation_base():
375+
pipe1 = seed_rn50_pipeline_base()
376+
pipe2 = seed_rn50_pipeline_base()
377+
compare_pipelines(pipe1, pipe2, 8, 10)

0 commit comments

Comments
 (0)