Skip to content

Commit 917c9c4

Browse files
cspadesfarhadrghdorotat-nvtrvachovnvdreidenbach
committed
Fix masked token loss refactor from NeMo bump. (#855)
### Description <!-- Provide a detailed description of the changes in this PR --> - Fixes the error caused by NVIDIA-NeMo/NeMo#12459 refactoring the definition of `masked_token_loss` and `masked_token_loss_context_parallel` into a single function with a `cp_size` argument that no longer divides the loss by the number of "valid" (i.e. non-masked) tokens. So it returns a CP-reduced loss sum. - Specifically, this breaks one of our golden value tests in `bionemo-llm`: `sub-packages/bionemo-llm/tests/bionemo/llm/model/test_loss.py::test_loss_equivalency_bionemo_vs_pytorch`, and this fixes it with no behavior change to the LLM model `forward()`, i.e. we perform the normalization on valid tokens on our side now. ### Details - Bump NeMo to a version greater than: NVIDIA-NeMo/NeMo#12856 or matching this: #798 - Update: Need to migrate to `inference_context` in NeMo: https://github.com/NVIDIA/NeMo/tree/cye/hyena-gpt-infer-context - Bump Megatron to support new imports in the NeMo bump. Found a commit that bisects the new Megatron inference engine and the new NeMo imports to prevent breakage of our inference tests. - Use a backend version of RoPE for the Amplify Megatron vs. PyTorch/HF parity test to avoid the CP process group requirement. - `MaskedTokenLossReduction.forward()` return API changed. - Added commentary for future devs to understand the code. #### Appendix - NeMo Fork Hotfix Patch: Safe import of a future module in Megatron to avoid upgrading. ``` get_gpt_heterogeneous_layer_spec, HAVE_GPT_HETEROGENEOUS = safe_import("megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs") ``` ### Type of changes <!-- Mark the relevant option with an [x] --> - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): ### Usage / Testing <!--- How does a user interact with the changed code --> - Tested against the commit specified in this PR: #798 ```python cd 3rdparty/NeMo git checkout c998e273f9cd23e36d7348fa27d0c2692efd87c8 pytest -s sub-packages/bionemo-llm/tests/bionemo/llm/model/test_loss.py::test_loss_equivalency_bionemo_vs_pytorch ``` --------- Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com> Signed-off-by: Cory Ye <cye@nvidia.com> Signed-off-by: cspades <cory0ye@gmail.com> Signed-off-by: Timur Rvachov <trvachov@nvidia.com> Signed-off-by: Danny <dreidenbach@nvidia.com> Signed-off-by: Cory Ye <44509866+cspades@users.noreply.github.com> Signed-off-by: nvdreidenbach <97637601+nvdreidenbach@users.noreply.github.com> Signed-off-by: Peter St. John <pstjohn@nvidia.com> Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: Polina Binder <pbinder@nvidia.com> Signed-off-by: polinabinder1 <pbinder@nvidia.com> Signed-off-by: dorotat <dorotat@nvidia.com> Signed-off-by: Truong Nguyen <tgnguyen@nvidia.com> Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com> Signed-off-by: Timur Rvachov <120140748+trvachov@users.noreply.github.com> Signed-off-by: Steven <skothenhill@nvidia.com> Co-authored-by: Farhad Ramezanghorbani <farhadr@nvidia.com> Co-authored-by: Farhad Ramezanghorbani <farhadrgh@users.noreply.github.com> Co-authored-by: Dorota Toczydlowska <115542912+dorotat-nv@users.noreply.github.com> Co-authored-by: Timur Rvachov <120140748+trvachov@users.noreply.github.com> Co-authored-by: nvdreidenbach <97637601+nvdreidenbach@users.noreply.github.com> Co-authored-by: Steven Kothen-Hill <148821680+skothenhill-nv@users.noreply.github.com> Co-authored-by: Peter St. John <pstjohn@nvidia.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: polinabinder1 <pbinder@nvidia.com> Co-authored-by: Truong Nguyen <tgnguyen@nvidia.com> Co-authored-by: jomitchellnv <148147880+jomitchellnv@users.noreply.github.com> Co-authored-by: lvojtku <lvojtku@nvidia.com> Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
1 parent ca6e41f commit 917c9c4

11 files changed

Lines changed: 107 additions & 23 deletions

File tree

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ uv pip install --no-build-isolation \
195195
-r /requirements-test.txt
196196

197197
# Install back ngcsdk, as a WAR for the protobuf version conflict with nemo_toolkit.
198-
uv pip install ngcsdk==3.63.0 # Remove when https://nvidia.slack.com/archives/CEX3JC6SF/p1744898511311379 is fixed.
198+
uv pip install ngcsdk==3.64.3 # Temporary fix for changed filename, see https://nvidia.slack.com/archives/C074Z808N05/p1746231345981209
199199

200200
# Install nvidia-pytriton which seems to cause a conflict with pyzmq versions
201201
uv pip install nvidia-pytriton # Temporary dependency until this gets added to requirements_nlp.txt in NeMo.

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.5
1+
2.6

docs/docs/models/geneformer.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,4 +207,4 @@ The 106M parameter variant of Geneformer achieves over 50 TFLOPS per GPU during
207207

208208
![GPU Performance (TFLOPS) Comparison Between Geneformer Model Variants on A100 GPUs](../assets/images/geneformer/model_tflops_per_gpu_chart_geneformer.png)
209209

210-
Performance will increase if the `num_dataset_workers` and the `micro_batch_size` are set appropriately. For the above metrics, we set `num_dataset_workers=8`. For the 10m model, set `micro_batch_size=120` and for the 106m model set the `micro_batch_size=16`. This will enable you to achieve similar performance results.
210+
Performance will increase if the `num_dataset_workers` and the `micro_batch_size` are set appropriately. For the above metrics, we set `num_dataset_workers=8`. For the 10m model, set `micro_batch_size=120` and for the 106m model set the `micro_batch_size=16`. This will enable you to achieve similar performance results.

docs/docs/user-guide/appendix/releasenotes-fw.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
# Release Notes
22

3+
## BioNeMo Framework v2.6
4+
5+
### New Features
6+
7+
* Adds support for AMPLIFY [doi:10.1101/2024.09.23.614603](https://doi.org/10.1101/2024.09.23.614603) pre-training and inference, offering a 70% speedup over the xformers-based attention backend with similar final perplexity values at 1M pre-training steps. (4.23 for 120M, 3.05 for 350M). The model is fully compatible with existing weights on HuggingFace.
8+
* Adds alpha support for [LoRA fine-tuning to for ESM2 models](https://nvidia.github.io/bionemo-framework/models/ESM-2/#lora-fine-tuning-performace). Inference and fine-tuning are enabled along with resumption from a checkpoint.
9+
10+
### Updates & Improvements
11+
12+
* Blackwell support, tested on B200 systems.
13+
* Fixed Grace CPU support, released ARM compatible container.
14+
315
## BioNeMo Framework v2.5
416

517
### New Features
@@ -12,6 +24,9 @@
1224
* Upgrade bionemo-moco to v0.0.2
1325
* Brev.dev launchable tutorials
1426

27+
#### Known Issues
28+
* Partial test failures on ARM CPUs.
29+
1530
## BioNeMo Framework v2.4.1
1631

1732
### Updates & Improvements
@@ -23,6 +38,9 @@
2338
* Draft implementation of Evo2 with support for Hyena operators
2439
* bionemo-moco v0.0.1 released for building diffusion-like generative models.
2540

41+
### Known Issues
42+
* Partial test failures on ARM CPUs.
43+
2644
### Updates & Improvements
2745

2846
* ESM2 fine-tuning script with CLI (finetune_esm2) that supports sequence-level/token-level classification/regression using a CSV dataset.

sub-packages/bionemo-amplify/tests/bionemo/amplify/test_hf_rotary.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
import torch
17-
from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb
17+
from megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd
1818
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
1919
from transformers import AutoConfig
2020

@@ -47,8 +47,20 @@ def test_rope_embeddings():
4747
seq_len_interpolation_factor=nemo_config.seq_len_interpolation_factor,
4848
)
4949
rotary_pos_emb = rotary_pos_layer(q.shape[1])
50-
q_post_nemo = apply_rotary_pos_emb(q.transpose(0, 1).cuda(), rotary_pos_emb.cuda(), config=nemo_config).cpu()
51-
k_post_nemo = apply_rotary_pos_emb(k.transpose(0, 1).cuda(), rotary_pos_emb.cuda(), config=nemo_config).cpu()
50+
# Note: Use the backend implementation of the RoPE to avoid
51+
# getting or instantiating a CP process group.
52+
q_post_nemo = _apply_rotary_pos_emb_bshd(
53+
q.transpose(0, 1).cuda(),
54+
rotary_pos_emb.cuda(),
55+
rotary_interleaved=nemo_config.rotary_interleaved,
56+
multi_latent_attention=nemo_config.multi_latent_attention,
57+
).cpu()
58+
k_post_nemo = _apply_rotary_pos_emb_bshd(
59+
k.transpose(0, 1).cuda(),
60+
rotary_pos_emb.cuda(),
61+
rotary_interleaved=nemo_config.rotary_interleaved,
62+
multi_latent_attention=nemo_config.multi_latent_attention,
63+
).cpu()
5264

5365
torch.testing.assert_close(q_post, q_post_nemo.transpose(0, 1))
5466
torch.testing.assert_close(k_post, k_post_nemo.transpose(0, 1))

sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_train_esm2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def test_main_runs(tmp_path, dummy_protein_dataset, dummy_parquet_train_val_inpu
327327
event_files = list(log_dir.rglob("events.out.tfevents*"))
328328
assert event_files, f"No TensorBoard event files found under {log_dir}"
329329
assert "val_ppl" in trainer.logged_metrics # validation logging on by default
330-
assert "tflops_per_sec_per_gpu" in trainer.logged_metrics # ensuring that tflops logger can be added
330+
assert "TFLOPS_per_GPU" in trainer.logged_metrics # ensuring that tflops logger can be added
331331
assert "train_step_timing in s" in trainer.logged_metrics
332332

333333

sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
157157
return forward_out
158158
# Reminder: the model's predictions for input i land at output i+1. To get everything to align, we prepend the
159159
# EOS token to the input sequences and take the outputs for all but the first token.
160-
forward_out_tp_gathered = _gather_along_last_dim(forward_out)
160+
forward_out_tp_gathered = _gather_along_last_dim(
161+
forward_out, group=parallel_state.get_tensor_model_parallel_group()
162+
)
161163
# else:
162164
# forward_out_tp_gathered = _collect_into_dim(forward_out, dim=-1)
163165
forward_out_gathered = _gather_along_cp_dim(forward_out_tp_gathered)

sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_train_evo2_stops(tmp_path):
146146
)
147147

148148
assert "reduced_train_loss" in trainer.logged_metrics # validation logging on by default
149-
assert "tflops_per_sec_per_gpu" in trainer.logged_metrics # ensuring that tflops logger can be added
149+
assert "TFLOPS_per_GPU" in trainer.logged_metrics # ensuring that tflops logger can be added
150150
assert "train_step_timing in s" in trainer.logged_metrics
151151

152152

sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
from megatron.core import parallel_state, tensor_parallel
2020
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
2121
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
22-
from nemo.lightning.megatron_parallel import MegatronLossReduction, masked_token_loss
22+
from nemo.lightning.megatron_parallel import (
23+
MegatronLossReduction,
24+
masked_token_loss,
25+
)
2326
from torch import Tensor
2427

2528

@@ -175,14 +178,17 @@ def forward(
175178

176179
# TODO(@jstjohn) also handle different output keys, like the sequence loss.
177180

178-
# compute loss
181+
# Compute loss over "valid" tokens in the microbatch, i.e. the non-masked tokens.
182+
# The loss is not normalized, only potentially reduced via torch.distributed.ReduceOp.SUM
183+
# across the context parallel process group, so you need to divide by the number
184+
# of non-masked tokens (loss_mask.sum()) to compute the mean reduced loss per token.
179185
cp_size = parallel_state.get_context_parallel_world_size()
180-
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size)
186+
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size=cp_size)
187+
num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
181188

182189
# If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
183190
# reducing the loss across the data parallel group.
184191
if self.validation_step and not self.val_drop_last:
185-
num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
186192
if loss_for_microbatch.isnan():
187193
# TODO(@jomitchell): Add a unit test for this. This is the case where there are no valid tokens in the microbatch for the loss
188194
# to be computed over, so we expect a NaN loss (divide by zero for a mean) but we make this an expected and non-breaking case,
@@ -191,9 +197,8 @@ def forward(
191197
raise ValueError("Got NaN loss with non-empty input")
192198
loss_sum_for_microbatch = torch.zeros_like(num_valid_tokens_in_microbatch)
193199
else:
194-
loss_sum_for_microbatch = (
195-
num_valid_tokens_in_microbatch * loss_for_microbatch
196-
) # sum over all valid tokens
200+
# The reduced loss is already the sum of all losses from masked_token_loss().
201+
loss_sum_for_microbatch = loss_for_microbatch
197202

198203
# In this case we need to store the loss sum as well as the number of valid tokens in the microbatch.
199204
loss_sum_and_microbatch_size_all_gpu = torch.cat(
@@ -202,17 +207,28 @@ def forward(
202207
Tensor([num_valid_tokens_in_microbatch]).cuda().clone().detach(),
203208
]
204209
)
210+
211+
# Reduce the loss sum across the data parallel group to get the total loss
212+
# for all data parallel / distributed microbatches.
205213
torch.distributed.all_reduce(
206214
loss_sum_and_microbatch_size_all_gpu,
207215
group=parallel_state.get_data_parallel_group(),
208216
op=torch.distributed.ReduceOp.SUM,
209217
)
218+
219+
# Return the loss tensor multiplied by the context parallel size,
220+
# and the data & context parallel reduced loss sum.
210221
return loss_for_microbatch * cp_size, {
211222
"loss_sum_and_microbatch_size": loss_sum_and_microbatch_size_all_gpu
212223
}
213224

214-
# average the losses across the data parallel group, but also return the unreduced loss
215-
reduced_loss = average_losses_across_data_parallel_group([loss_for_microbatch])
225+
# Return the loss tensor multiplied by the context parallel size, as well as
226+
# the data-parallel averaged loss, i.e. the loss divided by the DP size.
227+
# Normalize the loss by the number of "valid" tokens, because masked_token_loss
228+
# no longer does this normalization, and BioNeMo losses expect this normalization.
229+
reduced_loss = (
230+
average_losses_across_data_parallel_group([loss_for_microbatch]) / num_valid_tokens_in_microbatch
231+
)
216232
return loss_for_microbatch * cp_size, {"avg": reduced_loss}
217233

218234

sub-packages/bionemo-llm/src/bionemo/llm/utils/callbacks.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@
2929

3030

3131
class PredictionWriter(BasePredictionWriter, pl.Callback):
32-
"""A callback that writes predictions to disk at specified intervals during training."""
32+
"""A callback that writes predictions to disk at specified intervals during training.
33+
34+
Logits, Embeddings, Hiddens, Input IDs, and Labels may all be saved to the disk depending on trainer configuration.
35+
Batch Idxs are provided for each prediction in the same dictionary. These must be used to maintain order between
36+
multi device predictions and single device predictions.
37+
"""
3338

3439
def __init__(
3540
self,
@@ -42,15 +47,28 @@ def __init__(
4247
4348
Args:
4449
output_dir: The directory where predictions will be written.
45-
write_interval: The interval at which predictions will be written. (batch, epoch)
50+
write_interval: The interval at which predictions will be written (batch, epoch). Epoch may not be used with multi-device trainers.
4651
batch_dim_key_defaults: The default batch dimension for each key, if different from the standard 0.
4752
seq_dim_key_defaults: The default sequence dimension for each key, if different from the standard 1.
4853
"""
4954
super().__init__(write_interval)
55+
self.write_interval = write_interval
5056
self.output_dir = str(output_dir)
5157
self.batch_dim_key_defaults = batch_dim_key_defaults
5258
self.seq_dim_key_defaults = seq_dim_key_defaults
5359

60+
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs) -> None: # noqa: D417
61+
"""Invoked with Trainer.fit, validate, test, and predict are called. Will immediately fail when 'write_interval' is 'epoch' and 'trainer.num_devices' > 1.
62+
63+
Args:
64+
trainer: The Trainer instance.
65+
pl_module: The LightningModule instance.
66+
"""
67+
if trainer.num_devices > 1 and self.write_interval == "epoch":
68+
raise ValueError(
69+
"Multi-GPU predictions are not permitted as outputs are not ordered and batch indices are lost."
70+
)
71+
5472
def write_on_batch_end(
5573
self,
5674
trainer: pl.Trainer,
@@ -63,6 +81,9 @@ def write_on_batch_end(
6381
) -> None:
6482
"""Writes predictions to disk at the end of each batch.
6583
84+
Predictions files follow the naming pattern, where rank is the active GPU in which the predictions were made.
85+
predictions__rank_{rank}__batch_{batch_idx}.pt
86+
6687
Args:
6788
trainer: The Trainer instance.
6889
pl_module: The LightningModule instance.
@@ -77,7 +98,12 @@ def write_on_batch_end(
7798
result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}__batch_{batch_idx}.pt")
7899

79100
# batch_indices is not captured due to a lightning bug when return_predictions = False
80-
# we use input IDs in the prediction to map the result to input
101+
# we use input IDs in the prediction to map the result to input.
102+
103+
# NOTE store the batch_idx so we do not need to rely on filenames for reconstruction of inputs. This is wrapped
104+
# in a tensor and list container to ensure compatibility with batch_collator.
105+
prediction["batch_idx"] = torch.tensor([batch_idx], dtype=torch.int64)
106+
81107
torch.save(prediction, result_path)
82108
logging.info(f"Inference predictions are stored in {result_path}\n{prediction.keys()}")
83109

@@ -90,14 +116,23 @@ def write_on_epoch_end(
90116
) -> None:
91117
"""Writes predictions to disk at the end of each epoch.
92118
119+
Writing all predictions on epoch end is memory intensive. It is recommended to use the batch writer instead for
120+
large predictions.
121+
122+
Multi-device predictions will likely yield predictions in an order that is inconsistent with single device predictions and the input data.
123+
93124
Args:
94125
trainer: The Trainer instance.
95126
pl_module: The LightningModule instance.
96127
predictions: The predictions made by the model.
97128
batch_indices: The indices of the batch.
129+
130+
Raises:
131+
Multi-GPU predictions are output in an inconsistent order with multiple devices.
98132
"""
99133
# this will create N (num processes) files in `output_dir` each containing
100134
# the predictions of it's respective rank
135+
101136
result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}.pt")
102137

103138
# collate multiple batches / ignore empty ones
@@ -106,13 +141,14 @@ def write_on_epoch_end(
106141
collate_kwargs["batch_dim_key_defaults"] = self.batch_dim_key_defaults
107142
if self.seq_dim_key_defaults is not None:
108143
collate_kwargs["seq_dim_key_defaults"] = self.seq_dim_key_defaults
144+
109145
prediction = batch_collator([item for item in predictions if item is not None], **collate_kwargs)
110146

111147
# batch_indices is not captured due to a lightning bug when return_predictions = False
112148
# we use input IDs in the prediction to map the result to input
113-
torch.save(prediction, result_path)
114149
if isinstance(prediction, dict):
115150
keys = prediction.keys()
116151
else:
117152
keys = "tensor"
153+
torch.save(prediction, result_path)
118154
logging.info(f"Inference predictions are stored in {result_path}\n{keys}")

0 commit comments

Comments
 (0)