Skip to content

Commit ea64325

Browse files
committed
Address PR feedback
Signed-off-by: John St John <jstjohn@nvidia.com>
1 parent 825c2db commit ea64325

2 files changed

Lines changed: 23 additions & 10 deletions

File tree

sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,10 @@ def predict_step(self, batch, batch_idx: int | None = None) -> Tensor | dict[str
224224
softmax_logprobs = torch.log_softmax(forward_out_gathered, dim=-1)
225225
softmax_logprobs = softmax_logprobs[:, :-1]
226226
input_ids = tokens_gathered[:, 1:]
227-
try:
228-
assert softmax_logprobs.shape[1] == input_ids.shape[1]
229-
except Exception as e:
230-
if torch.distributed.get_rank() == 0:
231-
breakpoint()
232-
torch.distributed.barrier()
233-
raise e
227+
if softmax_logprobs.shape[1] != input_ids.shape[1]:
228+
raise RuntimeError(
229+
f"Softmax logprobs shape {softmax_logprobs.shape} does not match input ids shape {input_ids.shape}"
230+
)
234231

235232
logprobs = torch.gather(
236233
softmax_logprobs, # Gather likelihoods...
@@ -404,6 +401,11 @@ def predict(
404401
"""
405402
if work_dir is None:
406403
work_dir = Path(tempfile.mkdtemp())
404+
if files_per_subdir is None and write_interval == "batch":
405+
logger.warning(
406+
"--files-per-subdir is not set with --write-interval batch, will write all predictions to a "
407+
"single directory. This may cause problems if you are predicting on a very large dataset."
408+
)
407409
sequence_parallel = tensor_parallel_size > 1 and not no_sequence_parallel
408410
output_dir.mkdir(parents=True, exist_ok=True) # Make sure the output directory exists, files will be written here.
409411
model_parallel_size = tensor_parallel_size * pipeline_model_parallel_size * context_parallel_size

sub-packages/bionemo-llm/src/bionemo/llm/utils/callbacks.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,24 +91,35 @@ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwa
9191
"in the model's predictions as outputs are not ordered and batch indices do not track input order."
9292
)
9393

94+
@staticmethod
95+
def _assert_initialized():
96+
"""Asserts that the environment is initialized."""
97+
if not (
98+
torch.distributed.is_available() and torch.distributed.is_initialized() and parallel_state.is_initialized()
99+
):
100+
raise RuntimeError("This function is only defined within an initialized megatron parallel environment.")
101+
94102
@property
95103
def data_parallel_world_size(self) -> int:
96104
"""Returns the data parallel world size."""
105+
self._assert_initialized()
97106
return torch.distributed.get_world_size(parallel_state.get_data_parallel_group(with_context_parallel=False))
98107

99108
@property
100109
def data_parallel_rank(self) -> int:
101110
"""Returns the data parallel rank."""
111+
self._assert_initialized()
102112
return torch.distributed.get_rank(parallel_state.get_data_parallel_group(with_context_parallel=False))
103113

104114
@property
105115
def should_write_predictions(self) -> bool:
106116
"""Returns the context parallel rank."""
107117
# TODO: handle expert parallelism and other kinds of parallelism
118+
self._assert_initialized()
119+
if not parallel_state.is_pipeline_last_stage():
120+
return False
108121
return self.save_all_model_parallel_ranks or (
109-
parallel_state.is_pipeline_last_stage()
110-
and parallel_state.get_tensor_model_parallel_rank() == 0
111-
and parallel_state.get_context_parallel_rank() == 0
122+
parallel_state.get_tensor_model_parallel_rank() == 0 and parallel_state.get_context_parallel_rank() == 0
112123
)
113124

114125
@override

0 commit comments

Comments
 (0)