-
Notifications
You must be signed in to change notification settings - Fork 610
Expand file tree
/
Copy pathnan_detector.py
More file actions
54 lines (43 loc) · 1.66 KB
/
nan_detector.py
File metadata and controls
54 lines (43 loc) · 1.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Utilities for detecting NaN values in loss during training."""
import logging
import math
log = logging.getLogger(__name__)
class LossNaNError(RuntimeError):
"""Exception raised when NaN is detected in total loss during training."""
def __init__(self, step: int, total_loss: float) -> None:
"""Initialize the exception.
Parameters
----------
step : int
The training step where NaN was detected
total_loss : float
The total loss value that contains NaN
"""
self.step = step
self.total_loss = total_loss
message = (
f"NaN detected in total loss at training step {step}: {total_loss}. "
f"Training stopped to prevent wasting time with corrupted parameters. "
f"This typically indicates unstable training conditions such as "
f"learning rate too high, poor data quality, or numerical instability."
)
super().__init__(message)
def check_total_loss_nan(step: int, total_loss: float) -> None:
"""Check if the total loss contains NaN and raise an exception if found.
This function is designed to be called during training after the total loss
is computed and converted to a CPU float value.
Parameters
----------
step : int
Current training step
total_loss : float
Total loss value to check for NaN
Raises
------
LossNaNError
If the total loss contains NaN
"""
if math.isnan(total_loss):
log.error(f"NaN detected in total loss at step {step}: {total_loss}")
raise LossNaNError(step, total_loss)