diff --git a/3rdparty/NeMo b/3rdparty/NeMo index b685967f95..d0fc65838b 160000 --- a/3rdparty/NeMo +++ b/3rdparty/NeMo @@ -1 +1 @@ -Subproject commit b685967f9512e1906e11fbd95048ff0fb05ff2fe +Subproject commit d0fc65838ba2bc6f2506da742d5af427acfd0747 diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/model/finetune_token_regressor.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/model/finetune_token_regressor.py index ee47fd7fc3..bdea3691a0 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/model/finetune_token_regressor.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/model/finetune_token_regressor.py @@ -24,10 +24,7 @@ from nemo.collections.llm.peft.lora import LoRA, LoRALinear from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ParallelLinearAdapter from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group -from nemo.lightning.megatron_parallel import ( - masked_token_loss, - masked_token_loss_context_parallel, -) +from nemo.lightning.megatron_parallel import masked_token_loss from torch import Tensor, nn from bionemo.llm.model.biobert.model import BioBertConfig, BioBertOutput, MegatronBioBertModel @@ -102,17 +99,7 @@ def forward( # TODO(@jstjohn) also handle different output keys, like the sequence loss. cp_size = parallel_state.get_context_parallel_world_size() - if cp_size == 1: - # reduce the loss across the micro batch - loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"]) - else: - # reduce the loss across the micro batch. - # TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this. - # This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and - # other necessary keys to the batch. Thanks! - loss_for_microbatch = masked_token_loss_context_parallel( - unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"] - ) + loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size) # If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support # reducing the loss across the data parallel group. diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py index 272ab27829..e04619eda7 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py @@ -19,11 +19,7 @@ from megatron.core import parallel_state, tensor_parallel from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group -from nemo.lightning.megatron_parallel import ( - MegatronLossReduction, - masked_token_loss, - masked_token_loss_context_parallel, -) +from nemo.lightning.megatron_parallel import MegatronLossReduction, masked_token_loss from torch import Tensor @@ -181,17 +177,8 @@ def forward( # compute loss cp_size = parallel_state.get_context_parallel_world_size() - if cp_size == 1: - # reduce the loss across the micro batch per valid token - loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"]) - else: - # reduce the loss across the micro batch per valid token. - # TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this. - # This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and - # other necessary keys to the batch. Thanks! - loss_for_microbatch = masked_token_loss_context_parallel( - unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"] - ) + # reduce the loss across the micro batch per valid token + loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size) # If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support # reducing the loss across the data parallel group.