Skip to content

feat: stratified per-sample validation timesteps#1436

Draft
BitcrushedHeart wants to merge 2 commits into
Nerogar:masterfrom
BitcrushedHeart:validation-timestep-stratified
Draft

feat: stratified per-sample validation timesteps#1436
BitcrushedHeart wants to merge 2 commits into
Nerogar:masterfrom
BitcrushedHeart:validation-timestep-stratified

Conversation

@BitcrushedHeart
Copy link
Copy Markdown
Contributor

Description

Replaces the single hardcoded validation timestep (the midpoint of the scheduler, e.g. 499 for 1000-step diffusion) with deterministic stratified sampling across the full timestep distribution. Validation loss is now averaged over the entire noise range instead of being measured at one point.

The previous behaviour gives a snapshot of the model's compositional/detail balance at the midpoint, which is informative but not representative of how the model behaves at low or high noise. Loss at timestep 499 alone tends to wobble run-to-run in ways that don't track final sample quality. Spreading samples across the distribution and averaging gives a more stable signal that correlates better with actual generative behaviour.

Two users with identical configs and datasets still get bit-identical validation loss - the stratified positions and per-sample noise seeds are derived from a fixed seed via a NumPy SeedSequence, so the assignments are fully reproducible.

How It Works

For a validation set of N samples, sample i gets:

  1. A stratified position in the unit interval: pos_i = (i + jitter_i) / N, where jitter_i is drawn from np.random.default_rng([VAL_SEED, TIMESTEP_TAG, i]).random(). Each sample's position falls in its own [i/N, (i+1)/N) bucket so coverage is even at any N.
  2. A per-concept (or global) timestep shift applied in the [0, 1] domain: u_i = s * pos_i / ((s - 1) * pos_i + 1). This is algebraically equivalent to the existing N-domain shift in ModelSetupNoiseMixin._get_timestep_discrete, just rearranged to operate on a unit value before the integer mapping.
  3. The unit position multiplied by num_train_timesteps and clamped to a valid integer.
  4. A per-sample noise seed: np.random.default_rng([VAL_SEED, NOISE_TAG, i]).integers(0, 2**31).

Both streams are derived from i and recomputed each validation pass, so resumed runs reproduce the same timesteps and noise without needing to persist any extra state in meta.json.

The trainer mutates the validation batch dict with two private keys (__val_timestep_unit__, __val_noise_seed__) before passing it to model_setup.predict(). _get_timestep_discrete and _get_timestep_continuous pick those up; if absent they fall back to the existing midpoint behaviour. That keeps the Tools-menu "Calculate losses" path (GenerateLossesModel) unchanged - it doesn't set the keys, so it still reports loss at the midpoint as before.

UI

A new "Validation Timestep Shift" entry sits directly underneath "Timestep Shift" in the noise frame of the Training tab. Default 1.0 (identity, matches the existing distribution). When Dynamic Timestep Shifting is supported, that toggle moves down one row.

In the concept editor, VALIDATION concepts get a per-concept "Validation Timestep Shift" entry. Leaving it blank means inherit the global value; setting a numeric value overrides it for that concept's samples. The field is harmless on non-validation concepts (the trainer only consults it for samples that come from VALIDATION concepts) but is shown for all concepts to keep the editor uniform.

Backwards Compatibility

validation_timestep_shift is added to TrainConfig with default 1.0 and to ConceptConfig with default None. Old saved configs missing the field load through BaseConfig.from_dict() and pick up the defaults silently - no migration function needed. No existing field is renamed or removed.

The single-midpoint behaviour is preserved for any caller that doesn't populate __val_timestep_unit__ in the batch. The only place that does set it is the validation loop in GenericTrainer.__validate().

