Skip to content

feat: allow NaN in targets when masks are provided for metric computation#821

Closed
haoyu-haoyu wants to merge 2 commits into
WenjieDu:devfrom
haoyu-haoyu:fix/allow-nan-targets-with-masks
Closed

feat: allow NaN in targets when masks are provided for metric computation#821
haoyu-haoyu wants to merge 2 commits into
WenjieDu:devfrom
haoyu-haoyu:fix/allow-nan-targets-with-masks

Conversation

@haoyu-haoyu
Copy link
Copy Markdown
Contributor

Summary

Implements the feature requested in #708 (and attempted in the now-closed PR #707).

When evaluating imputation on irregularly sampled time series, the ground truth (targets) itself may contain NaN values at naturally missing positions. Previously, _check_inputs() rejected any NaN in targets unconditionally, blocking this legitimate workflow.

Why PR #707 failed

The previous attempt used logical_not incorrectly (reversed mask semantics — PyPOTS convention is mask=1 for observed, mask=0 for missing) and didn't complete tests before being auto-closed.

The hidden trap: NaN * 0 = NaN

Simply removing the NaN assertion is not sufficient. In both NumPy and PyTorch, NaN * 0 produces NaN, not 0. So even when the mask correctly zeroes out NaN positions, the arithmetic (predictions - targets) * masks will propagate NaN through the result. The fix must replace NaN with 0 in targets before any arithmetic.

Solution

_check_inputs() now returns (lib, targets, masks) instead of just lib:

# When masks provided AND targets contain NaN:
# 1. Extend mask to exclude NaN positions
nan_free = (~isnan(targets)).to(masks.dtype)
masks = masks * nan_free

# 2. Replace NaN with 0 to prevent NaN * 0 = NaN propagation
targets = where(isnan(targets), zeros_like(targets), targets)

Three behaviors:

Scenario Result
targets NaN-free Unchanged (backward compatible)
targets has NaN + masks provided Auto-extend mask + replace NaN with 0
targets has NaN + no masks ValueError with clear message

Example

import numpy as np
from pypots.nn.functional import calc_mae

predictions = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
targets = np.array([1.0, np.nan, 3.0, 4.0, np.nan])  # NaN at positions 1, 4
masks = np.array([1, 1, 1, 1, 1])  # Initially all observed

# Before: AssertionError ("targets mustn't contain NaN")
# After:  mae = 0.0 (NaN positions auto-masked, remaining values match)
mae = calc_mae(predictions, targets, masks)

Files changed

File Change
pypots/nn/functional/error.py _check_inputs returns (lib, targets, masks), all calc_* functions updated

Test plan

import numpy as np, torch
from pypots.nn.functional import calc_mae, calc_mse, calc_rmse, calc_mre

# Backward compat — NaN-free targets work as before
calc_mae(np.array([1,2,3]), np.array([1,2,4]))  # ✓

# NaN targets with masks — new behavior
t = np.array([1.0, np.nan, 3.0])
p = np.array([1.0, 99.0, 3.0])
m = np.array([1, 1, 1])
assert calc_mae(p, t, m) == 0.0  # NaN position auto-excluded

# NaN targets without masks — clear error
try:
    calc_mae(p, t)  # ValueError: targets contains NaN but no masks
except ValueError:
    pass  # ✓

# Works with PyTorch too
calc_mae(torch.tensor(p), torch.tensor(t), torch.tensor(m, dtype=torch.float))  # ✓

Fixes #708

…tion

When evaluating imputation on irregularly sampled time series, the
ground truth (targets) itself may contain NaN values at naturally
missing positions.  Previously, _check_inputs() rejected any NaN
in targets unconditionally, blocking this legitimate workflow.

Changes to _check_inputs():
- Return (lib, targets, masks) tuple instead of just lib, so callers
  receive NaN-safe versions of targets and masks
- When masks are provided AND targets contain NaN:
  1. Extend masks to exclude NaN positions (mask *= ~isnan(targets))
  2. Replace NaN in targets with 0.0 to prevent NaN * 0 = NaN
     propagation in arithmetic
- When masks are NOT provided AND targets contain NaN: raise
  ValueError with clear message (no way to determine eval positions)
- predictions NaN check unchanged (model output should never be NaN)

All calc_* functions updated to unpack the new 3-tuple return.
Fully backward compatible: existing code with NaN-free targets
works identically.

Fixes WenjieDu#708

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@WenjieDu
Copy link
Copy Markdown
Owner

Hi @haoyu-haoyu, please add unit testing cases for your updates. Thanks :)

9 test cases covering all scenarios for _check_inputs NaN handling:
- backward compat (NaN-free inputs unchanged)
- NaN targets + masks → auto-exclude (numpy + torch)
- NaN targets without masks → ValueError
- all-NaN targets → metric ≈ 0
- MSE/RMSE/MRE no NaN propagation
- predictions NaN still rejected
- masks already excluding NaN → same result

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@haoyu-haoyu
Copy link
Copy Markdown
Contributor Author

Hi @WenjieDu, thanks for the feedback! Unit tests added in tests/nn/error_metrics.py:

9 test cases covering all scenarios:

  1. Backward compatibility — NaN-free inputs produce same results
  2. NaN-free inputs with masks — same results
  3. NaN targets + masks (numpy) — auto-exclude NaN positions
  4. NaN targets + masks (torch) — same behavior
  5. NaN targets without masks — raises ValueError
  6. All-NaN targets — metric returns ~0
  7. MSE/RMSE/MRE — no NaN propagation
  8. Predictions with NaN — still rejected
  9. Masks already exclude NaN — same result as auto-exclude

@WenjieDu WenjieDu changed the base branch from main to dev March 18, 2026 18:29
@WenjieDu WenjieDu closed this Mar 18, 2026
@WenjieDu WenjieDu reopened this Mar 18, 2026
@sonarqubecloud
Copy link
Copy Markdown

@coveralls
Copy link
Copy Markdown
Collaborator

Pull Request Test Coverage Report for Build 23260746779

Details

  • 15 of 15 (100.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+0.03%) to 79.948%

Totals Coverage Status
Change from base Build 23238548275: 0.03%
Covered Lines: 15187
Relevant Lines: 18996

💛 - Coveralls

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 2, 2026

This pull request had no activity for 14 days. It will be closed in 1 week unless there is some new activity.

@github-actions github-actions Bot added the stale label Apr 2, 2026
@haoyu-haoyu
Copy link
Copy Markdown
Contributor Author

@WenjieDu Friendly ping — unit tests were added per your feedback (9 test cases, 100% coverage on new code, CI green, SonarCloud passed). Would you have a chance to take another look? Happy to address any further feedback. Thanks!

@haoyu-haoyu
Copy link
Copy Markdown
Contributor Author

Note: the CI failure on Python 3.11 + PyTorch 2.3.0 is a pre-existing environment issue unrelated to this PR:

Disabling PyTorch because PyTorch >= 2.4 is required but found 2.3.0
ImportError: cannot import name 'GenerationMixin' from 'transformers.generation'

The transformers library requires PyTorch >= 2.4, but the CI matrix pins PyTorch 2.3.0. The Python 3.9 + PyTorch 2.3.0 job passes with a compatible transformers version. This PR's changes (in pypots/nn/functional/) are not involved in the import chain that fails.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 3, 2026

This PR is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 10 days.

@github-actions github-actions Bot added the stale label May 3, 2026
@github-actions
Copy link
Copy Markdown

This PR was closed because it has been stalled for 10 days with no activity.

@github-actions github-actions Bot closed this May 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Enable metrics being computed even with NaN targets through the usage of masks

3 participants