From 6e8bca5099ccd7b3703062503a33145886a7db19 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 8 Jan 2026 01:26:56 +0800 Subject: [PATCH 1/3] feat: add NaN detection during training Fix #4985. This implementation is much simpler than #4986. Signed-off-by: Jinzhe Zeng --- deepmd/loggers/training.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/deepmd/loggers/training.py b/deepmd/loggers/training.py index c7fe94e24d..40de2a3f01 100644 --- a/deepmd/loggers/training.py +++ b/deepmd/loggers/training.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import datetime +import logging +import math +log = logging.getLogger(__name__) def format_training_message( batch: int, @@ -19,7 +22,23 @@ def format_training_message_per_task( task_name: str, rmse: dict[str, float], learning_rate: float | None, + check_total_rmse_nan: bool = True, ) -> str: + """Format training messages for a specific task. + + Parameters + ---------- + batch : int + The batch index + task_name : str + The task name + rmse : dict[str, float] + The root-mean-squared errors. + learning_rate : float | None + The learning rate + check_total_rmse_nan : bool + Whether throw the error if the total RMSE is NaN + """ if task_name: task_name += ": " if learning_rate is None: @@ -28,8 +47,16 @@ def format_training_message_per_task( lr = f", lr = {learning_rate:8.2e}" # sort rmse rmse = dict(sorted(rmse.items())) - return ( + msg = ( f"batch {batch:7d}: {task_name}" f"{', '.join([f'{kk} = {vv:8.2e}' for kk, vv in rmse.items()])}" f"{lr}" ) + if check_total_rmse_nan and math.isnan(rmse.get("rmse", 0.0)): + log.error(msg) + err_msg = ( + f"NaN detected at batch {batch:7d}: {task_name}. " + "Something went wrong, and it is meaningless to continue." + ) + raise RuntimeError(err_msg) + return msg From c4dc138ac2f582d65eb5ba89a11c6d75abe0e8a2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Jan 2026 17:30:24 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/loggers/training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/loggers/training.py b/deepmd/loggers/training.py index 40de2a3f01..5e03671ea6 100644 --- a/deepmd/loggers/training.py +++ b/deepmd/loggers/training.py @@ -5,6 +5,7 @@ log = logging.getLogger(__name__) + def format_training_message( batch: int, wall_time: float, From db5e195d1a71b5596815cdffcc0e875b75223309 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 8 Jan 2026 01:41:29 +0800 Subject: [PATCH 3/3] Update deepmd/loggers/training.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jinzhe Zeng --- deepmd/loggers/training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/loggers/training.py b/deepmd/loggers/training.py index 5e03671ea6..555ab32622 100644 --- a/deepmd/loggers/training.py +++ b/deepmd/loggers/training.py @@ -38,7 +38,7 @@ def format_training_message_per_task( learning_rate : float | None The learning rate check_total_rmse_nan : bool - Whether throw the error if the total RMSE is NaN + Whether to throw an error if the total RMSE is NaN """ if task_name: task_name += ": "