diff --git a/tests/pytorch/engine/test_inputs_maker.py b/tests/pytorch/engine/test_inputs_maker.py index 5923722877..23193ff07c 100644 --- a/tests/pytorch/engine/test_inputs_maker.py +++ b/tests/pytorch/engine/test_inputs_maker.py @@ -35,6 +35,7 @@ def __init__(self, self.prefix_cache = SimpleNamespace(match_start_step=match_start_step) self.return_logits = False self.return_routed_experts = False + self.return_ce_loss = False def get_input_multimodals(self): return self._input_multimodals