@@ -166,7 +166,7 @@ def test_esm2_finetune_token_classifier(
166166 config_class = ESM2FineTuneTokenConfig ,
167167 lora_checkpoint_path = simple_ft_checkpoint ,
168168 )
169- prediction_path = tmp_path / "infer" / "predictions__rank_0 .pt"
169+ prediction_path = tmp_path / "infer" / "predictions__rank_0__dp_rank_0 .pt"
170170 # check that prediction_path loaded has classification_output key
171171 assert prediction_path .exists ()
172172 predictions = torch .load (prediction_path )
@@ -310,7 +310,7 @@ def test_esm2_finetune_regressor(
310310 config_class = ESM2FineTuneSeqConfig ,
311311 lora_checkpoint_path = simple_ft_checkpoint ,
312312 )
313- prediction_path = tmp_path / "infer" / "predictions__rank_0 .pt"
313+ prediction_path = tmp_path / "infer" / "predictions__rank_0__dp_rank_0 .pt"
314314 # check that prediction_path loaded has classification_output key
315315 assert prediction_path .exists ()
316316 predictions = torch .load (prediction_path )
@@ -456,7 +456,7 @@ def test_esm2_finetune_classifier(
456456 config_class = ESM2FineTuneSeqConfig ,
457457 lora_checkpoint_path = simple_ft_checkpoint ,
458458 )
459- prediction_path = tmp_path / "infer" / "predictions__rank_0 .pt"
459+ prediction_path = tmp_path / "infer" / "predictions__rank_0__dp_rank_0 .pt"
460460 # check that prediction_path loaded has classification_output key
461461 assert prediction_path .exists ()
462462 predictions = torch .load (prediction_path )
0 commit comments