|
24 | 24 | from nemo.collections.llm.peft.lora import LoRA, LoRALinear |
25 | 25 | from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ParallelLinearAdapter |
26 | 26 | from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group |
27 | | -from nemo.lightning.megatron_parallel import ( |
28 | | - masked_token_loss, |
29 | | - masked_token_loss_context_parallel, |
30 | | -) |
| 27 | +from nemo.lightning.megatron_parallel import masked_token_loss |
31 | 28 | from torch import Tensor, nn |
32 | 29 |
|
33 | 30 | from bionemo.llm.model.biobert.model import BioBertConfig, BioBertOutput, MegatronBioBertModel |
@@ -102,17 +99,7 @@ def forward( |
102 | 99 | # TODO(@jstjohn) also handle different output keys, like the sequence loss. |
103 | 100 |
|
104 | 101 | cp_size = parallel_state.get_context_parallel_world_size() |
105 | | - if cp_size == 1: |
106 | | - # reduce the loss across the micro batch |
107 | | - loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"]) |
108 | | - else: |
109 | | - # reduce the loss across the micro batch. |
110 | | - # TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this. |
111 | | - # This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and |
112 | | - # other necessary keys to the batch. Thanks! |
113 | | - loss_for_microbatch = masked_token_loss_context_parallel( |
114 | | - unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"] |
115 | | - ) |
| 102 | + loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size) |
116 | 103 |
|
117 | 104 | # If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support |
118 | 105 | # reducing the loss across the data parallel group. |
|
0 commit comments