Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/NeMo
Submodule NeMo updated from b68596 to d0fc65
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@
from nemo.collections.llm.peft.lora import LoRA, LoRALinear
from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ParallelLinearAdapter
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
from nemo.lightning.megatron_parallel import (
masked_token_loss,
masked_token_loss_context_parallel,
)
from nemo.lightning.megatron_parallel import masked_token_loss
from torch import Tensor, nn

from bionemo.llm.model.biobert.model import BioBertConfig, BioBertOutput, MegatronBioBertModel
Expand Down Expand Up @@ -102,17 +99,7 @@ def forward(
# TODO(@jstjohn) also handle different output keys, like the sequence loss.

cp_size = parallel_state.get_context_parallel_world_size()
if cp_size == 1:
# reduce the loss across the micro batch
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
else:
# reduce the loss across the micro batch.
# TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
# This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
# other necessary keys to the batch. Thanks!
loss_for_microbatch = masked_token_loss_context_parallel(
unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
)
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size)

# If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
# reducing the loss across the data parallel group.
Expand Down
19 changes: 3 additions & 16 deletions sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@
from megatron.core import parallel_state, tensor_parallel
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
from nemo.lightning.megatron_parallel import (
MegatronLossReduction,
masked_token_loss,
masked_token_loss_context_parallel,
)
from nemo.lightning.megatron_parallel import MegatronLossReduction, masked_token_loss
from torch import Tensor


Expand Down Expand Up @@ -181,17 +177,8 @@ def forward(

# compute loss
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size == 1:
# reduce the loss across the micro batch per valid token
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
else:
# reduce the loss across the micro batch per valid token.
# TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
# This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
# other necessary keys to the batch. Thanks!
loss_for_microbatch = masked_token_loss_context_parallel(
unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
)
# reduce the loss across the micro batch per valid token
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size)

# If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
# reducing the loss across the data parallel group.
Expand Down
Loading