Skip to content

Commit 419fce5

Browse files
committed
Unit tests for Entropic Time Scheduler
Signed-off-by: btrentini <brunoxtrentini@gmail.com>
1 parent 7c8a824 commit 419fce5

1 file changed

Lines changed: 106 additions & 0 deletions

File tree

sub-packages/bionemo-moco/tests/bionemo/moco/schedules/test_inference_schedules.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,109 @@ def test_uniform_dt_padding_dilation(timesteps, device, direction, padding, dila
160160
assert schedule[0] > schedule[-1]
161161
for i in range(padding):
162162
assert schedule[-1 * (i + 1)] == 0
163+
164+
@pytest.mark.parametrize("timesteps", [10, 20])
165+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
166+
@pytest.mark.parametrize("direction", [TimeDirection.UNIFIED, TimeDirection.DIFFUSION])
167+
def test_entropic_schedule(timesteps, device, direction):
168+
"""Test the EntropicInferenceSchedule for correctness.
169+
Using a tractable predictor function to ensure the scheduler
170+
produces a non-uniform schedule with the correct properties (shape, device, direction, bounds)."""
171+
if device == "cuda" and not torch.cuda.is_available():
172+
pytest.skip("CUDA is not available")
173+
174+
# Dummy dim for the scheduler
175+
dim = 2
176+
# A simple time-dependent predictor. Divergence is D*(2t-1)
177+
# non-uniform entropy profile.
178+
predictor = lambda t, x: (2 * t - 1) * x
179+
x_0_sampler = lambda bs: torch.randn(bs, dim, device=device)
180+
x_1_sampler = lambda bs: torch.randn(bs, dim, device=device)
181+
182+
scheduler = EntropicInferenceSchedule(
183+
predictor=predictor,
184+
x_0_sampler=x_0_sampler,
185+
x_1_sampler=x_1_sampler,
186+
nsteps=timesteps,
187+
n_approx_entropy_points=25, # Fewer points for faster testing
188+
batch_size=32,
189+
direction=direction,
190+
device=device,
191+
)
192+
193+
schedule = scheduler.generate_schedule()
194+
195+
assert schedule.shape == (timesteps,)
196+
assert schedule.device.type == device
197+
198+
# Check that values are within the correct [0, 1] bounds
199+
assert torch.all(schedule >= 0) and torch.all(schedule <= 1)
200+
201+
# Check for correct ordering based on direction
202+
if direction == TimeDirection.UNIFIED:
203+
assert schedule[0] < schedule[-1]
204+
assert torch.all(torch.diff(schedule) >= 0) # Increase 0 to 1
205+
else:
206+
assert schedule[0] > schedule[-1]
207+
assert torch.all(torch.diff(schedule) <= 0) # Decrease 1 to 0
208+
209+
# Check that the schedule is non-uniform, confirming the entropic logic is active
210+
# Round to avoid float precision issues making all diffs unique
211+
diffs = torch.diff(torch.abs(schedule)).round(decimals=5)
212+
# Expect more than one unique step size, unlike a linear schedule
213+
if timesteps > 5:
214+
assert len(torch.unique(diffs)) > 1
215+
216+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
217+
def test_entropic_schedule_reproducibility(device):
218+
"""Checks that the the EntropicInferenceSchedule produce reproducible results when
219+
a torch.Generator with a fixed seed is provided."""
220+
if device == "cuda" and not torch.cuda.is_available():
221+
pytest.skip("CUDA is not available")
222+
223+
timesteps = 10
224+
dim = 2
225+
predictor = lambda t, x: (2 * t - 1) * x
226+
x_0_sampler = lambda bs: torch.randn(bs, dim, device=device)
227+
x_1_sampler = lambda bs: torch.randn(bs, dim, device=device)
228+
229+
gen1 = torch.Generator(device=device).manual_seed(42)
230+
scheduler1 = EntropicInferenceSchedule(
231+
predictor=predictor,
232+
x_0_sampler=x_0_sampler,
233+
x_1_sampler=x_1_sampler,
234+
nsteps=timesteps,
235+
device=device,
236+
generator=gen1,
237+
)
238+
schedule1 = scheduler1.generate_schedule()
239+
240+
# Run again with the same seed...
241+
gen2 = torch.Generator(device=device).manual_seed(42)
242+
scheduler2 = EntropicInferenceSchedule(
243+
predictor=predictor,
244+
x_0_sampler=x_0_sampler,
245+
x_1_sampler=x_1_sampler,
246+
nsteps=timesteps,
247+
device=device,
248+
generator=gen2,
249+
)
250+
schedule2 = scheduler2.generate_schedule()
251+
252+
# Compare again with another seed
253+
gen3 = torch.Generator(device=device).manual_seed(99)
254+
scheduler3 = EntropicInferenceSchedule(
255+
predictor=predictor,
256+
x_0_sampler=x_0_sampler,
257+
x_1_sampler=x_1_sampler,
258+
nsteps=timesteps,
259+
device=device,
260+
generator=gen3,
261+
)
262+
schedule3 = scheduler3.generate_schedule()
263+
264+
# Schedules from identical seeds should be identical
265+
assert torch.allclose(schedule1, schedule2)
266+
267+
# Schedule from a different seed should be different
268+
assert not torch.allclose(schedule1, schedule3)

0 commit comments

Comments
 (0)