Skip to content

Commit 6d3dab2

Browse files
committed
ETS improved tests, tutorial, and Generator"
Signed-off-by: btrentini <brunoxtrentini@gmail.com>
1 parent 419fce5 commit 6d3dab2

3 files changed

Lines changed: 18 additions & 15 deletions

File tree

sub-packages/bionemo-moco/src/bionemo/moco/schedules/inference_time_schedules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,9 +545,10 @@ def _hutchinson_divergence(self, t: Tensor, x: Tensor) -> Tensor:
545545

546546
# random vector from the Rademacher distribution
547547
if self.generator:
548-
epsilon = (torch.randint_like(x, 0, 2, generator=self.generator) * 2 - 1).to(x.dtype)
548+
epsilon = (torch.randint(0, 2, x.shape, generator=self.generator, device=x.device) * 2 - 1).to(x.dtype)
549549
else:
550550
epsilon = (torch.randint_like(x, 0, 2) * 2 - 1).to(x.dtype)
551+
551552

552553
v = self.predictor(t, x)
553554

sub-packages/bionemo-moco/src/bionemo/moco/schedules/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ class TimeDirection(Enum):
2222

2323
UNIFIED = "unified" # Noise(0) --> Data(1)
2424
DIFFUSION = "diffusion" # Noise(1) --> Data(0)
25+

sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_inference_schedules.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
LinearInferenceSchedule,
2323
LogInferenceSchedule,
2424
PowerInferenceSchedule,
25+
EntropicInferenceSchedule,
2526
)
2627
from bionemo.moco.schedules.utils import TimeDirection
2728

@@ -215,46 +216,46 @@ def test_entropic_schedule(timesteps, device, direction):
215216

216217
@pytest.mark.parametrize("device", ["cpu", "cuda"])
217218
def test_entropic_schedule_reproducibility(device):
218-
"""Checks that the the EntropicInferenceSchedule produce reproducible results when
219-
a torch.Generator with a fixed seed is provided."""
219+
"""Checks that the the EntropicInferenceSchedule produce reproducible results when a torch.Generator with a fixed seed is provided."""
220220
if device == "cuda" and not torch.cuda.is_available():
221221
pytest.skip("CUDA is not available")
222222

223223
timesteps = 10
224224
dim = 2
225-
predictor = lambda t, x: (2 * t - 1) * x
226-
x_0_sampler = lambda bs: torch.randn(bs, dim, device=device)
227-
x_1_sampler = lambda bs: torch.randn(bs, dim, device=device)
225+
predictor = lambda t, x: t * torch.sin(x)
228226

229227
gen1 = torch.Generator(device=device).manual_seed(42)
228+
sampler1 = lambda bs: torch.randn(bs, dim, device=device, generator=gen1)
230229
scheduler1 = EntropicInferenceSchedule(
231230
predictor=predictor,
232-
x_0_sampler=x_0_sampler,
233-
x_1_sampler=x_1_sampler,
231+
x_0_sampler=sampler1,
232+
x_1_sampler=sampler1,
234233
nsteps=timesteps,
235234
device=device,
236235
generator=gen1,
237236
)
238237
schedule1 = scheduler1.generate_schedule()
239238

240-
# Run again with the same seed...
239+
# Run again with the same seed ---
241240
gen2 = torch.Generator(device=device).manual_seed(42)
241+
sampler2 = lambda bs: torch.randn(bs, dim, device=device, generator=gen2)
242242
scheduler2 = EntropicInferenceSchedule(
243243
predictor=predictor,
244-
x_0_sampler=x_0_sampler,
245-
x_1_sampler=x_1_sampler,
244+
x_0_sampler=sampler2,
245+
x_1_sampler=sampler2,
246246
nsteps=timesteps,
247247
device=device,
248248
generator=gen2,
249249
)
250250
schedule2 = scheduler2.generate_schedule()
251-
251+
252252
# Compare again with another seed
253253
gen3 = torch.Generator(device=device).manual_seed(99)
254+
sampler3 = lambda bs: torch.randn(bs, dim, device=device, generator=gen3)
254255
scheduler3 = EntropicInferenceSchedule(
255256
predictor=predictor,
256-
x_0_sampler=x_0_sampler,
257-
x_1_sampler=x_1_sampler,
257+
x_0_sampler=sampler3,
258+
x_1_sampler=sampler3,
258259
nsteps=timesteps,
259260
device=device,
260261
generator=gen3,
@@ -264,5 +265,5 @@ def test_entropic_schedule_reproducibility(device):
264265
# Schedules from identical seeds should be identical
265266
assert torch.allclose(schedule1, schedule2)
266267

267-
# Schedule from a different seed should be different
268+
# Schedules from different seeds should be different
268269
assert not torch.allclose(schedule1, schedule3)

0 commit comments

Comments
 (0)