Skip to content

Commit 1cb1016

Browse files
committed
Add dp rank info to callback names
Signed-off-by: John St John <jstjohn@nvidia.com>
1 parent 97f73c3 commit 1cb1016

1 file changed

Lines changed: 22 additions & 17 deletions

File tree

sub-packages/bionemo-llm/tests/bionemo/llm/utils/test_callbacks.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from bionemo.llm.lightning import batch_collator
2525
from bionemo.llm.utils.callbacks import PredictionWriter
26+
from bionemo.testing import megatron_parallel_state_utils
2627

2728

2829
# Fixture for temporary directory
@@ -68,17 +69,20 @@ def test_write_on_batch_end(mock_torch_save, temp_dir, mock_trainer, mock_module
6869
writer = PredictionWriter(output_dir=temp_dir, write_interval="batch")
6970

7071
batch_idx = 1
71-
writer.write_on_batch_end(
72-
trainer=mock_trainer,
73-
pl_module=mock_module,
74-
prediction=collated_prediction,
75-
batch_indices=[],
76-
batch=None,
77-
batch_idx=batch_idx,
78-
dataloader_idx=0,
72+
with megatron_parallel_state_utils.distributed_model_parallel_state():
73+
writer.write_on_batch_end(
74+
trainer=mock_trainer,
75+
pl_module=mock_module,
76+
prediction=collated_prediction,
77+
batch_indices=[],
78+
batch=None,
79+
batch_idx=batch_idx,
80+
dataloader_idx=0,
81+
)
82+
83+
expected_path = os.path.join(
84+
temp_dir, f"predictions__rank_{mock_trainer.global_rank}__dp_rank_0__batch_{batch_idx}.pt"
7985
)
80-
81-
expected_path = os.path.join(temp_dir, f"predictions__rank_{mock_trainer.global_rank}__batch_{batch_idx}.pt")
8286
mock_torch_save.assert_called_once_with(collated_prediction, expected_path)
8387

8488

@@ -88,14 +92,15 @@ def test_write_on_epoch_end(
8892
):
8993
writer = PredictionWriter(output_dir=temp_dir, write_interval="epoch")
9094

91-
writer.write_on_epoch_end(
92-
trainer=mock_trainer,
93-
pl_module=mock_module,
94-
predictions=sample_predictions,
95-
batch_indices=[],
96-
)
95+
with megatron_parallel_state_utils.distributed_model_parallel_state():
96+
writer.write_on_epoch_end(
97+
trainer=mock_trainer,
98+
pl_module=mock_module,
99+
predictions=sample_predictions,
100+
batch_indices=[],
101+
)
97102

98-
expected_path = os.path.join(temp_dir, f"predictions__rank_{mock_trainer.global_rank}.pt")
103+
expected_path = os.path.join(temp_dir, f"predictions__rank_{mock_trainer.global_rank}__dp_rank_0.pt")
99104

100105
mock_torch_save.assert_called_once() # Ensure it's called exactly once
101106

0 commit comments

Comments
 (0)