Skip to content

Commit daea062

Browse files
committed
refactor: Merge MoECrossEntropyLoss into modalities core
1 parent a801e99 commit daea062

5 files changed

Lines changed: 36 additions & 43 deletions

File tree

src/modalities/loss_functions.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,39 @@ def _parse_arguments(
8787
return labels, lm_logits
8888

8989

90+
class MoECrossEntropyLoss(Loss):
91+
"""Cross entropy loss with optional MoE auxiliary losses from model layers."""
92+
93+
def __init__(
94+
self,
95+
target_key: str,
96+
prediction_key: str,
97+
model,
98+
tag: str = "MoECrossEntropyLoss",
99+
):
100+
super().__init__(tag)
101+
self.target_key = target_key
102+
self.prediction_key = prediction_key
103+
self.model = model
104+
self.loss_fun = CrossEntropyLoss(reduction="mean")
105+
106+
def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor:
107+
labels = forward_batch.get_targets(self.target_key)
108+
lm_logits = forward_batch.get_predictions(self.prediction_key)
109+
110+
labels = labels.to(lm_logits.device)
111+
loss = self.loss_fun(
112+
lm_logits.contiguous().view(-1, lm_logits.size(-1)),
113+
labels.contiguous().long().view(-1),
114+
)
115+
116+
for layer in self.model.layers.values():
117+
if hasattr(layer, "aux_loss") and layer.aux_loss is not None:
118+
loss = loss + layer.aux_loss.to(loss.dtype)
119+
120+
return loss
121+
122+
90123
def nce_loss(
91124
embedding1: torch.Tensor, embedding2: torch.Tensor, device: torch.device, is_asymmetric: bool, temperature: float
92125
) -> torch.Tensor:

src/modalities/models/moe/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from modalities.models.moe.loss_functions import MoECrossEntropyLoss
1+
from modalities.loss_functions import MoECrossEntropyLoss
22
from modalities.models.moe.qwen_model import QwenModel, QwenModelConfig
33

44
__all__ = [

src/modalities/models/moe/loss_functions.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

src/modalities/registry/components.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
ProgressSubscriberFactory,
8686
ResultsSubscriberFactory,
8787
)
88-
from modalities.loss_functions import CLMCrossEntropyLoss
88+
from modalities.loss_functions import CLMCrossEntropyLoss, MoECrossEntropyLoss
8989
from modalities.models.coca.coca_model import CoCa, CoCaConfig
9090
from modalities.models.coca.collator import CoCaCollateFnConfig, CoCaCollatorFn
9191
from modalities.models.components.layer_norms import (
@@ -99,7 +99,6 @@
9999
from modalities.models.gpt2.llama3_like_initialization import Llama3Initializer, Llama3InitializerConfig
100100
from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig
101101
from modalities.models.model_factory import GPT2ModelFactory, ModelFactory
102-
from modalities.models.moe.loss_functions import MoECrossEntropyLoss
103102
from modalities.models.moe.qwen_model import QwenModel, QwenModelConfig
104103
from modalities.models.parallelism.pipeline_parallelism import ComponentSelectorFromPipeline, PipelineFactory
105104
from modalities.models.parallelism.pipeline_parallelism_configs import (

tests/models/moe/test_loss_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch.nn import CrossEntropyLoss
33

44
from modalities.batch import InferenceResultBatch
5-
from modalities.models.moe.loss_functions import MoECrossEntropyLoss
5+
from modalities.loss_functions import MoECrossEntropyLoss
66

77

88
class DummyLayer:

0 commit comments

Comments
 (0)