Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions pypots/nn/functional/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union, Optional
from typing import Union, Optional, Tuple

import numpy as np
import torch
Expand All @@ -16,7 +16,21 @@ def _check_inputs(
targets: Union[np.ndarray, torch.Tensor, list],
masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None,
check_shape: bool = True,
):
) -> Tuple:
"""Validate inputs for metric computation.

Returns a tuple ``(lib, targets, masks)`` where ``targets`` and ``masks``
may have been updated to exclude NaN positions in ``targets`` (only when
``masks`` is provided).

When ``masks`` is given, NaN values in ``targets`` are tolerated: the mask
is automatically extended so that NaN positions are excluded from the
metric, and the NaN values are replaced with zeros to prevent arithmetic
propagation (``NaN * 0 = NaN`` in both NumPy and PyTorch).

When ``masks`` is *not* given, NaN in ``targets`` raises an error because
there is no mask to indicate which positions should be evaluated.
"""
# check type
assert isinstance(predictions, type(targets)), (
f"types of `predictions` and `targets` must match, but got"
Expand All @@ -30,9 +44,30 @@ def _check_inputs(
assert (
prediction_shape == target_shape
), f"shape of `predictions` and `targets` must match, but got {prediction_shape} and {target_shape}"
# check NaN
# check NaN in predictions — model output should never contain NaN
assert not lib.isnan(predictions).any(), "`predictions` mustn't contain NaN values, but detected NaN in it"
assert not lib.isnan(targets).any(), "`targets` mustn't contain NaN values, but detected NaN in it"

# handle NaN in targets
has_nan_targets = bool(lib.isnan(targets).any())
if has_nan_targets:
if masks is None:
raise ValueError(
"`targets` contains NaN values but no `masks` were provided. "
"Either remove NaN from `targets` or provide `masks` to indicate "
"which positions should be evaluated."
)
# Extend masks to also exclude NaN target positions.
# Convention: mask=1 means observed, mask=0 means missing.
if isinstance(targets, torch.Tensor):
nan_free = (~torch.isnan(targets)).to(masks.dtype)
targets = torch.where(torch.isnan(targets), torch.zeros_like(targets), targets)
else:
nan_free = (~np.isnan(targets)).astype(masks.dtype)
targets = np.where(np.isnan(targets), 0, targets)
masks = masks * nan_free
else:
# no NaN — nothing to do
pass

if masks is not None:
# check type
Expand All @@ -49,7 +84,7 @@ def _check_inputs(
# check NaN
assert not lib.isnan(masks).any(), "`masks` mustn't contain NaN values, but detected NaN in it"

return lib
return lib, targets, masks


def calc_mae(
Expand Down Expand Up @@ -95,7 +130,7 @@ def calc_mae(

"""
# check shapes and values of inputs
lib = _check_inputs(predictions, targets, masks)
lib, targets, masks = _check_inputs(predictions, targets, masks)

if masks is not None:
return lib.sum(lib.abs(predictions - targets) * masks) / (lib.sum(masks) + 1e-12)
Expand Down Expand Up @@ -146,7 +181,7 @@ def calc_mse(

"""
# check shapes and values of inputs
lib = _check_inputs(predictions, targets, masks)
lib, targets, masks = _check_inputs(predictions, targets, masks)

if masks is not None:
return lib.sum(lib.square(predictions - targets) * masks) / (lib.sum(masks) + 1e-12)
Expand Down Expand Up @@ -246,7 +281,7 @@ def calc_mre(

"""
# check shapes and values of inputs
lib = _check_inputs(predictions, targets, masks)
lib, targets, masks = _check_inputs(predictions, targets, masks)

if masks is not None:
return lib.sum(lib.abs(predictions - targets) * masks) / (lib.sum(lib.abs(targets * masks)) + 1e-12)
Expand Down Expand Up @@ -300,7 +335,7 @@ def calc_quantile_crps(

"""
# check shapes and values of inputs
_ = _check_inputs(predictions, targets, masks, check_shape=False)
_, targets, masks = _check_inputs(predictions, targets, masks, check_shape=False)

if isinstance(predictions, np.ndarray):
predictions = torch.from_numpy(predictions)
Expand Down Expand Up @@ -359,7 +394,7 @@ def calc_quantile_crps_sum(

"""
# check shapes and values of inputs
_ = _check_inputs(predictions, targets, masks, check_shape=False)
_, targets, masks = _check_inputs(predictions, targets, masks, check_shape=False)

if isinstance(predictions, np.ndarray):
predictions = torch.from_numpy(predictions)
Expand Down
Empty file added tests/nn/__init__.py
Empty file.
101 changes: 101 additions & 0 deletions tests/nn/error_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
Test cases for error metric functions with NaN target support.
"""

# Created by Haoyu Wang
# License: BSD-3-Clause

import unittest

import numpy as np
import pytest
import torch

from pypots.nn.functional.error import calc_mae, calc_mse, calc_rmse, calc_mre


class TestErrorMetricsNaNTargets(unittest.TestCase):
"""Tests for NaN-tolerant metric computation when masks are provided."""

def test_backward_compat_nan_free(self):
"""Verify NaN-free inputs produce the same results as before."""
p = np.array([1.0, 2.0, 1.0, 4.0, 6.0])
t = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
mae = calc_mae(p, t)
assert abs(mae - 0.6) < 1e-6, f"Expected 0.6, got {mae}"

def test_backward_compat_with_masks(self):
"""Verify NaN-free inputs with masks produce the same results."""
p = np.array([1.0, 2.0, 1.0, 4.0, 6.0])
t = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
m = np.array([0.0, 0.0, 0.0, 1.0, 1.0])
mae = calc_mae(p, t, m)
assert abs(mae - 0.5) < 1e-6, f"Expected 0.5, got {mae}"

def test_nan_targets_with_masks_numpy(self):
"""NaN positions in targets should be auto-excluded via mask."""
p = np.array([1.0, 2.0, 1.0, 4.0, 6.0])
t = np.array([1.0, np.nan, 3.0, 4.0, np.nan])
m = np.array([1.0, 1.0, 1.0, 1.0, 1.0])
# Valid positions: 0, 2, 3 → errors: |1-1|+|1-3|+|4-4| = 2, count=3
mae = calc_mae(p, t, m)
expected = 2.0 / 3.0
assert abs(mae - expected) < 1e-6, f"Expected {expected}, got {mae}"

def test_nan_targets_with_masks_torch(self):
"""Same as above but with PyTorch tensors."""
p = torch.tensor([1.0, 2.0, 1.0, 4.0, 6.0])
t = torch.tensor([1.0, float("nan"), 3.0, 4.0, float("nan")])
m = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])
mae = calc_mae(p, t, m)
expected = 2.0 / 3.0
assert abs(mae.item() - expected) < 1e-6

def test_nan_targets_without_masks_raises(self):
"""NaN targets without masks should raise ValueError."""
p = np.array([1.0, 2.0, 3.0])
t = np.array([1.0, np.nan, 3.0])
with pytest.raises(ValueError, match="no `masks` were provided"):
calc_mae(p, t)

def test_all_nan_targets(self):
"""When all targets are NaN, all positions are masked → metric ≈ 0."""
p = np.array([1.0, 2.0, 3.0])
t = np.array([np.nan, np.nan, np.nan])
m = np.array([1.0, 1.0, 1.0])
mae = calc_mae(p, t, m)
assert abs(mae) < 1e-6, f"Expected ~0, got {mae}"

def test_no_nan_propagation_mse_rmse_mre(self):
"""Verify MSE, RMSE, MRE also handle NaN targets without propagation."""
p = np.array([1.0, 2.0, 1.0, 4.0, 6.0])
t = np.array([1.0, np.nan, 3.0, 4.0, np.nan])
m = np.array([1.0, 1.0, 1.0, 1.0, 1.0])
mse = calc_mse(p, t, m)
rmse = calc_rmse(p, t, m)
mre = calc_mre(p, t, m)
assert not np.isnan(mse), "MSE should not be NaN"
assert not np.isnan(rmse), "RMSE should not be NaN"
assert not np.isnan(mre), "MRE should not be NaN"

def test_predictions_nan_still_rejected(self):
"""Predictions with NaN should still raise AssertionError."""
p = np.array([1.0, np.nan, 3.0])
t = np.array([1.0, 2.0, 3.0])
m = np.array([1.0, 1.0, 1.0])
with pytest.raises(AssertionError, match="predictions"):
calc_mae(p, t, m)

def test_masks_already_exclude_nan(self):
"""When masks already exclude NaN positions, result is unchanged."""
p = np.array([1.0, 2.0, 1.0, 4.0, 6.0])
t = np.array([1.0, np.nan, 3.0, 4.0, np.nan])
m_auto = np.array([1.0, 1.0, 1.0, 1.0, 1.0])
m_manual = np.array([1.0, 0.0, 1.0, 1.0, 0.0])
mae_auto = calc_mae(p, t, m_auto)
mae_manual = calc_mae(p, t, m_manual)
assert abs(mae_auto - mae_manual) < 1e-6


if __name__ == "__main__":
unittest.main()
Loading