Skip to content

fix: guard optional validation fields in CSAI forward during inference#838

Open
shaun0927 wants to merge 1 commit into
WenjieDu:mainfrom
shaun0927:fix/csai-inference-keyerror
Open

fix: guard optional validation fields in CSAI forward during inference#838
shaun0927 wants to merge 1 commit into
WenjieDu:mainfrom
shaun0927:fix/csai-inference-keyerror

Conversation

@shaun0927
Copy link
Copy Markdown

Description

In pypots/imputation/csai/core.py, _BCSAI.forward ends with:

if calc_criterion:
    if self.training:
        ...
    else:  # validation stage
        X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
        ...

if not self.training:
    results["x_ori"] = inputs["X_ori"]
    results["indicating_mask"] = inputs["indicating_mask"]

The first block correctly gates X_ori / indicating_mask behind
calc_criterion (those keys only exist in validation batches). The second
block runs whenever self.training is False — including plain predict()
calls where the user's Dataset has no ground-truth fields — and raises
KeyError: 'X_ori' there.

The unconditional accesses were introduced when CSAI was re-landed in PR
#788 and never made it through a predict-only test. This PR restores the
invariant by guarding the accesses so they only fire when the keys actually
exist in the input dict.

Changes

  • pypots/imputation/csai/core.py: wrap the results["x_ori"] and
    results["indicating_mask"] assignments in if "X_ori" in inputs: /
    if "indicating_mask" in inputs: checks. Semantically identical for any
    caller that already provides those keys (validation); now safe for pure
    inference.

Testing

  • Validation flow is unchanged: when the caller feeds a batch with both
    keys present, both entries are added to results just as before.
  • Pure-inference flow (user passes only X, missing_mask, the deltas_*
    / last_obs_* pair) now returns the regular result dict without raising
    KeyError.

No changes to model weights, training loss, or validation metric; the fix
is limited to two safe attribute lookups.

The post-eval block unconditionally reads inputs["X_ori"] and
inputs["indicating_mask"] whenever self.training is False. These keys
are only populated in validation batches, so predict() paths where the
user's Dataset omits ground-truth fields crash with KeyError. Gating
the accesses keeps validation behavior unchanged and restores plain
inference on minimal input dicts.
@sonarqubecloud
Copy link
Copy Markdown

@github-actions
Copy link
Copy Markdown

This PR is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 10 days.

@github-actions github-actions Bot added the stale label May 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant