2121
2222class Seeder :
2323 def __init__ (self , master_seed , ngpus , local_rank ):
24-
24+
25+ self .master_seed_was_none = master_seed is None
26+ if master_seed is None and local_rank == 0 :
27+ print ('INFO: master_seed is None in seeder.init, random master_seed will be generated (different one for each worker).' )
28+
2529 self .master_seed = (
2630 master_seed if master_seed is not None else generate_master_seed_randomly ()
2731 )
2832 self .seed_gen = SeedGen (self .master_seed , ngpus , local_rank )
2933 self ._ext_generators = []
34+ self ._ext_generators_shared = []
3035
31- def register_generator (self , gen ):
32- self ._ext_generators .append (gen )
36+ def register_generator (self , gen , shared = False ):
37+ if shared :
38+ if self .master_seed_was_none :
39+ raise Exception ('master_seed was None during seeder.init, seeds shared among workers cannot be used.' )
40+ self ._ext_generators_shared .append (gen )
41+ else :
42+ self ._ext_generators .append (gen )
3343
3444 def unregister_generator (self , gen ):
35- self ._ext_generators .remove (gen )
45+ try :
46+ self ._ext_generators .remove (gen )
47+ except ValueError :
48+ self ._ext_generators_shared .remove (gen )
3649
3750 def reseed (self , task , epoch ):
3851 seed = self .seed_gen (task , epoch )
@@ -43,6 +56,11 @@ def reseed(self, task, epoch):
4356 for generator in self ._ext_generators :
4457 generator (seed )
4558
59+ if self ._ext_generators_shared :
60+ shared_seed = self .seed_gen (task , epoch , shared_seed = True )
61+ for generator in self ._ext_generators_shared :
62+ generator (shared_seed )
63+
4664
4765_seeder_run = None
4866
@@ -62,15 +80,14 @@ def get_master_seed():
6280 return _seeder_run .master_seed
6381
6482
65- def register_generator (gen ):
83+ def register_generator (gen , shared = False ):
6684 global _seeder_run
67- _seeder_run ._ext_generators . append (gen )
85+ _seeder_run .register_generator (gen , shared )
6886
6987
7088def unregister_generator (gen ):
7189 global _seeder_run
72- _seeder_run ._ext_generators .remove (gen )
73-
90+ _seeder_run .unregister_generator (gen )
7491
7592class SeederCB (tf .keras .callbacks .Callback ):
7693 def on_epoch_begin (self , epoch , logs = None ):
0 commit comments