fix: validate inputs and preserve numpy return type in calc_quantile_loss#839
Merged
WenjieDu merged 2 commits intoApr 25, 2026
Merged
Conversation
added 2 commits
April 17, 2026 13:08
…loss Two related consistency fixes for pypots/nn/functional/error.py: 1. calc_quantile_loss was the only calc_* function that did not call _check_inputs(). NaN or shape-mismatched inputs therefore produced silently wrong numeric results instead of the clear AssertionError that calc_mae/calc_mse/calc_rmse/calc_mre raise. This adds the same guard so all error metrics share a single validation contract. 2. After numpy support was introduced in WenjieDu#822, numpy inputs were always returned as a torch.Tensor, breaking the existing Union[float, torch.Tensor] contract that sibling metrics honor (numpy in -> numpy out). The function now converts back to a numpy scalar when the caller passed numpy arrays. Verified with the sibling metrics on both numpy and torch paths.
…s_sum working The initial fix passed _check_inputs with default check_shape=True, which rejects the intentional broadcasting in the calc_quantile_crps_sum code path (q_pred has shape (B,) while targets has shape (B, T)). Passing check_shape=False keeps the NaN/type guards that motivated this change while allowing both internal callers (calc_quantile_crps and calc_quantile_crps_sum) to keep broadcasting as they did before. _check_inputs still validates mask.shape == targets.shape, so the mask contract is unchanged.
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.



Description
Two small, related fixes to
pypots/nn/functional/error.pythat makecalc_quantile_lossbehave like its siblings in the same module.1. Missing
_check_inputscall.Every other
calc_*in the file (calc_mae,calc_mse,calc_rmse,calc_mre,calc_quantile_crps,calc_quantile_crps_sum) starts withlib = _check_inputs(...)so NaN / shape / dtype problems surface as aclear
AssertionError.calc_quantile_losswas the one exception andsilently returned
tensor(nan)on NaN input, e.g.Adding the same guard makes the module's validation contract uniform.
2. numpy-in / numpy-out parity.
The function signature returns
Union[float, torch.Tensor], matching thesibling metrics which preserve the caller's type (numpy in → numpy out).
After #822 added numpy support,
calc_quantile_lossalways convertednumpy inputs to torch tensors and returned a
torch.Tensor, breakingthat contract:
Restoring the numpy return on numpy inputs keeps downstream code that
expected a plain numpy scalar working.
Changes
pypots/nn/functional/error.py—calc_quantile_loss:_check_inputs(predictions, targets, eval_points)first;numpy_in = isinstance(predictions, np.ndarray)before theconversion block;
numpy_in, returnquantile_loss.detach().cpu().numpy()so thedeclared
Union[float, torch.Tensor]contract holds.The change is confined to one function and does not touch the
_check_inputssignature, so it does not conflict with the in-flight #821.Testing
AssertionError: predictions mustn't contain NaN values, matchingcalc_mae/calc_mse/ …AssertionError: shape of predictions and targets must match …instead of relying on downstream broadcast errors.calc_quantile_crpsstill callscalc_quantile_lossinternally ontorch tensors; that path is numerically unchanged.