Skip to content

Fix training crashes from corrupt beatmaps and wandb NaN loss#98

Draft
rosacry wants to merge 1 commit into
OliBomby:mainfrom
rosacry:fix/training-robustness
Draft

Fix training crashes from corrupt beatmaps and wandb NaN loss#98
rosacry wants to merge 1 commit into
OliBomby:mainfrom
rosacry:fix/training-robustness

Conversation

@rosacry

@rosacry rosacry commented Apr 9, 2026

Copy link
Copy Markdown
Contributor

Fixes two training stability issues that cause crashes during long training runs.

Changes

Corrupt beatmap handling (mmrs_dataset.py)

  • Wraps _get_next_beatmap() in try/except for ValueError and IndexError
  • Logs the corrupt beatmap filename and skips to the next one instead of crashing
  • Encountered during training on large datasets with occasional malformed .osu files

wandb artifact logging (train_utils.py)

  • Sanitizes current_loss with math.isfinite() before passing to wandb artifact metadata (NaN/Inf values crash wandb)
  • Wraps the entire artifact logging block in try/except so a wandb failure doesn't prevent checkpoint saving
  • Encountered when loss spikes to NaN during early training steps

Copilot AI review requested due to automatic review settings April 9, 2026 17:58

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves training robustness by preventing two known crash modes: malformed beatmaps during dataset iteration and wandb failures during checkpoint artifact logging.

Changes:

  • Skip corrupt beatmaps by catching ValueError/IndexError around _get_next_beatmap() iteration.
  • Sanitize current_loss before adding it to wandb artifact metadata to avoid NaN/Inf-related wandb crashes.
  • Wrap wandb artifact logging in a try/except so wandb issues don’t block checkpoint saving.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
osuT5/osuT5/utils/train_utils.py Makes wandb artifact logging resilient to NaN/Inf loss values and to wandb logging exceptions.
osuT5/osuT5/dataset/mmrs_dataset.py Prevents training crashes by skipping beatmaps that raise parsing/processing errors during iteration.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +101 to +102
except Exception as e:
logger.warning(f"Failed to log checkpoint artifact to wandb: {e}")

Copilot AI Apr 9, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The broad except Exception as e logs only the exception message and drops the traceback, which makes wandb failures hard to diagnose. Use logger.exception(...) (or logger.warning(..., exc_info=True)) so the stack trace is preserved while still keeping checkpoint saving non-blocking.

Suggested change
except Exception as e:
logger.warning(f"Failed to log checkpoint artifact to wandb: {e}")
except Exception:
logger.exception("Failed to log checkpoint artifact to wandb")

Copilot uses AI. Check for mistakes.
Comment on lines +72 to +89
safe_loss = shared.current_loss if math.isfinite(shared.current_loss) else 0.0
art = wandb.Artifact(
f"osuT5-{wandb.run.id}",
type="model",
metadata={
"format": "accelerate",
"src_seq_len": args.data.src_seq_len,
"tgt_seq_len": args.data.tgt_seq_len,
"num_classes": args.data.num_classes,
"num_diff_classes": args.data.num_diff_classes,
"max_difficulty": args.data.max_diff,
"class_dropout_prob": args.data.class_dropout_prob,
"diff_dropout_prob": args.data.diff_dropout_prob,
"spectrogram": args.model.spectrogram,
"current_train_step": shared.current_train_step,
"current_epoch": shared.current_epoch,
"current_loss": safe_loss,
},

Copilot AI Apr 9, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safe_loss maps non-finite shared.current_loss (including the initial np.inf) to 0.0, which makes artifact metadata misleading (it will look like loss improved to 0). Prefer recording None/"non_finite" or omitting current_loss from metadata when it isn't finite, and optionally add a separate boolean flag (e.g., current_loss_is_finite).

Suggested change
safe_loss = shared.current_loss if math.isfinite(shared.current_loss) else 0.0
art = wandb.Artifact(
f"osuT5-{wandb.run.id}",
type="model",
metadata={
"format": "accelerate",
"src_seq_len": args.data.src_seq_len,
"tgt_seq_len": args.data.tgt_seq_len,
"num_classes": args.data.num_classes,
"num_diff_classes": args.data.num_diff_classes,
"max_difficulty": args.data.max_diff,
"class_dropout_prob": args.data.class_dropout_prob,
"diff_dropout_prob": args.data.diff_dropout_prob,
"spectrogram": args.model.spectrogram,
"current_train_step": shared.current_train_step,
"current_epoch": shared.current_epoch,
"current_loss": safe_loss,
},
current_loss_is_finite = math.isfinite(shared.current_loss)
artifact_metadata = {
"format": "accelerate",
"src_seq_len": args.data.src_seq_len,
"tgt_seq_len": args.data.tgt_seq_len,
"num_classes": args.data.num_classes,
"num_diff_classes": args.data.num_diff_classes,
"max_difficulty": args.data.max_diff,
"class_dropout_prob": args.data.class_dropout_prob,
"diff_dropout_prob": args.data.diff_dropout_prob,
"spectrogram": args.model.spectrogram,
"current_train_step": shared.current_train_step,
"current_epoch": shared.current_epoch,
"current_loss_is_finite": current_loss_is_finite,
"current_loss": shared.current_loss if current_loss_is_finite else None,
}
art = wandb.Artifact(
f"osuT5-{wandb.run.id}",
type="model",
metadata=artifact_metadata,

Copilot uses AI. Check for mistakes.
@rosacry rosacry marked this pull request as draft April 9, 2026 18:01
- Skip corrupt beatmaps during training instead of crashing (catch ValueError/IndexError)
- Sanitize NaN/Inf loss values before logging to wandb artifacts
- Wrap wandb artifact logging in try/except to prevent checkpoint save failures
@rosacry rosacry force-pushed the fix/training-robustness branch from 0af2864 to dae0669 Compare April 9, 2026 18:11

wandb.log_artifact(art, aliases=["best"] if is_best else None)
logger.info(f"Logged checkpoint to wandb: {art.name}")
try:

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need a try or was the only exception encountered about the non-finite loss? I think we can get the rid of the try and reduce nesting too with some early returns while we're at it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants