feat: stratified per-sample validation timesteps#1436
Draft
BitcrushedHeart wants to merge 2 commits into
Draft
Conversation
Replaces the hardcoded validation timestep (midpoint of the scheduler) with deterministic stratified sampling across the full timestep distribution. Validation loss is now averaged over the whole range of noise levels, so it reflects the model's compositional and detail behaviour rather than just the scheduler midpoint. Two users with identical configs and datasets still get identical validation loss. Per-sample assignments are recomputed from the sample index each epoch via a NumPy SeedSequence, so resumed runs reproduce the same timesteps and noise without needing to serialise extra state. A new global Validation Timestep Shift (default 1.0) lives directly under Timestep Shift in the Training tab, and per-concept VALIDATION concepts can override it through a nullable entry in the concept editor (blank inherits). Old configs without the new fields load with the documented defaults, so this is fully backwards compatible.
Drop redundant int()/float() casts on already-typed parameters, replace getattr fallback with direct attribute access, remove the over-defensive OSError guard on the concept-file fallback (the validation data loader has already opened the same file by this point), and shorten the new tooltip to match the surrounding register.
Collaborator
|
Contributor
Author
|
Moving to draft for reworking. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Replaces the single hardcoded validation timestep (the midpoint of the scheduler, e.g. 499 for 1000-step diffusion) with deterministic stratified sampling across the full timestep distribution. Validation loss is now averaged over the entire noise range instead of being measured at one point.
The previous behaviour gives a snapshot of the model's compositional/detail balance at the midpoint, which is informative but not representative of how the model behaves at low or high noise. Loss at timestep 499 alone tends to wobble run-to-run in ways that don't track final sample quality. Spreading samples across the distribution and averaging gives a more stable signal that correlates better with actual generative behaviour.
Two users with identical configs and datasets still get bit-identical validation loss - the stratified positions and per-sample noise seeds are derived from a fixed seed via a NumPy SeedSequence, so the assignments are fully reproducible.
How It Works
For a validation set of N samples, sample i gets:
np.random.default_rng([VAL_SEED, TIMESTEP_TAG, i]).random(). Each sample's position falls in its own [i/N, (i+1)/N) bucket so coverage is even at any N.ModelSetupNoiseMixin._get_timestep_discrete, just rearranged to operate on a unit value before the integer mapping.num_train_timestepsand clamped to a valid integer.np.random.default_rng([VAL_SEED, NOISE_TAG, i]).integers(0, 2**31).Both streams are derived from
iand recomputed each validation pass, so resumed runs reproduce the same timesteps and noise without needing to persist any extra state inmeta.json.The trainer mutates the validation batch dict with two private keys (
__val_timestep_unit__,__val_noise_seed__) before passing it tomodel_setup.predict()._get_timestep_discreteand_get_timestep_continuouspick those up; if absent they fall back to the existing midpoint behaviour. That keeps the Tools-menu "Calculate losses" path (GenerateLossesModel) unchanged - it doesn't set the keys, so it still reports loss at the midpoint as before.UI
A new "Validation Timestep Shift" entry sits directly underneath "Timestep Shift" in the noise frame of the Training tab. Default 1.0 (identity, matches the existing distribution). When Dynamic Timestep Shifting is supported, that toggle moves down one row.
In the concept editor, VALIDATION concepts get a per-concept "Validation Timestep Shift" entry. Leaving it blank means inherit the global value; setting a numeric value overrides it for that concept's samples. The field is harmless on non-validation concepts (the trainer only consults it for samples that come from VALIDATION concepts) but is shown for all concepts to keep the editor uniform.
Backwards Compatibility
validation_timestep_shiftis added to TrainConfig with default 1.0 and to ConceptConfig with default None. Old saved configs missing the field load throughBaseConfig.from_dict()and pick up the defaults silently - no migration function needed. No existing field is renamed or removed.The single-midpoint behaviour is preserved for any caller that doesn't populate
__val_timestep_unit__in the batch. The only place that does set it is the validation loop inGenericTrainer.__validate().Files Changed
modules/util/validation_timestep.py- new helper module with the stratified position, noise seed, and shift transformmodules/util/config/TrainConfig.py- globalvalidation_timestep_shiftfield, default 1.0modules/util/config/ConceptConfig.py- per-concept nullable override, default Nonemodules/modelSetup/mixin/ModelSetupNoiseMixin.py-validation_overrideparameter on the deterministic branches of_get_timestep_discreteand_get_timestep_continuousmodules/trainer/GenericTrainer.py- precomputes per-sample positions and noise seeds, resolves global vs per-concept shift, mutates the validation batch dictmodules/modelSetup/Base*Setup.py(14 files) - one-line change to the deterministic batch seed and one-line addition ofvalidation_override=batch.get("__val_timestep_unit__")to the existing timestep callmodules/ui/TrainingTab.py- new entry under Timestep Shift in the noise framemodules/ui/ConceptWindow.py- per-concept entry in the general tab21 files changed, 135 insertions, 24 deletions.
Testing Notes
Local pytest suite covers stratification (every position in its bucket at N = 10/50/100), shift transform monotonicity and endpoint stability at shifts 0.5/1.0/2.0/3.0, algebraic equivalence to the existing N-domain shift, determinism across reruns, persistence across simulated epoch boundaries, per-concept override resolution, backward-compat (configs missing the field load with defaults and round-trip cleanly), N=1 edge case, mixin override correctness for both discrete and continuous, and the
GenerateLossesModelno-override path still returning the midpoint. All 34 tests pass.Tested on Windows 11, Python 3.10. The local test file is not part of this PR.
Didn't run a full GPU validation cycle since my GPU was busy on another job; the change is metadata-plumbing plus a deterministic algorithm, so the blast radius on actual training and forward-pass behaviour is limited to the timestep that gets fed into existing code paths.