2222 LinearInferenceSchedule ,
2323 LogInferenceSchedule ,
2424 PowerInferenceSchedule ,
25+ EntropicInferenceSchedule ,
2526)
2627from bionemo .moco .schedules .utils import TimeDirection
2728
@@ -215,46 +216,46 @@ def test_entropic_schedule(timesteps, device, direction):
215216
216217@pytest .mark .parametrize ("device" , ["cpu" , "cuda" ])
217218def test_entropic_schedule_reproducibility (device ):
218- """Checks that the the EntropicInferenceSchedule produce reproducible results when
219- a torch.Generator with a fixed seed is provided."""
219+ """Checks that the the EntropicInferenceSchedule produce reproducible results when a torch.Generator with a fixed seed is provided."""
220220 if device == "cuda" and not torch .cuda .is_available ():
221221 pytest .skip ("CUDA is not available" )
222222
223223 timesteps = 10
224224 dim = 2
225- predictor = lambda t , x : (2 * t - 1 ) * x
226- x_0_sampler = lambda bs : torch .randn (bs , dim , device = device )
227- x_1_sampler = lambda bs : torch .randn (bs , dim , device = device )
225+ predictor = lambda t , x : t * torch .sin (x )
228226
229227 gen1 = torch .Generator (device = device ).manual_seed (42 )
228+ sampler1 = lambda bs : torch .randn (bs , dim , device = device , generator = gen1 )
230229 scheduler1 = EntropicInferenceSchedule (
231230 predictor = predictor ,
232- x_0_sampler = x_0_sampler ,
233- x_1_sampler = x_1_sampler ,
231+ x_0_sampler = sampler1 ,
232+ x_1_sampler = sampler1 ,
234233 nsteps = timesteps ,
235234 device = device ,
236235 generator = gen1 ,
237236 )
238237 schedule1 = scheduler1 .generate_schedule ()
239238
240- # Run again with the same seed...
239+ # Run again with the same seed ---
241240 gen2 = torch .Generator (device = device ).manual_seed (42 )
241+ sampler2 = lambda bs : torch .randn (bs , dim , device = device , generator = gen2 )
242242 scheduler2 = EntropicInferenceSchedule (
243243 predictor = predictor ,
244- x_0_sampler = x_0_sampler ,
245- x_1_sampler = x_1_sampler ,
244+ x_0_sampler = sampler2 ,
245+ x_1_sampler = sampler2 ,
246246 nsteps = timesteps ,
247247 device = device ,
248248 generator = gen2 ,
249249 )
250250 schedule2 = scheduler2 .generate_schedule ()
251-
251+
252252 # Compare again with another seed
253253 gen3 = torch .Generator (device = device ).manual_seed (99 )
254+ sampler3 = lambda bs : torch .randn (bs , dim , device = device , generator = gen3 )
254255 scheduler3 = EntropicInferenceSchedule (
255256 predictor = predictor ,
256- x_0_sampler = x_0_sampler ,
257- x_1_sampler = x_1_sampler ,
257+ x_0_sampler = sampler3 ,
258+ x_1_sampler = sampler3 ,
258259 nsteps = timesteps ,
259260 device = device ,
260261 generator = gen3 ,
@@ -264,5 +265,5 @@ def test_entropic_schedule_reproducibility(device):
264265 # Schedules from identical seeds should be identical
265266 assert torch .allclose (schedule1 , schedule2 )
266267
267- # Schedule from a different seed should be different
268+ # Schedules from different seeds should be different
268269 assert not torch .allclose (schedule1 , schedule3 )
0 commit comments