Skip to content

fix: allocate CRPS accumulator on the input device#837

Merged
WenjieDu merged 10 commits into
WenjieDu:devfrom
shaun0927:fix/crps-device-mismatch
Apr 25, 2026
Merged

fix: allocate CRPS accumulator on the input device#837
WenjieDu merged 10 commits into
WenjieDu:devfrom
shaun0927:fix/crps-device-mismatch

Conversation

@shaun0927
Copy link
Copy Markdown

Description

calc_quantile_crps and calc_quantile_crps_sum (in
pypots/nn/functional/error.py) initialize their running accumulator with
CRPS = torch.tensor(0.0), which always lives on CPU. The subsequent
CRPS += q_loss / denominator then tries to mix that CPU scalar with tensors
that inherit the input device (CUDA, MPS, …), and raises:

RuntimeError: Expected all tensors to be on the same device, but got CRPS
is on cpu, different from other tensors on cuda:0.

This makes both metrics unusable whenever predictions / targets are on
anything other than CPU — a surprising limitation for a PyTorch-first
toolkit, and out of step with sibling metrics like calc_mae which stay
device-agnostic.

Changes

  • pypots/nn/functional/error.py: allocate the CRPS accumulator on
    predictions.device in both calc_quantile_crps and
    calc_quantile_crps_sum. Two one-line changes, no behavioral change for
    CPU callers (the accumulator's device matches what it already was).

Testing

CPU path is untouched, so existing call sites keep their previous behavior.
The device-matching path was verified manually:

import torch
from pypots.nn.functional import calc_quantile_crps

device = "cuda" if torch.cuda.is_available() else "mps"
P = torch.randn(2, 3, 5, device=device)
T = torch.randn(2, 3, 5, device=device)
M = torch.ones(2, 3, 5, device=device)
calc_quantile_crps(P, T, M)   # no longer raises

The fix is scoped to pypots/nn/functional/error.py only and does not touch
the existing _check_inputs signature, so there is no interaction with the
in-flight #821.

WenjieDu and others added 10 commits April 8, 2026 04:16
…s.generation') (WenjieDu#829)

* add version upper limitation for transformers
…string

Correct arg typing error in ModerTCN
The CRPS accumulator was created with torch.tensor(0.0), which always
lives on CPU. The subsequent in-place += with a device tensor raised
RuntimeError whenever the inputs were on CUDA or MPS. Allocating the
accumulator on predictions.device makes GPU evaluation work.
@sonarqubecloud
Copy link
Copy Markdown

@WenjieDu WenjieDu changed the base branch from main to dev April 25, 2026 18:07
@WenjieDu WenjieDu merged commit 837138c into WenjieDu:dev Apr 25, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants