@@ -160,3 +160,109 @@ def test_uniform_dt_padding_dilation(timesteps, device, direction, padding, dila
160160 assert schedule [0 ] > schedule [- 1 ]
161161 for i in range (padding ):
162162 assert schedule [- 1 * (i + 1 )] == 0
163+
164+ @pytest .mark .parametrize ("timesteps" , [10 , 20 ])
165+ @pytest .mark .parametrize ("device" , ["cpu" , "cuda" ])
166+ @pytest .mark .parametrize ("direction" , [TimeDirection .UNIFIED , TimeDirection .DIFFUSION ])
167+ def test_entropic_schedule (timesteps , device , direction ):
168+ """Test the EntropicInferenceSchedule for correctness.
169+ Using a tractable predictor function to ensure the scheduler
170+ produces a non-uniform schedule with the correct properties (shape, device, direction, bounds)."""
171+ if device == "cuda" and not torch .cuda .is_available ():
172+ pytest .skip ("CUDA is not available" )
173+
174+ # Dummy dim for the scheduler
175+ dim = 2
176+ # A simple time-dependent predictor. Divergence is D*(2t-1)
177+ # non-uniform entropy profile.
178+ predictor = lambda t , x : (2 * t - 1 ) * x
179+ x_0_sampler = lambda bs : torch .randn (bs , dim , device = device )
180+ x_1_sampler = lambda bs : torch .randn (bs , dim , device = device )
181+
182+ scheduler = EntropicInferenceSchedule (
183+ predictor = predictor ,
184+ x_0_sampler = x_0_sampler ,
185+ x_1_sampler = x_1_sampler ,
186+ nsteps = timesteps ,
187+ n_approx_entropy_points = 25 , # Fewer points for faster testing
188+ batch_size = 32 ,
189+ direction = direction ,
190+ device = device ,
191+ )
192+
193+ schedule = scheduler .generate_schedule ()
194+
195+ assert schedule .shape == (timesteps ,)
196+ assert schedule .device .type == device
197+
198+ # Check that values are within the correct [0, 1] bounds
199+ assert torch .all (schedule >= 0 ) and torch .all (schedule <= 1 )
200+
201+ # Check for correct ordering based on direction
202+ if direction == TimeDirection .UNIFIED :
203+ assert schedule [0 ] < schedule [- 1 ]
204+ assert torch .all (torch .diff (schedule ) >= 0 ) # Increase 0 to 1
205+ else :
206+ assert schedule [0 ] > schedule [- 1 ]
207+ assert torch .all (torch .diff (schedule ) <= 0 ) # Decrease 1 to 0
208+
209+ # Check that the schedule is non-uniform, confirming the entropic logic is active
210+ # Round to avoid float precision issues making all diffs unique
211+ diffs = torch .diff (torch .abs (schedule )).round (decimals = 5 )
212+ # Expect more than one unique step size, unlike a linear schedule
213+ if timesteps > 5 :
214+ assert len (torch .unique (diffs )) > 1
215+
216+ @pytest .mark .parametrize ("device" , ["cpu" , "cuda" ])
217+ def test_entropic_schedule_reproducibility (device ):
218+ """Checks that the the EntropicInferenceSchedule produce reproducible results when
219+ a torch.Generator with a fixed seed is provided."""
220+ if device == "cuda" and not torch .cuda .is_available ():
221+ pytest .skip ("CUDA is not available" )
222+
223+ timesteps = 10
224+ dim = 2
225+ predictor = lambda t , x : (2 * t - 1 ) * x
226+ x_0_sampler = lambda bs : torch .randn (bs , dim , device = device )
227+ x_1_sampler = lambda bs : torch .randn (bs , dim , device = device )
228+
229+ gen1 = torch .Generator (device = device ).manual_seed (42 )
230+ scheduler1 = EntropicInferenceSchedule (
231+ predictor = predictor ,
232+ x_0_sampler = x_0_sampler ,
233+ x_1_sampler = x_1_sampler ,
234+ nsteps = timesteps ,
235+ device = device ,
236+ generator = gen1 ,
237+ )
238+ schedule1 = scheduler1 .generate_schedule ()
239+
240+ # Run again with the same seed...
241+ gen2 = torch .Generator (device = device ).manual_seed (42 )
242+ scheduler2 = EntropicInferenceSchedule (
243+ predictor = predictor ,
244+ x_0_sampler = x_0_sampler ,
245+ x_1_sampler = x_1_sampler ,
246+ nsteps = timesteps ,
247+ device = device ,
248+ generator = gen2 ,
249+ )
250+ schedule2 = scheduler2 .generate_schedule ()
251+
252+ # Compare again with another seed
253+ gen3 = torch .Generator (device = device ).manual_seed (99 )
254+ scheduler3 = EntropicInferenceSchedule (
255+ predictor = predictor ,
256+ x_0_sampler = x_0_sampler ,
257+ x_1_sampler = x_1_sampler ,
258+ nsteps = timesteps ,
259+ device = device ,
260+ generator = gen3 ,
261+ )
262+ schedule3 = scheduler3 .generate_schedule ()
263+
264+ # Schedules from identical seeds should be identical
265+ assert torch .allclose (schedule1 , schedule2 )
266+
267+ # Schedule from a different seed should be different
268+ assert not torch .allclose (schedule1 , schedule3 )
0 commit comments