Skip to content

Commit e26c8fc

Browse files
committed
Pressure test the write per batch and subdirs path for a lot of files
Signed-off-by: John St John <jstjohn@nvidia.com>
1 parent 0a09609 commit e26c8fc

2 files changed

Lines changed: 37 additions & 6 deletions

File tree

sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ def parse_args():
105105
default=None,
106106
help="Output dir that will contain the generated text produced by the Evo2 model. If not provided, the output will be logged.",
107107
)
108+
ap.add_argument(
109+
"--files-per-subdir",
110+
type=int,
111+
help="Number of files to write to each subdirectory. If provided, subdirectories with N files each will be created. Ignored unless --write-interval is 'batch'.",
112+
)
108113
ap.add_argument(
109114
"--full-fp8",
110115
action="store_true",
@@ -374,6 +379,7 @@ def predict(
374379
hybrid_override_pattern: str | None = None,
375380
num_layers: int | None = None,
376381
seq_len_interpolation_factor: int | None = None,
382+
files_per_subdir: int | None = None,
377383
):
378384
"""Inference workflow for Evo2.
379385
@@ -422,6 +428,8 @@ def predict(
422428
write_interval=write_interval,
423429
batch_dim_key_defaults={"token_logits": 0},
424430
seq_dim_key_defaults={"token_logits": 1},
431+
files_per_subdir=files_per_subdir,
432+
save_all_model_parallel_ranks=False, # only write one copy of predictions.
425433
)
426434
],
427435
plugins=nl.MegatronMixedPrecision(
@@ -536,6 +544,8 @@ def main():
536544
hybrid_override_pattern=args.hybrid_override_pattern,
537545
seq_len_interpolation_factor=args.seq_len_interpolation_factor,
538546
num_layers=args.num_layers,
547+
files_per_subdir=args.files_per_subdir,
548+
write_interval=args.write_interval,
539549
)
540550

541551

sub-packages/bionemo-llm/src/bionemo/llm/utils/callbacks.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import logging
1818
import os
19-
from typing import Any, Literal, Sequence
19+
from typing import Any, Literal, Sequence, override
2020

2121
import lightning.pytorch as pl
2222
import torch
@@ -45,23 +45,33 @@ def __init__(
4545
batch_dim_key_defaults: dict[str, int] | None = None,
4646
seq_dim_key_defaults: dict[str, int] | None = None,
4747
save_all_model_parallel_ranks: bool = False,
48+
files_per_subdir: int | None = None,
4849
):
4950
"""Initializes the callback.
5051
5152
Args:
5253
output_dir: The directory where predictions will be written.
53-
write_interval: The interval at which predictions will be written (batch, epoch). Epoch may not be used with multi-device trainers.
54+
write_interval: The interval at which predictions will be written (batch, epoch). Epoch may not be used with
55+
multi-device trainers.
5456
batch_dim_key_defaults: The default batch dimension for each key, if different from the standard 0.
5557
seq_dim_key_defaults: The default sequence dimension for each key, if different from the standard 1.
5658
save_all_model_parallel_ranks: Whether to save predictions for all model parallel ranks. Generally these
5759
will be redundant.
60+
files_per_subdir: Number of files to write to each subdirectory. If provided, subdirectories with N files
61+
each will be created. Ignored unless write_interval is 'batch'.
5862
"""
5963
super().__init__(write_interval)
6064
self.write_interval = write_interval
6165
self.output_dir = str(output_dir)
66+
self.base_dir = self.output_dir # start out like this, but output_dir will be updated if files_per_subdir>0
6267
self.batch_dim_key_defaults = batch_dim_key_defaults
6368
self.seq_dim_key_defaults = seq_dim_key_defaults
6469
self.save_all_model_parallel_ranks = save_all_model_parallel_ranks
70+
self.files_per_subdir = files_per_subdir
71+
# Initialize to infinity if files_per_subdir is provided so that we create a new subdirectory before writing
72+
# any files.
73+
self.num_files_written = float("inf") if files_per_subdir else 0
74+
self.num_subdirs_written = 0
6575

6676
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs) -> None: # noqa: D417
6777
"""Invoked with Trainer.fit, validate, test, and predict are called. Will immediately fail when 'write_interval' is 'epoch' and 'trainer.num_devices' > 1.
@@ -77,9 +87,9 @@ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwa
7787
)
7888

7989
@property
80-
def model_parallel_rank(self) -> int:
81-
"""Returns the model parallel rank."""
82-
return torch.distributed.get_rank(parallel_state.get_model_parallel_group())
90+
def data_parallel_world_size(self) -> int:
91+
"""Returns the data parallel world size."""
92+
return torch.distributed.get_world_size(parallel_state.get_data_parallel_group(with_context_parallel=False))
8393

8494
@property
8595
def data_parallel_rank(self) -> int:
@@ -96,12 +106,13 @@ def should_write_predictions(self) -> bool:
96106
and parallel_state.get_context_parallel_rank() == 0
97107
)
98108

109+
@override
99110
def write_on_batch_end(
100111
self,
101112
trainer: pl.Trainer,
102113
pl_module: pl.LightningModule,
103114
prediction: Any,
104-
batch_indices: Sequence[int],
115+
batch_indices: Sequence[int] | None,
105116
batch: Any,
106117
batch_idx: int,
107118
dataloader_idx: int,
@@ -123,6 +134,14 @@ def write_on_batch_end(
123134
# this will create N (num processes) files in `output_dir` each containing
124135
# the predictions of it's respective rank
125136
if self.should_write_predictions:
137+
if (
138+
self.files_per_subdir is not None
139+
and (self.num_files_written * self.data_parallel_world_size) >= self.files_per_subdir
140+
):
141+
self.num_subdirs_written += 1
142+
self.output_dir = os.path.join(self.base_dir, f"subdir_{self.num_subdirs_written}")
143+
os.makedirs(self.output_dir, exist_ok=True)
144+
self.num_files_written = 0
126145
result_path = os.path.join(
127146
self.output_dir,
128147
f"predictions__rank_{trainer.global_rank}__dp_rank_{self.data_parallel_rank}__batch_{batch_idx}.pt",
@@ -137,7 +156,9 @@ def write_on_batch_end(
137156

138157
torch.save(prediction, result_path)
139158
logging.info(f"Inference predictions are stored in {result_path}\n{prediction.keys()}")
159+
self.num_files_written += 1
140160

161+
@override
141162
def write_on_epoch_end(
142163
self,
143164
trainer: pl.Trainer,

0 commit comments

Comments
 (0)