@@ -242,6 +242,9 @@ def get_dummy_inputs(self) -> Dict[str, Any]:
242242 """
243243 Returns dict of inputs to pass to the model forward pass.
244244
245+ Implementations must be deterministic: every call must return identical inputs (seed any random
246+ tensors and generators), since tests call this once per forward pass to compare outputs.
247+
245248 Returns:
246249 Dict[str, Any]: Input tensors/values for model.forward().
247250
@@ -292,9 +295,8 @@ def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
292295 f"Parameter shape mismatch for { param_name } . Original: { param_1 .shape } , loaded: { param_2 .shape } "
293296 )
294297
295- inputs_dict = self .get_dummy_inputs ()
296- image = model (** inputs_dict , return_dict = False )[0 ]
297- new_image = new_model (** inputs_dict , return_dict = False )[0 ]
298+ image = model (** self .get_dummy_inputs (), return_dict = False )[0 ]
299+ new_image = new_model (** self .get_dummy_inputs (), return_dict = False )[0 ]
298300
299301 assert_tensors_close (image , new_image , atol = atol , rtol = rtol , msg = "Models give different forward passes." )
300302
@@ -314,9 +316,8 @@ def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
314316
315317 new_model .to (torch_device )
316318
317- inputs_dict = self .get_dummy_inputs ()
318- image = model (** inputs_dict , return_dict = False )[0 ]
319- new_image = new_model (** inputs_dict , return_dict = False )[0 ]
319+ image = model (** self .get_dummy_inputs (), return_dict = False )[0 ]
320+ new_image = new_model (** self .get_dummy_inputs (), return_dict = False )[0 ]
320321
321322 assert_tensors_close (image , new_image , atol = atol , rtol = rtol , msg = "Models give different forward passes." )
322323
@@ -344,9 +345,8 @@ def test_determinism(self, atol=1e-5, rtol=0):
344345 model .to (torch_device )
345346 model .eval ()
346347
347- inputs_dict = self .get_dummy_inputs ()
348- first = model (** inputs_dict , return_dict = False )[0 ]
349- second = model (** inputs_dict , return_dict = False )[0 ]
348+ first = model (** self .get_dummy_inputs (), return_dict = False )[0 ]
349+ second = model (** self .get_dummy_inputs (), return_dict = False )[0 ]
350350
351351 first_flat = first .flatten ()
352352 second_flat = second .flatten ()
@@ -403,9 +403,8 @@ def recursive_check(tuple_object, dict_object):
403403 model .to (torch_device )
404404 model .eval ()
405405
406- inputs_dict = self .get_dummy_inputs ()
407- outputs_dict = model (** inputs_dict )
408- outputs_tuple = model (** inputs_dict , return_dict = False )
406+ outputs_dict = model (** self .get_dummy_inputs ())
407+ outputs_tuple = model (** self .get_dummy_inputs (), return_dict = False )
409408
410409 recursive_check (outputs_tuple , outputs_dict )
411410
@@ -509,11 +508,10 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4,
509508 def test_sharded_checkpoints (self , tmp_path , atol = 1e-5 , rtol = 0 ):
510509 torch .manual_seed (0 )
511510 config = self .get_init_dict ()
512- inputs_dict = self .get_dummy_inputs ()
513511 model = self .model_class (** config ).eval ()
514512 model = model .to (torch_device )
515513
516- base_output = model (** inputs_dict , return_dict = False )[0 ]
514+ base_output = model (** self . get_dummy_inputs () , return_dict = False )[0 ]
517515
518516 model_size = compute_module_persistent_sizes (model )["" ]
519517 max_shard_size = int ((model_size * 0.75 ) / (2 ** 10 )) # Convert to KB as these test models are small
@@ -532,10 +530,7 @@ def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0):
532530 new_model = new_model .to (torch_device )
533531
534532 torch .manual_seed (0 )
535- # Re-create inputs only if they contain a generator (which needs to be reset)
536- if "generator" in inputs_dict :
537- inputs_dict = self .get_dummy_inputs ()
538- new_output = new_model (** inputs_dict , return_dict = False )[0 ]
533+ new_output = new_model (** self .get_dummy_inputs (), return_dict = False )[0 ]
539534
540535 assert_tensors_close (
541536 base_output , new_output , atol = atol , rtol = rtol , msg = "Output should match after sharded save/load"
@@ -546,11 +541,10 @@ def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0):
546541 def test_sharded_checkpoints_with_variant (self , tmp_path , atol = 1e-5 , rtol = 0 ):
547542 torch .manual_seed (0 )
548543 config = self .get_init_dict ()
549- inputs_dict = self .get_dummy_inputs ()
550544 model = self .model_class (** config ).eval ()
551545 model = model .to (torch_device )
552546
553- base_output = model (** inputs_dict , return_dict = False )[0 ]
547+ base_output = model (** self . get_dummy_inputs () , return_dict = False )[0 ]
554548
555549 model_size = compute_module_persistent_sizes (model )["" ]
556550 max_shard_size = int ((model_size * 0.75 ) / (2 ** 10 )) # Convert to KB as these test models are small
@@ -574,10 +568,7 @@ def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0):
574568 new_model = new_model .to (torch_device )
575569
576570 torch .manual_seed (0 )
577- # Re-create inputs only if they contain a generator (which needs to be reset)
578- if "generator" in inputs_dict :
579- inputs_dict = self .get_dummy_inputs ()
580- new_output = new_model (** inputs_dict , return_dict = False )[0 ]
571+ new_output = new_model (** self .get_dummy_inputs (), return_dict = False )[0 ]
581572
582573 assert_tensors_close (
583574 base_output , new_output , atol = atol , rtol = rtol , msg = "Output should match after variant sharded save/load"
@@ -589,11 +580,10 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt
589580
590581 torch .manual_seed (0 )
591582 config = self .get_init_dict ()
592- inputs_dict = self .get_dummy_inputs ()
593583 model = self .model_class (** config ).eval ()
594584 model = model .to (torch_device )
595585
596- base_output = model (** inputs_dict , return_dict = False )[0 ]
586+ base_output = model (** self . get_dummy_inputs () , return_dict = False )[0 ]
597587
598588 model_size = compute_module_persistent_sizes (model )["" ]
599589 max_shard_size = int ((model_size * 0.75 ) / (2 ** 10 )) # Convert to KB as these test models are small
@@ -627,10 +617,7 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt
627617 model_parallel = model_parallel .to (torch_device )
628618
629619 torch .manual_seed (0 )
630- # Re-create inputs only if they contain a generator (which needs to be reset)
631- if "generator" in inputs_dict :
632- inputs_dict = self .get_dummy_inputs ()
633- output_parallel = model_parallel (** inputs_dict , return_dict = False )[0 ]
620+ output_parallel = model_parallel (** self .get_dummy_inputs (), return_dict = False )[0 ]
634621
635622 assert_tensors_close (
636623 base_output , output_parallel , atol = atol , rtol = rtol , msg = "Output should match with parallel loading"
0 commit comments