Skip to content

Commit a9bbf74

Browse files
fix: nan error when all fparam/aparam have equal values (#5321)
For standard deviation of `fparam/aparam`, $\sigma = \sqrt{\frac{1}{N} \sum_{i=1}^{N} (x_i - \bar{x})^2}=\sqrt{\frac{\sum x_i^2}{N} - \left( \frac{\sum x_i}{N} \right)^2}$. When all `fparam`/`aparam` have equal values in one dimension, $\frac{\sum x_i^2}{N} - \left( \frac{\sum x_i}{N} \right)^2$ equals zero. However, it sometimes becomes a very small negative number(for example, 1e-18) due to numerical instability, so $\sqrt{\frac{\sum x_i^2}{N} - \left( \frac{\sum x_i}{N} \right)^2}$ becomes `nan`. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved numerical stability in variance/std calculations by ensuring intermediate variance values are non-negative before taking the square root. This prevents occasional floating-point underflow from producing invalid results and yields more reliable statistical outputs across edge-case inputs. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3ab3779 commit a9bbf74

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

deepmd/utils/env_mat_stat.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,12 @@ def compute_std(self, default: float = 1e-1, protection: float = 1e-2) -> float:
8989
if self.number == 0:
9090
return default
9191
val = np.sqrt(
92-
self.squared_sum / self.number
93-
- np.multiply(self.sum / self.number, self.sum / self.number)
92+
np.clip(
93+
self.squared_sum / self.number
94+
- np.multiply(self.sum / self.number, self.sum / self.number),
95+
a_min=0,
96+
a_max=None,
97+
)
9498
)
9599
if np.abs(val) < protection:
96100
val = protection

0 commit comments

Comments
 (0)