Skip to content

Implement RLHF DPO (Direct Preference Optimization) training#1403

Draft
BitcrushedHeart wants to merge 24 commits into
Nerogar:masterfrom
BitcrushedHeart:RLHF
Draft

Implement RLHF DPO (Direct Preference Optimization) training#1403
BitcrushedHeart wants to merge 24 commits into
Nerogar:masterfrom
BitcrushedHeart:RLHF

Conversation

@BitcrushedHeart
Copy link
Copy Markdown
Contributor

DPO in OneTrainer

What This Is

OneTrainer's DPO implementation lets you show the model two images for the same prompt - one you prefer, one you don't - and trains an adapter to produce more of the former. The reference model is kept adapter-sized (either the raw base or a frozen snapshot of your existing LoRA), so the whole thing runs on a single consumer GPU without doubling your VRAM budget.

It's model-agnostic. The loss function doesn't know or care whether you're training on Flux, SDXL, SD 1.5, Z-Image, or anything else OneTrainer supports. It hooks into the existing predict() pipeline and works with whatever prediction targets the model already uses.

How DPO Works Here

Standard DPO compares how much more the policy model prefers the chosen sample over the rejected one, relative to how much the reference model does. The loss is:

L = -log(sigmoid(beta * ((score_policy(chosen) - score_ref(chosen))
                        - (score_policy(rejected) - score_ref(rejected)))))

In language models, those scores are token log-probabilities. We don't have those. What we do have is the prediction error from the denoising objective - the same MSE that standard training minimises. A lower MSE means the model "agrees" more with that image at that timestep, so we use negative MSE as a score proxy:

score(image) = -mean((predicted - target)^2)

This is reduced across all non-batch dimensions, giving one scalar per sample. The key insight is that this works regardless of prediction type - epsilon, v-prediction, flow velocity - because predict() already returns the right (predicted, target) pair for each model family. The DPO loss never needs to know which one it's looking at.

The default beta is 5000, which looks absurdly high compared to the 0.1-0.5 typical in LLM DPO. The reason is scale: MSE values are tiny compared to token log-probs, so you need a correspondingly larger beta to keep the sigmoid in a useful gradient range rather than having it saturate immediately.

