@@ -91,24 +91,35 @@ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwa
9191 "in the model's predictions as outputs are not ordered and batch indices do not track input order."
9292 )
9393
94+ @staticmethod
95+ def _assert_initialized ():
96+ """Asserts that the environment is initialized."""
97+ if not (
98+ torch .distributed .is_available () and torch .distributed .is_initialized () and parallel_state .is_initialized ()
99+ ):
100+ raise RuntimeError ("This function is only defined within an initialized megatron parallel environment." )
101+
94102 @property
95103 def data_parallel_world_size (self ) -> int :
96104 """Returns the data parallel world size."""
105+ self ._assert_initialized ()
97106 return torch .distributed .get_world_size (parallel_state .get_data_parallel_group (with_context_parallel = False ))
98107
99108 @property
100109 def data_parallel_rank (self ) -> int :
101110 """Returns the data parallel rank."""
111+ self ._assert_initialized ()
102112 return torch .distributed .get_rank (parallel_state .get_data_parallel_group (with_context_parallel = False ))
103113
104114 @property
105115 def should_write_predictions (self ) -> bool :
106116 """Returns the context parallel rank."""
107117 # TODO: handle expert parallelism and other kinds of parallelism
118+ self ._assert_initialized ()
119+ if not parallel_state .is_pipeline_last_stage ():
120+ return False
108121 return self .save_all_model_parallel_ranks or (
109- parallel_state .is_pipeline_last_stage ()
110- and parallel_state .get_tensor_model_parallel_rank () == 0
111- and parallel_state .get_context_parallel_rank () == 0
122+ parallel_state .get_tensor_model_parallel_rank () == 0 and parallel_state .get_context_parallel_rank () == 0
112123 )
113124
114125 @override
0 commit comments