Skip to content

Commit 3ac72f6

Browse files
committed
remove call to context_parallel loss
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 607b19e commit 3ac72f6

2 files changed

Lines changed: 5 additions & 31 deletions

File tree

sub-packages/bionemo-geneformer/src/bionemo/geneformer/model/finetune_token_regressor.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,7 @@
2424
from nemo.collections.llm.peft.lora import LoRA, LoRALinear
2525
from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ParallelLinearAdapter
2626
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
27-
from nemo.lightning.megatron_parallel import (
28-
masked_token_loss,
29-
masked_token_loss_context_parallel,
30-
)
27+
from nemo.lightning.megatron_parallel import masked_token_loss
3128
from torch import Tensor, nn
3229

3330
from bionemo.llm.model.biobert.model import BioBertConfig, BioBertOutput, MegatronBioBertModel
@@ -102,17 +99,7 @@ def forward(
10299
# TODO(@jstjohn) also handle different output keys, like the sequence loss.
103100

104101
cp_size = parallel_state.get_context_parallel_world_size()
105-
if cp_size == 1:
106-
# reduce the loss across the micro batch
107-
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
108-
else:
109-
# reduce the loss across the micro batch.
110-
# TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
111-
# This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
112-
# other necessary keys to the batch. Thanks!
113-
loss_for_microbatch = masked_token_loss_context_parallel(
114-
unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
115-
)
102+
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size)
116103

117104
# If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
118105
# reducing the loss across the data parallel group.

sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,7 @@
1919
from megatron.core import parallel_state, tensor_parallel
2020
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
2121
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
22-
from nemo.lightning.megatron_parallel import (
23-
MegatronLossReduction,
24-
masked_token_loss,
25-
masked_token_loss_context_parallel,
26-
)
22+
from nemo.lightning.megatron_parallel import MegatronLossReduction, masked_token_loss
2723
from torch import Tensor
2824

2925

@@ -181,17 +177,8 @@ def forward(
181177

182178
# compute loss
183179
cp_size = parallel_state.get_context_parallel_world_size()
184-
if cp_size == 1:
185-
# reduce the loss across the micro batch per valid token
186-
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
187-
else:
188-
# reduce the loss across the micro batch per valid token.
189-
# TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
190-
# This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
191-
# other necessary keys to the batch. Thanks!
192-
loss_for_microbatch = masked_token_loss_context_parallel(
193-
unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
194-
)
180+
# reduce the loss across the micro batch per valid token
181+
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size)
195182

196183
# If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
197184
# reducing the loss across the data parallel group.

0 commit comments

Comments
 (0)