2323
2424from bionemo .llm .lightning import batch_collator
2525from bionemo .llm .utils .callbacks import PredictionWriter
26+ from bionemo .testing import megatron_parallel_state_utils
2627
2728
2829# Fixture for temporary directory
@@ -68,17 +69,20 @@ def test_write_on_batch_end(mock_torch_save, temp_dir, mock_trainer, mock_module
6869 writer = PredictionWriter (output_dir = temp_dir , write_interval = "batch" )
6970
7071 batch_idx = 1
71- writer .write_on_batch_end (
72- trainer = mock_trainer ,
73- pl_module = mock_module ,
74- prediction = collated_prediction ,
75- batch_indices = [],
76- batch = None ,
77- batch_idx = batch_idx ,
78- dataloader_idx = 0 ,
72+ with megatron_parallel_state_utils .distributed_model_parallel_state ():
73+ writer .write_on_batch_end (
74+ trainer = mock_trainer ,
75+ pl_module = mock_module ,
76+ prediction = collated_prediction ,
77+ batch_indices = [],
78+ batch = None ,
79+ batch_idx = batch_idx ,
80+ dataloader_idx = 0 ,
81+ )
82+
83+ expected_path = os .path .join (
84+ temp_dir , f"predictions__rank_{ mock_trainer .global_rank } __dp_rank_0__batch_{ batch_idx } .pt"
7985 )
80-
81- expected_path = os .path .join (temp_dir , f"predictions__rank_{ mock_trainer .global_rank } __batch_{ batch_idx } .pt" )
8286 mock_torch_save .assert_called_once_with (collated_prediction , expected_path )
8387
8488
@@ -88,14 +92,15 @@ def test_write_on_epoch_end(
8892):
8993 writer = PredictionWriter (output_dir = temp_dir , write_interval = "epoch" )
9094
91- writer .write_on_epoch_end (
92- trainer = mock_trainer ,
93- pl_module = mock_module ,
94- predictions = sample_predictions ,
95- batch_indices = [],
96- )
95+ with megatron_parallel_state_utils .distributed_model_parallel_state ():
96+ writer .write_on_epoch_end (
97+ trainer = mock_trainer ,
98+ pl_module = mock_module ,
99+ predictions = sample_predictions ,
100+ batch_indices = [],
101+ )
97102
98- expected_path = os .path .join (temp_dir , f"predictions__rank_{ mock_trainer .global_rank } .pt" )
103+ expected_path = os .path .join (temp_dir , f"predictions__rank_{ mock_trainer .global_rank } __dp_rank_0 .pt" )
99104
100105 mock_torch_save .assert_called_once () # Ensure it's called exactly once
101106
0 commit comments