Skip to content

Commit 4dec162

Browse files
committed
Handle the case where a user wants all ranks
Signed-off-by: John St John <jstjohn@nvidia.com>
1 parent e0db7d0 commit 4dec162

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

sub-packages/bionemo-llm/src/bionemo/llm/utils/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def data_parallel_rank(self) -> int:
8989
@property
9090
def should_write_predictions(self) -> bool:
9191
"""Returns the context parallel rank."""
92-
return (
92+
return self.save_all_model_parallel_ranks or (
9393
parallel_state.is_pipeline_last_stage()
9494
and parallel_state.get_tensor_model_parallel_rank() == 0
9595
and parallel_state.get_context_parallel_rank() == 0

0 commit comments

Comments
 (0)