fix: allocate CRPS accumulator on the input device#837
Merged
Conversation
…s.generation') (WenjieDu#829) * add version upper limitation for transformers
Update stale manager
…string Correct arg typing error in ModerTCN
The CRPS accumulator was created with torch.tensor(0.0), which always lives on CPU. The subsequent in-place += with a device tensor raised RuntimeError whenever the inputs were on CUDA or MPS. Allocating the accumulator on predictions.device makes GPU evaluation work.
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.



Description
calc_quantile_crpsandcalc_quantile_crps_sum(inpypots/nn/functional/error.py) initialize their running accumulator withCRPS = torch.tensor(0.0), which always lives on CPU. The subsequentCRPS += q_loss / denominatorthen tries to mix that CPU scalar with tensorsthat inherit the input device (CUDA, MPS, …), and raises:
This makes both metrics unusable whenever
predictions/targetsare onanything other than CPU — a surprising limitation for a PyTorch-first
toolkit, and out of step with sibling metrics like
calc_maewhich staydevice-agnostic.
Changes
pypots/nn/functional/error.py: allocate the CRPS accumulator onpredictions.devicein bothcalc_quantile_crpsandcalc_quantile_crps_sum. Two one-line changes, no behavioral change forCPU callers (the accumulator's device matches what it already was).
Testing
CPU path is untouched, so existing call sites keep their previous behavior.
The device-matching path was verified manually:
The fix is scoped to
pypots/nn/functional/error.pyonly and does not touchthe existing
_check_inputssignature, so there is no interaction with thein-flight #821.