Fix training crashes from corrupt beatmaps and wandb NaN loss#98
Fix training crashes from corrupt beatmaps and wandb NaN loss#98rosacry wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Pull request overview
This PR improves training robustness by preventing two known crash modes: malformed beatmaps during dataset iteration and wandb failures during checkpoint artifact logging.
Changes:
- Skip corrupt beatmaps by catching
ValueError/IndexErroraround_get_next_beatmap()iteration. - Sanitize
current_lossbefore adding it to wandb artifact metadata to avoid NaN/Inf-related wandb crashes. - Wrap wandb artifact logging in a
try/exceptso wandb issues don’t block checkpoint saving.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
osuT5/osuT5/utils/train_utils.py |
Makes wandb artifact logging resilient to NaN/Inf loss values and to wandb logging exceptions. |
osuT5/osuT5/dataset/mmrs_dataset.py |
Prevents training crashes by skipping beatmaps that raise parsing/processing errors during iteration. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| except Exception as e: | ||
| logger.warning(f"Failed to log checkpoint artifact to wandb: {e}") |
There was a problem hiding this comment.
The broad except Exception as e logs only the exception message and drops the traceback, which makes wandb failures hard to diagnose. Use logger.exception(...) (or logger.warning(..., exc_info=True)) so the stack trace is preserved while still keeping checkpoint saving non-blocking.
| except Exception as e: | |
| logger.warning(f"Failed to log checkpoint artifact to wandb: {e}") | |
| except Exception: | |
| logger.exception("Failed to log checkpoint artifact to wandb") |
| safe_loss = shared.current_loss if math.isfinite(shared.current_loss) else 0.0 | ||
| art = wandb.Artifact( | ||
| f"osuT5-{wandb.run.id}", | ||
| type="model", | ||
| metadata={ | ||
| "format": "accelerate", | ||
| "src_seq_len": args.data.src_seq_len, | ||
| "tgt_seq_len": args.data.tgt_seq_len, | ||
| "num_classes": args.data.num_classes, | ||
| "num_diff_classes": args.data.num_diff_classes, | ||
| "max_difficulty": args.data.max_diff, | ||
| "class_dropout_prob": args.data.class_dropout_prob, | ||
| "diff_dropout_prob": args.data.diff_dropout_prob, | ||
| "spectrogram": args.model.spectrogram, | ||
| "current_train_step": shared.current_train_step, | ||
| "current_epoch": shared.current_epoch, | ||
| "current_loss": safe_loss, | ||
| }, |
There was a problem hiding this comment.
safe_loss maps non-finite shared.current_loss (including the initial np.inf) to 0.0, which makes artifact metadata misleading (it will look like loss improved to 0). Prefer recording None/"non_finite" or omitting current_loss from metadata when it isn't finite, and optionally add a separate boolean flag (e.g., current_loss_is_finite).
| safe_loss = shared.current_loss if math.isfinite(shared.current_loss) else 0.0 | |
| art = wandb.Artifact( | |
| f"osuT5-{wandb.run.id}", | |
| type="model", | |
| metadata={ | |
| "format": "accelerate", | |
| "src_seq_len": args.data.src_seq_len, | |
| "tgt_seq_len": args.data.tgt_seq_len, | |
| "num_classes": args.data.num_classes, | |
| "num_diff_classes": args.data.num_diff_classes, | |
| "max_difficulty": args.data.max_diff, | |
| "class_dropout_prob": args.data.class_dropout_prob, | |
| "diff_dropout_prob": args.data.diff_dropout_prob, | |
| "spectrogram": args.model.spectrogram, | |
| "current_train_step": shared.current_train_step, | |
| "current_epoch": shared.current_epoch, | |
| "current_loss": safe_loss, | |
| }, | |
| current_loss_is_finite = math.isfinite(shared.current_loss) | |
| artifact_metadata = { | |
| "format": "accelerate", | |
| "src_seq_len": args.data.src_seq_len, | |
| "tgt_seq_len": args.data.tgt_seq_len, | |
| "num_classes": args.data.num_classes, | |
| "num_diff_classes": args.data.num_diff_classes, | |
| "max_difficulty": args.data.max_diff, | |
| "class_dropout_prob": args.data.class_dropout_prob, | |
| "diff_dropout_prob": args.data.diff_dropout_prob, | |
| "spectrogram": args.model.spectrogram, | |
| "current_train_step": shared.current_train_step, | |
| "current_epoch": shared.current_epoch, | |
| "current_loss_is_finite": current_loss_is_finite, | |
| "current_loss": shared.current_loss if current_loss_is_finite else None, | |
| } | |
| art = wandb.Artifact( | |
| f"osuT5-{wandb.run.id}", | |
| type="model", | |
| metadata=artifact_metadata, |
- Skip corrupt beatmaps during training instead of crashing (catch ValueError/IndexError) - Sanitize NaN/Inf loss values before logging to wandb artifacts - Wrap wandb artifact logging in try/except to prevent checkpoint save failures
0af2864 to
dae0669
Compare
|
|
||
| wandb.log_artifact(art, aliases=["best"] if is_best else None) | ||
| logger.info(f"Logged checkpoint to wandb: {art.name}") | ||
| try: |
There was a problem hiding this comment.
does this need a try or was the only exception encountered about the non-finite loss? I think we can get the rid of the try and reduce nesting too with some early returns while we're at it.
Fixes two training stability issues that cause crashes during long training runs.
Changes
Corrupt beatmap handling (
mmrs_dataset.py)_get_next_beatmap()in try/except forValueErrorandIndexErrorwandb artifact logging (
train_utils.py)current_losswithmath.isfinite()before passing to wandb artifact metadata (NaN/Inf values crash wandb)