Skip to content

Commit d1a2182

Browse files
committed
Seeder: Add option for seeding shared across workers
1 parent f2ee71a commit d1a2182

4 files changed

Lines changed: 88 additions & 26 deletions

File tree

fwr13y/seeder/paddle.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,31 @@
2121

2222
class Seeder:
2323
def __init__(self, master_seed, ngpus, local_rank):
24-
24+
25+
self.master_seed_was_none = master_seed is None
26+
if master_seed is None and local_rank == 0:
27+
print('INFO: master_seed is None in seeder.init, random master_seed will be generated (different one for each worker).')
28+
2529
self.master_seed = (
2630
master_seed if master_seed is not None else generate_master_seed_randomly()
2731
)
2832
self.seed_gen = SeedGen(self.master_seed, ngpus, local_rank)
2933
self._ext_generators = []
34+
self._ext_generators_shared = []
3035

31-
def register_generator(self, gen):
32-
self._ext_generators.append(gen)
36+
def register_generator(self, gen, shared=False):
37+
if shared:
38+
if self.master_seed_was_none:
39+
raise Exception('master_seed was None during seeder.init, seeds shared among workers cannot be used.')
40+
self._ext_generators_shared.append(gen)
41+
else:
42+
self._ext_generators.append(gen)
3343

3444
def unregister_generator(self, gen):
35-
self._ext_generators.remove(gen)
45+
try:
46+
self._ext_generators.remove(gen)
47+
except ValueError:
48+
self._ext_generators_shared.remove(gen)
3649

3750
def reseed(self, task, epoch):
3851
seed = self.seed_gen(task, epoch)
@@ -43,6 +56,11 @@ def reseed(self, task, epoch):
4356
for generator in self._ext_generators:
4457
generator(seed)
4558

59+
if self._ext_generators_shared:
60+
shared_seed = self.seed_gen(task, epoch, shared_seed=True)
61+
for generator in self._ext_generators_shared:
62+
generator(shared_seed)
63+
4664

4765
_seeder_run = None
4866

@@ -62,11 +80,11 @@ def get_master_seed():
6280
return _seeder_run.master_seed
6381

6482

65-
def register_generator(gen):
83+
def register_generator(gen, shared=False):
6684
global _seeder_run
67-
_seeder_run._ext_generators.append(gen)
85+
_seeder_run.register_generator(gen, shared)
6886

6987

7088
def unregister_generator(gen):
7189
global _seeder_run
72-
_seeder_run._ext_generators.remove(gen)
90+
_seeder_run.unregister_generator(gen)

fwr13y/seeder/pyt.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,31 @@
2121

2222
class Seeder:
2323
def __init__(self, master_seed, ngpus, local_rank):
24-
24+
25+
self.master_seed_was_none = master_seed is None
26+
if master_seed is None and local_rank == 0:
27+
print('INFO: master_seed is None in seeder.init, random master_seed will be generated (different one for each worker).')
28+
2529
self.master_seed = (
2630
master_seed if master_seed is not None else generate_master_seed_randomly()
2731
)
2832
self.seed_gen = SeedGen(self.master_seed, ngpus, local_rank)
2933
self._ext_generators = []
34+
self._ext_generators_shared = []
3035

31-
def register_generator(self, gen):
32-
self._ext_generators.append(gen)
36+
def register_generator(self, gen, shared=False):
37+
if shared:
38+
if self.master_seed_was_none:
39+
raise Exception('master_seed was None during seeder.init, seeds shared among workers cannot be used.')
40+
self._ext_generators_shared.append(gen)
41+
else:
42+
self._ext_generators.append(gen)
3343

3444
def unregister_generator(self, gen):
35-
self._ext_generators.remove(gen)
45+
try:
46+
self._ext_generators.remove(gen)
47+
except ValueError:
48+
self._ext_generators_shared.remove(gen)
3649

3750
def reseed(self, task, epoch):
3851
seed = self.seed_gen(task, epoch)
@@ -43,6 +56,11 @@ def reseed(self, task, epoch):
4356
for generator in self._ext_generators:
4457
generator(seed)
4558

59+
if self._ext_generators_shared:
60+
shared_seed = self.seed_gen(task, epoch, shared_seed=True)
61+
for generator in self._ext_generators_shared:
62+
generator(shared_seed)
63+
4664

4765
_seeder_run = None
4866

@@ -62,11 +80,11 @@ def get_master_seed():
6280
return _seeder_run.master_seed
6381

6482

65-
def register_generator(gen):
83+
def register_generator(gen, shared=False):
6684
global _seeder_run
67-
_seeder_run._ext_generators.append(gen)
85+
_seeder_run.register_generator(gen, shared)
6886

6987

7088
def unregister_generator(gen):
7189
global _seeder_run
72-
_seeder_run._ext_generators.remove(gen)
90+
_seeder_run.unregister_generator(gen)

fwr13y/seeder/seed_gen.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,19 @@ def __init__(self, master_seed, ngpus, local_rank):
2929
self._used_seeds = set()
3030
self._rng = random.Random(0)
3131

32-
def __call__(self, task, epoch):
33-
seed = (
34-
self.master_seed + (epoch * self.ngpus + self.local_rank)
35-
) * self.ntasks + task
32+
def __call__(self, task, epoch, shared_seed=False):
33+
if shared_seed:
34+
# Use the same seed for every rank
35+
# Constant at the beginning so that the values do not repeat with not shared seeds
36+
seed = 2 * (
37+
self.master_seed + epoch
38+
) * self.ntasks + task
39+
else:
40+
# Use different seed for every rank
41+
seed = (
42+
self.master_seed + (epoch * self.ngpus + self.local_rank)
43+
) * self.ntasks + task
44+
3645
if seed in self._used_seeds:
3746
print(
3847
"Warning!!! seed has been generated more than once!!!", file=sys.stderr

fwr13y/seeder/tf.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,31 @@
2121

2222
class Seeder:
2323
def __init__(self, master_seed, ngpus, local_rank):
24-
24+
25+
self.master_seed_was_none = master_seed is None
26+
if master_seed is None and local_rank == 0:
27+
print('INFO: master_seed is None in seeder.init, random master_seed will be generated (different one for each worker).')
28+
2529
self.master_seed = (
2630
master_seed if master_seed is not None else generate_master_seed_randomly()
2731
)
2832
self.seed_gen = SeedGen(self.master_seed, ngpus, local_rank)
2933
self._ext_generators = []
34+
self._ext_generators_shared = []
3035

31-
def register_generator(self, gen):
32-
self._ext_generators.append(gen)
36+
def register_generator(self, gen, shared=False):
37+
if shared:
38+
if self.master_seed_was_none:
39+
raise Exception('master_seed was None during seeder.init, seeds shared among workers cannot be used.')
40+
self._ext_generators_shared.append(gen)
41+
else:
42+
self._ext_generators.append(gen)
3343

3444
def unregister_generator(self, gen):
35-
self._ext_generators.remove(gen)
45+
try:
46+
self._ext_generators.remove(gen)
47+
except ValueError:
48+
self._ext_generators_shared.remove(gen)
3649

3750
def reseed(self, task, epoch):
3851
seed = self.seed_gen(task, epoch)
@@ -43,6 +56,11 @@ def reseed(self, task, epoch):
4356
for generator in self._ext_generators:
4457
generator(seed)
4558

59+
if self._ext_generators_shared:
60+
shared_seed = self.seed_gen(task, epoch, shared_seed=True)
61+
for generator in self._ext_generators_shared:
62+
generator(shared_seed)
63+
4664

4765
_seeder_run = None
4866

@@ -62,15 +80,14 @@ def get_master_seed():
6280
return _seeder_run.master_seed
6381

6482

65-
def register_generator(gen):
83+
def register_generator(gen, shared=False):
6684
global _seeder_run
67-
_seeder_run._ext_generators.append(gen)
85+
_seeder_run.register_generator(gen, shared)
6886

6987

7088
def unregister_generator(gen):
7189
global _seeder_run
72-
_seeder_run._ext_generators.remove(gen)
73-
90+
_seeder_run.unregister_generator(gen)
7491

7592
class SeederCB(tf.keras.callbacks.Callback):
7693
def on_epoch_begin(self, epoch, logs=None):

0 commit comments

Comments
 (0)