Files Changed

  • modules/util/validation_timestep.py - new helper module with the stratified position, noise seed, and shift transform
  • modules/util/config/TrainConfig.py - global validation_timestep_shift field, default 1.0
  • modules/util/config/ConceptConfig.py - per-concept nullable override, default None
  • modules/modelSetup/mixin/ModelSetupNoiseMixin.py - validation_override parameter on the deterministic branches of _get_timestep_discrete and _get_timestep_continuous
  • modules/trainer/GenericTrainer.py - precomputes per-sample positions and noise seeds, resolves global vs per-concept shift, mutates the validation batch dict
  • modules/modelSetup/Base*Setup.py (14 files) - one-line change to the deterministic batch seed and one-line addition of validation_override=batch.get("__val_timestep_unit__") to the existing timestep call
  • modules/ui/TrainingTab.py - new entry under Timestep Shift in the noise frame
  • modules/ui/ConceptWindow.py - per-concept entry in the general tab

21 files changed, 135 insertions, 24 deletions.

Testing Notes

Local pytest suite covers stratification (every position in its bucket at N = 10/50/100), shift transform monotonicity and endpoint stability at shifts 0.5/1.0/2.0/3.0, algebraic equivalence to the existing N-domain shift, determinism across reruns, persistence across simulated epoch boundaries, per-concept override resolution, backward-compat (configs missing the field load with defaults and round-trip cleanly), N=1 edge case, mixin override correctness for both discrete and continuous, and the GenerateLossesModel no-override path still returning the midpoint. All 34 tests pass.

Tested on Windows 11, Python 3.10. The local test file is not part of this PR.

Didn't run a full GPU validation cycle since my GPU was busy on another job; the change is metadata-plumbing plus a deterministic algorithm, so the blast radius on actual training and forward-pass behaviour is limited to the timestep that gets fed into existing code paths.

Replaces the hardcoded validation timestep (midpoint of the scheduler) with
deterministic stratified sampling across the full timestep distribution.

Validation loss is now averaged over the whole range of noise levels, so it
reflects the model's compositional and detail behaviour rather than just the
scheduler midpoint. Two users with identical configs and datasets still get
identical validation loss.

Per-sample assignments are recomputed from the sample index each epoch via a
NumPy SeedSequence, so resumed runs reproduce the same timesteps and noise
without needing to serialise extra state.

A new global Validation Timestep Shift (default 1.0) lives directly under
Timestep Shift in the Training tab, and per-concept VALIDATION concepts can
override it through a nullable entry in the concept editor (blank inherits).

Old configs without the new fields load with the documented defaults, so this
is fully backwards compatible.
Drop redundant int()/float() casts on already-typed parameters, replace getattr
fallback with direct attribute access, remove the over-defensive OSError guard
on the concept-file fallback (the validation data loader has already opened the
same file by this point), and shorten the new tooltip to match the surrounding
register.
@dxqb
Copy link
Copy Markdown
Collaborator

dxqb commented May 10, 2026

  • if you use AI to submit PRs, please guide it towards briefness and only have it mention the points you think are important to understand the PR
  • this PR is similar to Validation timesteps #821 - could you unify them and finish one of them?
  • I'm not sure how useful it is to sample timesteps from a distribution for validation. the sampled timesteps have to be deterministic and most people's validation set will be quite small. so your random sampling is very much luck. This has been discussed in this thread [Feat]: Same timestep used for all validation steps #772 and that's why Validation timesteps #821 chose to use user-defined values
  • does this PR still report the average across all timesteps to the validation graph, right? In the same discussion [Feat]: Same timestep used for all validation steps #772 there are graphs that show an average across timesteps doesn't make much sense because the scales are different. The higher timesteps will dominate the average
  • GenericTrainer.py should remain agnostic about what model it trains - could even be an audio model that doesn't have timesteps. so it shouldn't have to know about timesteps. See Validation timesteps #821 for one idea how to solve that conflict
  • I haven't fully understood some of the features of this PR (such as the noise sampling), so please correct me where I'm wrong with the points above

@BitcrushedHeart BitcrushedHeart marked this pull request as draft May 17, 2026 21:59
@BitcrushedHeart
Copy link
Copy Markdown
Contributor Author

Moving to draft for reworking.

@dxqb dxqb mentioned this pull request May 24, 2026
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants