Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/NeMo
Submodule NeMo updated from b68596 to 7d7a10
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@
from nemo.collections.llm.peft.lora import LoRA, LoRALinear
from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ParallelLinearAdapter
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
from nemo.lightning.megatron_parallel import (
masked_token_loss,
masked_token_loss_context_parallel,
)
from nemo.lightning.megatron_parallel import masked_token_loss
from torch import Tensor, nn

from bionemo.llm.model.biobert.model import BioBertConfig, BioBertOutput, MegatronBioBertModel
Expand Down Expand Up @@ -102,17 +99,7 @@ def forward(
# TODO(@jstjohn) also handle different output keys, like the sequence loss.

cp_size = parallel_state.get_context_parallel_world_size()
if cp_size == 1:
# reduce the loss across the micro batch
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
else:
# reduce the loss across the micro batch.
# TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
# This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
# other necessary keys to the batch. Thanks!
loss_for_microbatch = masked_token_loss_context_parallel(
unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
)
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size)

# If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
# reducing the loss across the data parallel group.
Expand Down
19 changes: 3 additions & 16 deletions sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@
from megatron.core import parallel_state, tensor_parallel
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
from nemo.lightning.megatron_parallel import (
MegatronLossReduction,
masked_token_loss,
masked_token_loss_context_parallel,
)
from nemo.lightning.megatron_parallel import MegatronLossReduction, masked_token_loss
from torch import Tensor


Expand Down Expand Up @@ -181,17 +177,8 @@ def forward(

# compute loss
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size == 1:
# reduce the loss across the micro batch per valid token
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
else:
# reduce the loss across the micro batch per valid token.
# TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
# This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
# other necessary keys to the batch. Thanks!
loss_for_microbatch = masked_token_loss_context_parallel(
unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
)
# reduce the loss across the micro batch per valid token
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size)

# If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
# reducing the loss across the data parallel group.
Expand Down