Two optional extras sit on top of the base loss. Label smoothing interpolates between the DPO loss and its complement, which helps when your preference labels are noisy (i.e. you weren't fully sure which image was better). Supervised mix adds a standard training loss from the chosen image to prevent the adapter drifting too far from what it already knows. The supervised term reuses the chosen policy forward pass, so it doesn't cost an extra forward pass.

Training Types

There are two modes, and you don't pick them directly - the code infers which one you're using from whether you've loaded a base adapter.

New Adapter means no base adapter is loaded. The reference model is the raw base with no adapter applied. The policy is a fresh adapter training from scratch. Use this for general quality preferences - "prefer sharp images over blurry ones" - where you don't have a prior fine-tune to build on.

Existing Adapter means you've loaded a supervised LoRA/OFT checkpoint. The reference model is a frozen snapshot of those adapter weights taken at the start of training. The policy starts from the same weights and diverges during DPO. Use this when you've already got a character or style LoRA and want to refine its outputs - "my character LoRA is good but sometimes fumbles hands, so here's examples of good hands vs bad hands."

The output is always an adapter file in both cases.

The Reference Model

This is the part that makes DPO practical on hobbyist hardware.

For New Adapter mode, the reference pass just temporarily unhooks the adapter from the model, runs a forward pass through the raw base weights, and hooks the adapter back in. Simple.

For Existing Adapter mode, the implementation clones every adapter parameter tensor once at the start of training - this is the frozen reference snapshot. During each reference forward pass, it swaps param.data pointers from the live training weights to the frozen snapshot, runs the forward pass, and swaps back in a finally block. This is O(1) pointer assignment per parameter, not a copy. The base model weights are shared between policy and reference the whole time. The only duplicated data is the adapter tensors themselves - typically 50-100MB for a LoRA.

No second model is loaded. No base weights are duplicated. The "reference model" is a list of frozen tensors and a pointer swap.

Forward Pass Scheduling

DPO needs four forward passes per training step: reference-chosen, reference-rejected, policy-chosen, policy-rejected. How you schedule those is a VRAM/speed tradeoff exposed through the execution mode setting.

Sequential (default): Run all four one at a time, deleting activations between each. Both reference passes are no_grad so they don't retain activations anyway. The policy chosen pass runs, its output is deleted, then the policy rejected pass runs. Peak VRAM is roughly the same as standard LoRA training. Slowest, but fits on 24GB.

Policy Concurrent: Reference passes still run sequentially and get cleaned up. Both policy passes keep their activations alive simultaneously - the chosen output isn't deleted before the rejected pass starts. This saves recomputation during backward at the cost of holding two gradient-tracked passes in memory. Uses more VRAM than Sequential.

Full Concurrent: All four outputs stay alive until the scores are computed. Fastest scheduling, highest memory. Uses considerably more VRAM than either of the other modes.

The execution mode doesn't change the maths. Same loss, same gradients, same result - just different memory/speed profiles.

Shared Noise

When shared noise is on (the default), the chosen and rejected forward passes use the same timestep and the same noise. This is achieved through a slightly indirect mechanism: both passes receive TrainProgress objects with the same global_step, which seeds the RNG identically inside predict().

The reasoning is simple - if chosen and rejected get different noise draws, some of the gradient signal comes from the noise difference rather than the preference difference. Sharing noise isolates the preference signal.

When shared noise is off, the rejected pass gets global_step + 1, giving it an independent RNG stream. This can be useful for more diverse gradient signals at the cost of noisier loss curves.

Worth noting: this mechanism depends on predict() seeding its RNG from global_step. If that ever changes, the noise sharing breaks silently. There's a comment in the code about this, but it's a real coupling rather than a guaranteed interface.

Data Pipeline

Concept Types

DPO uses four explicit concept types: DPO_CHOSEN, DPO_REJECTED, DPO_CHOSEN_VAL, and DPO_REJECTED_VAL. There's no adjacency-based inference and no fallback to standard concepts. If you want DPO, you configure DPO concepts.

Pair resolution is strict. The code collects all enabled chosen and rejected concepts, errors if both sides aren't present, errors if the counts don't match, and zips them in config order.

Filename Matching

Within each concept pair, PairByFilename matches samples by filename stem relative to the concept root, with extensions stripped and path separators normalised. This allows subdirectory structures as long as they mirror each other.

The pairing module fails fast if it finds unmatched files on either side, if prompts differ between matched pairs, or if crop resolutions differ. DPO pairs are supposed to differ only in quality, not in what they depict or how they're framed.

Augmentation Suppression

All augmentations are disabled for DPO concepts - both in the runtime dataloader and in the concept editor UI. This includes image and text variations, crop jitter, flips, rotation, colour adjustments, mask transforms, tag shuffling, tag dropout, and capitalisation randomisation.

The logic is simple: if chosen and rejected images get different augmentations applied, you're teaching the model to prefer augmentation A over augmentation B, not to prefer the image quality you actually selected for. Augmentations are disabled entirely rather than synchronised across pairs, because synchronisation would add complexity for minimal benefit.

Curation Tool

The DPO Pair Tool (launched from the Tools tab) handles preference pair creation from generated images.

It reads prompt and aspect ratio metadata from image files - PNG text chunks for SwarmUI/ComfyUI metadata, raw byte scanning for JPEG/WebP - and groups images by exact (prompt, aspectratio) match. From there it supports two scoring modes.

ELO mode shows two images side by side and lets you pick a winner. Keyboard driven - left arrow, right arrow, down to skip. Adaptive pairing prefers images with similar ratings for more informative comparisons.

Direct selection mode shows all images in a group and lets you pick the best and worst directly. Faster, less rigorous.

A "Pairs per group" setting (default 1) controls how many pairs are extracted per prompt group before advancing. Images are removed from the pool after being used in a pair, so there's no reuse within a group. If fewer than two unused images remain, the tool advances early.

Export produces chosen/train/, chosen/val/, rejected/train/, rejected/val/ with matched filenames, caption text files, and a concepts.json ready for OneTrainer. Validation splitting is done at the prompt-group level so no prompt appears in both splits.

Validation

When DPO validation is enabled, the trainer doesn't fall back to per-concept reconstruction MSE. Instead it runs calculate_dpo_loss() under torch.no_grad() on the validation pairs and logs:

  • dpo/val_loss - the raw DPO loss term (not the mixed training loss)
  • dpo/val_accuracy - fraction of val pairs where the policy prefers chosen over rejected more than the reference does
  • dpo/val_chosen_reward - mean chosen preference ratio
  • dpo/val_rejected_reward - mean rejected preference ratio

One caveat: dpo/val_loss logs the pure DPO term, not the total loss including supervised mix. If you're running with supervised_mix > 0, the training loss and validation loss are measuring slightly different things. This doesn't affect early stopping since patience is based on accuracy and chosen reward, not val_loss directly.

Early Stopping

Patience is tracked against dpo/val_accuracy and dpo/val_chosen_reward, not reconstruction MSE.

Both best values start at negative infinity. A metric is considered stalled when the new value is less than or equal to the best seen so far. Two patience modes:

  • Either: increment patience counter if accuracy stalls OR chosen reward stalls
  • Both: increment only if both stall simultaneously

Any improvement resets the counter. When the counter hits the configured patience value, training stops.

This is deliberately tied to the DPO validation metrics rather than MSE because you can have both chosen and rejected MSE decreasing (the model gets better at denoising everything) while the preference separation stays flat or degrades. MSE trends tell you about reconstruction quality; DPO metrics tell you about preference learning. Early stopping should care about the latter.

TensorBoard Metrics

Training logs: loss/train_step, loss/dpo (total after smoothing + supervised mix), dpo/raw_loss (pure DPO term), dpo/chosen_reward, dpo/rejected_reward, dpo/accuracy.

Validation logs: dpo/val_loss, dpo/val_accuracy, dpo/val_chosen_reward, dpo/val_rejected_reward.

Diagnostic interpretation: accuracy approaching 1.0 at initialisation means your pairs are too easy - the policy already prefers chosen over rejected relative to the reference before any training has happened. Beta doesn't affect accuracy directly since accuracy is a pure comparison on the ratios, not the scaled logits. Accuracy stuck at 0.5 means the model isn't learning the preference - check data quality. If dpo/raw_loss diverges while loss/dpo looks stable, the supervised mix is masking a problem. At beta=0, the loss should equal log(2) regardless of model state.

Startup Validation

If RLHF is enabled but the configuration is wrong, the trainer fails during initialisation rather than mid-run.

The training method gate checks that you're using LoRA - DPO with full fine-tuning raises NotImplementedError immediately. Missing or mismatched DPO concepts raise RuntimeError when the dataloader is built, which happens before any training step. Unmatched filenames raise RuntimeError during dataset initialisation. There's also a defensive check in calculate_dpo_loss() for missing rejected latents, though in practice this shouldn't be reachable if the dataloader is configured correctly.

All four failure paths produce explicit exceptions during startup. None of them result in silent skipping or unhandled mid-training crashes.

Configuration

The DPO config fields, all prefixed rlhf_:

  • rlhf_enabled - master toggle
  • rlhf_mode - currently only DPO, exists for future extensibility
  • rlhf_dpo_beta - temperature, default 5000
  • rlhf_dpo_label_smoothing - for noisy preference labels, default 0
  • rlhf_dpo_ref_mode - NEW_ADAPTER or EXISTING_ADAPTER, auto-derived from whether a base adapter is loaded
  • rlhf_dpo_execution_mode - SEQUENTIAL, POLICY_CONCURRENT, or FULL_CONCURRENT
  • rlhf_supervised_mix - weight for supervised loss on chosen images, default 0
  • rlhf_dpo_shared_noise - share noise between chosen/rejected, default true
  • rlhf_dpo_validation - enable DPO-specific validation
  • rlhf_dpo_validation_percentage - train/val split percentage, default 10
  • rlhf_dpo_patience_enabled - enable early stopping
  • rlhf_dpo_patience_value - number of stalled validations before stopping
  • rlhf_dpo_patience_mode - EITHER or BOTH

rlhf_dpo_ref_mode is resynced from effective_dpo_ref_mode() on every config save/load, so it stays aligned with whether a base adapter is loaded regardless of what's in the JSON.

Testing

Unit tests covered the basics - enum contents, config round-tripping through JSON, loss behaviour including the beta=0 producing log(2) sanity check, all three execution modes returning finite loss and metrics, no-grad validation metrics, concept type semantics, and patience logic for both EITHER and BOTH modes.

The heavier coverage, which runs all six combinations of training type x execution mode (NEW_ADAPTER and EXISTING_ADAPTER x SEQUENTIAL, POLICY_CONCURRENT, FULL_CONCURRENT) using tiny fake tensors and a dummy adapter model on CPU. It checked finite loss and backward, adapter gradients, all five training metrics, reference snapshot integrity under Existing Adapter mode, TensorBoard event tags, and hot-swap correctness. No CUDA required.

Constraints

DPO is adapter-only in the LoRA path. The output is always an adapter file. Training type is derived, not user-selected. Pairing requires explicit chosen/rejected concepts with matching filenames, prompts, and crop resolutions. Augmentations are disabled, not synchronised. Validation uses DPO metrics directly. There's no reward model, no external reference path, and no full-parameter DPO.

Code Touchpoints

The implementation lives across:

  • BaseModelSetup.py - loss function, reference model context manager, metric caching
  • GenericTrainer.py - DPO loss branching, validation, patience, TensorBoard logging
  • TrainConfig.py - config fields, effective mode derivation, serialisation
  • DPORefMode.py, DPOExecutionMode.py, ConceptType.py - enums
  • dpo_curation_util.py - export, pair checking, shared utilities
  • PairByFilename.py - filename-based sample pairing
  • DataLoaderMgdsMixin.py - augmentation suppression, concept type filtering
  • DataLoaderText2ImageMixin.py - DPO output module construction, pair resolution
  • RLHFTab.py - settings UI
  • DPOCurationWindow.py - pair curation tool
  • ConceptWindow.py - DPO concept type handling in the editor

@yamatazen
Copy link
Copy Markdown

Any example images?

Images now scale up to fill the display box (not just down), and
prompts are shown in a collapsible expander instead of truncated.
- Fix failing export test to match 7-value return signature
- Add ValueError for unsupported DPO reference modes
- Remove dead hasattr check in reference_model
- Close PIL file handles promptly in DPO curation window
- Read only first 256KB for metadata extraction with fallback
- Add config migration 14 for transfer_* fields
- Add DPO loss math integration tests including beta=0 sanity check
SwarmUI embeds metadata in WebP EXIF UserComment as UTF-16LE
(UNICODE character code), causing prompt extraction to fail.
Fall back to UTF-16LE/BE decoding when UTF-8 finds no metadata.
Output folder is now selected at start. Each pick exports the
chosen/rejected pair to disk instantly and updates a manifest.json,
so closing mid-session loses no work. On restart, completed groups
are auto-skipped via manifest lookup. Groups are shuffled for
variety. Finalization (train/val split + concepts.json) is a
separate optional step at the end.
Metadata scan and dhash dedup now run in a background thread,
feeding a Queue(maxsize=10) of ready groups. The UI shows a live
scan counter and transitions to picking as soon as the first group
is ready. If the user outruns the dedup, a brief "Preparing..."
spinner appears. Groups already completed in the manifest are
skipped before deduping, saving the most expensive work.
After accepting a chosen/rejected pair, if there are still 2+
remaining images in the group, a Yes/No/Cancel dialog lets the
user continue scoring the same prompt or move on. Works in both
ELO and Selection modes. Cancel acts as an undo for mis-clicks.
@Silvicultor
Copy link
Copy Markdown

It reads prompt and aspect ratio metadata from image files - PNG text chunks for SwarmUI/ComfyUI metadata

Very interesting PR! I wanna test it, but right now ComfyUI metadata does not seem to work. If I select an input folder with ComfyUI-generated images (from a very basic SDXL workflow), I get "no images with extractable prompt metadata found".
Would be nice if either (or ideally both) ComfyUI or Forge metadata could be supported.

@BitcrushedHeart
Copy link
Copy Markdown
Contributor Author

It reads prompt and aspect ratio metadata from image files - PNG text chunks for SwarmUI/ComfyUI metadata

Very interesting PR! I wanna test it, but right now ComfyUI metadata does not seem to work. If I select an input folder with ComfyUI-generated images (from a very basic SDXL workflow), I get "no images with extractable prompt metadata found". Would be nice if either (or ideally both) ComfyUI or Forge metadata could be supported.

If you can send me an image on Discord from ComfyUI (any image, it doesn't matter), I can add this for you now and push it. If you have any from Forge as well I'll do the same - I don't have either of these so I don't know what they expect.

Add save-best option that snapshots model weights when validation accuracy
improves and restores them at end of training. Simplify patience to track
accuracy only (remove DPOPatienceMode). Add review mode to the DPO pair
tool with orphan detection and pair removal. Default execution mode is now
Full Concurrent. Polish curation UI with card layout and progress bar.
@BitcrushedHeart BitcrushedHeart force-pushed the RLHF branch 2 times, most recently from 5b55d90 to e1b8b09 Compare April 5, 2026 17:14
Parse plain-text A1111/Forge parameters (positive prompt before
"Negative prompt:" marker) and ComfyUI workflow JSON (trace KSampler
positive conditioning to CLIPTextEncode node, with fallback to longest
text encoder output). Handles SDXL/SD3/Flux text encoder variants.
@BitcrushedHeart
Copy link
Copy Markdown
Contributor Author

It reads prompt and aspect ratio metadata from image files - PNG text chunks for SwarmUI/ComfyUI metadata

Very interesting PR! I wanna test it, but right now ComfyUI metadata does not seem to work. If I select an input folder with ComfyUI-generated images (from a very basic SDXL workflow), I get "no images with extractable prompt metadata found". Would be nice if either (or ideally both) ComfyUI or Forge metadata could be supported.

Forge & Comfy now supported in latest commit.

@Silvicultor
Copy link
Copy Markdown

I did some testing with the PR:
(1) There is a bug in the UI: Currently only the ELO mode in the DPO pair tool in functional on Linux. In the direct selection mode you can’t view the images fully, if you try, the whole screen becomes white and OT Python process needs to be terminated to regain control.
Error log:

Exception in Tkinter callback
Traceback (most recent call last):
  File "/usr/lib/python3.12/tkinter/__init__.py", line 1967, in __call__
    return self.func(*args)
           ^^^^^^^^^^^^^^^^
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/venv/lib/python3.12/site-packages/customtkinter/windows/widgets/ctk_button.py", line 554, in _clicked
    self._command()
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/modules/ui/TrainUI.py", line 699, in open_dpo_curation_tool
    DPOCurationWindow(self)
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/modules/ui/DPOCurationWindow.py", line 75, in __init__
    self.grab_set()
  File "/usr/lib/python3.12/tkinter/__init__.py", line 963, in grab_set
    self.tk.call('grab', 'set', self._w)
_tkinter.TclError: grab failed: window not viewable
Exception in Tkinter callback
Traceback (most recent call last):
  File "/usr/lib/python3.12/tkinter/__init__.py", line 1967, in __call__
    return self.func(*args)
           ^^^^^^^^^^^^^^^^
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/modules/ui/DPOCurationWindow.py", line 542, in <lambda>
    label.bind("<Button-1>", lambda e, p=path: self._selection_preview(p))
                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/modules/ui/DPOCurationWindow.py", line 552, in _selection_preview
    preview.grab_set()
  File "/usr/lib/python3.12/tkinter/__init__.py", line 963, in grab_set
    self.tk.call('grab', 'set', self._w)
_tkinter.TclError: grab failed: window not viewable

(2) Test training with 15 pairs (NoobAI vpred) worked without error. I see minimal improvement of my concept in sample images. Tho the dpo/val_accuracy graph does not really reflect it. But I guess can’t expect that with literally only 1 val pair.

Patience now uses DPOPatienceMode (EITHER/BOTH) and tracks val_loss
alongside accuracy with rounding to 5 decimal places. Add visual
pair review accessible from RLHF tab that merges train/val splits
and supports orphan detection. Fix grab_set crash on Linux by adding
wait_visibility before grab.
@BitcrushedHeart
Copy link
Copy Markdown
Contributor Author

I did some testing with the PR: (1) There is a bug in the UI: Currently only the ELO mode in the DPO pair tool in functional on Linux. In the direct selection mode you can’t view the images fully, if you try, the whole screen becomes white and OT Python process needs to be terminated to regain control. Error log:

Exception in Tkinter callback
Traceback (most recent call last):
  File "/usr/lib/python3.12/tkinter/__init__.py", line 1967, in __call__
    return self.func(*args)
           ^^^^^^^^^^^^^^^^
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/venv/lib/python3.12/site-packages/customtkinter/windows/widgets/ctk_button.py", line 554, in _clicked
    self._command()
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/modules/ui/TrainUI.py", line 699, in open_dpo_curation_tool
    DPOCurationWindow(self)
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/modules/ui/DPOCurationWindow.py", line 75, in __init__
    self.grab_set()
  File "/usr/lib/python3.12/tkinter/__init__.py", line 963, in grab_set
    self.tk.call('grab', 'set', self._w)
_tkinter.TclError: grab failed: window not viewable
Exception in Tkinter callback
Traceback (most recent call last):
  File "/usr/lib/python3.12/tkinter/__init__.py", line 1967, in __call__
    return self.func(*args)
           ^^^^^^^^^^^^^^^^
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/modules/ui/DPOCurationWindow.py", line 542, in <lambda>
    label.bind("<Button-1>", lambda e, p=path: self._selection_preview(p))
                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/AI/RLHF_TEST_OT/OneTrainer/modules/ui/DPOCurationWindow.py", line 552, in _selection_preview
    preview.grab_set()
  File "/usr/lib/python3.12/tkinter/__init__.py", line 963, in grab_set
    self.tk.call('grab', 'set', self._w)
_tkinter.TclError: grab failed: window not viewable

(2) Test training with 15 pairs (NoobAI vpred) worked without error. I see minimal improvement of my concept in sample images. Tho the dpo/val_accuracy graph does not really reflect it. But I guess can’t expect that with literally only 1 val pair.

Fixed in 26e3b7a.

Validation accuracy is binary - with 1 validation pair it is literally a coin flip. Val loss would be a better metric for you in this case.

Images with no prompt metadata are now grouped under 'UNCONDITIONAL'
instead of being silently skipped. These groups have no pair limit
and are always shown for selection. Marked with a gold badge in the UI.
@Silvicultor
Copy link
Copy Markdown

@BitcrushedHeart I know you disabled all augmentations for the RLHF concepts, but did you also disable multi-line caption alternatives (the fact that a newline is considered a new caption by OT)?
If not, it might be a good idea to do so, otherwise a single newline in a longer prompt for a model that uses nat. lang. might sabotage the whole training.

- Replace PIL Image.open with raw PNG chunk parsing for metadata
  extraction — skips image data entirely, ~100x faster on large sets
- Detect and offer to fix multiline captions via Check Pairs button
- Auto-prune manifest entries whose files were deleted from disk
Strip <...> segments (e.g. <segment:face,0.5,0.6>) and clean up
resulting grammar artifacts when building prompt groups, so images
whose prompts differ only by angle-bracket metadata land in the
same curation group. Also normalizes manifest lookups for backward
compatibility with older entries.
When two images share a dhash, prefer the newest file by mtime rather
than the first path in sorted order.
The per-step DPO accuracy reported to TensorBoard was computed from only
the final micro-batch in the gradient-accumulation window, because
BaseModelSetup._last_dpo_metrics is overwritten on each calculate_dpo_loss
call. With batch_size=1 and gradient_accumulation_steps=N, that meant
each logged accuracy was mean() over a tensor of size 1 — always 0.0 or
1.0, regardless of effective batch size.

Accumulate loss/dpo_loss/chosen_reward/rejected_reward/accuracy over all
micro-batches in the grad-accum window in GenericTrainer and log the
mean once per optimizer step, alongside accumulated_loss. Validation is
unchanged; it already averages across batches explicitly.
…ional check

The DPO curation tool let the same image land on both sides of a pair
when a session was resumed against an existing manifest — most
visibly for unconditional groups, which have no pairs-per-group cap to
skip them on resume. Pairs now record their source paths, and the
background scan plus next-group pull filter out any image that already
appears as chosen or rejected in the manifest.

Prompt normalization moves into a shared `normalize_prompt_for_grouping`
helper that iterates bracket stripping until stable (so nested or
malformed brackets like `<a<b>>` collapse fully) and treats prompts whose
remaining content has no alphanumeric character as UNCONDITIONAL — stray
punctuation left behind by stripping no longer counts as guidance.

Adds a defensive guard in selection so picking the same image as both
best and worst is rejected rather than silently exported.

Test updated: manifest_pair_counts now keys on the normalized prompt,
so the assertion that previously expected the raw bracketed key was
already failing on this branch and is corrected to the normalized key.
…ks it

Most non-SwarmUI generators (ComfyUI, A1111/Forge, raw exports) don't
write an `aspectratio` metadata field. The DPO scan was using `meta.get
("aspectratio", "").strip()` directly, so every such image collapsed
into the same empty-AR bucket — and unconditional pairs then mixed
images of different shapes, producing tensor-shape mismatches at train
time.

`resolve_aspect_ratio` now falls back to PIL `Image.open(...).size`
when metadata is missing, reducing W:H by gcd so 1024x768 and 4096x3072
both bucket as `4:3`. PIL only reads the file header so the per-file
cost is negligible compared to the existing dhash dedup pass.
dhash was collapsing near-duplicate generations from different seeds,
which is exactly the discrimination signal DPO needs to learn — the
*minor* differences are the point of the pair. Replace it with a
BLAKE3 hash over raw file bytes so only literal duplicates (re-runs
that wrote the same file twice, manual copies) get dropped; anything
with even a one-byte difference now survives into the group.

BLAKE3 on raw bytes is also dramatically faster than dhash, which had
to PIL-decode and resize every image. update_mmap streams via memory
map; an OSError/ValueError fallback path handles empty files which
mmap can't open.

Adds blake3==1.0.8 to requirements-global.txt.
@dxqb dxqb self-requested a review May 10, 2026 17:07
Pre-commit ruff (v0.15.7) flags the dict comprehension
{k: 0.0 for k in micro_dpo_metrics} as C420 ("unnecessary dict
comprehension for iterable"). Replace with dict.fromkeys(...) to
unblock pre-commit.ci on PR Nerogar#1403.
@dxqb
Copy link
Copy Markdown
Collaborator

dxqb commented May 13, 2026

Thanks! I have a few questions mainly about the infrastructure that this PR adds, less so about DPO itself because it has been proven useful and we should clearly have it.

Image pairing

The positive and negative samples are two new concept types in your PR. That's one way to do that - the downside of that is that samples must be matched later - and that is done by filename (PairByFilename).

OneTrainer already has other training methods that need 2 images per sample:

  • masked training needs a mask
  • inpainting training needs a source image

And it could use such infrastructure for:

  • DPO
  • Edit training needs a reference image: Flux2 Edit training #1301
  • edit training could use multiple reference images
  • other RLHF techniques can use 2 images or 2 captions per sample, such as "image sliders"
  • ...

The current infrastructure for pairing these samples is to have a sample.png, and then a sample-condlabel.png or a sample-masklabel.png.
This isn't necessarily great, but it's a start that we already have - and I don't see the advantage of having 2 separate concepts and then matching them later by filename anyway. Why can't it be 1 concept with 2 files? If I have to prepare my datasets to have matching filenames anyway, why don't I prepare it into sample.png and sample-rejectedlabel.png?

Having these hardcoded labels isn't great of course. I was thinking about improving that by making it configurable in the concept as a pattern using https://pypi.org/project/parse/
For example, you could give the patterns {}.png and {}-dpolabel.png for the current way, or change the patterns to positive/{}.png and negative/{}.png if you want your pairs in two different folders.
This would be quite flexible without having to define 2 concepts. What do you think?

Another open question is how scaling and augmentations should work then. How does this work in your PR currently, do you require both images to be the same resolution / aspect ratio?
The current infrastructure for masks and inpainting does require it, but doesn't enfore it. it does inappropriate scaling otherwise. For some techniques I guess enforcing the same resolution is right, but some edit models such as Flux2 support that a reference image has a different resolution and aspect ratio than the training target.

Validation

Could you explain the value you found in having DPO validation loss?
Because I don't see it in theory. DPO loss starts at about 0.7 because of the logsigmoid loss function and then goes towards 0. Depending on your DPO beta, it learns appropriately and goes slowly towards 0 but it eventually reaches 0, when the difference between the positive vs the negative sample (or the model loss vs. the ref loss) has grown big enough. If your DPO beta isn't well-tuned, it reward-hacks and reaches a loss of 0 very quickly. Samples have degraded image quality then.

How is a DPO-validation-loss a useful value in this? You see training progress already in your train loss. Having a low validation loss doesn't mean much because of reward hacking, does it?

early stopping and patience follows from validation I think, so I'll have a look at that later.

Forward pass scheduling

Off the cuff, I'd say we only need Policy Concurrent. This is also how PRIOR_PREDICTION already works.
Could you explain how the other two are useful?

TensorBoard Metrics

Fully agree that this is useful, especially to see the chosen and rejected reward, for tuning DPO beta and LR.

Shared Noise

Again from theory only, I don't see how this can be useful.
By sampling 2 different timesteps for the chosen and rejected samples, you are comparing different tasks: Is this banana better than this car?
Could you explain how this is useful?

@dxqb dxqb removed their request for review May 13, 2026 22:15
calculate_dpo_loss previously ran 4 sequential forwards per step (ref_chosen,
ref_rejected, policy_chosen, policy_rejected) regardless of execution mode --
the three modes only varied when output tensors were freed and did no actual
concurrency. This change makes the modes a real speed knob:

- SEQUENTIAL: 4 passes, unchanged (lowest VRAM)
- POLICY_CONCURRENT: 3 passes (2 ref serial + 1 batched policy)
- FULL_CONCURRENT: 2 passes (1 batched ref + 1 batched policy)

Two new static helpers on BaseModelSetup merge chosen/rejected into a single
[2B, ...] batch by concatenating on dim 0, then split the model output back
into chosen/rejected halves. predict() in every model setup already reads
batch['latent_image'].shape[0] dynamically, so no per-model changes are needed.

Batching is gated on rlhf_dpo_shared_noise (the default). With shared noise
off, chosen and rejected use different global_step seeds for timestep/noise
and cannot share a forward, so all modes fall back to the original 4-pass
implementation. Loss math and DPO metrics are unchanged.

Execution Mode tooltip in RLHFTab.py updated to describe the new tiers
honestly.
@dxqb dxqb marked this pull request as draft May 24, 2026 14:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants