Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f5ff304
Implement RLHF DPO (Direct Preference Optimization) training
BitcrushedHeart Mar 29, 2026
cc0f749
DPO Pair Tool: fix image scaling and add expandable prompt display
BitcrushedHeart Apr 1, 2026
3c76232
fix: address code review issues in RLHF DPO implementation
BitcrushedHeart Apr 2, 2026
0effbef
fix: handle UTF-16LE metadata in WebP files for DPO prompt matching
BitcrushedHeart Apr 3, 2026
9f57d0c
feat: DPO pair tool saves picks immediately with resume support
BitcrushedHeart Apr 4, 2026
32cdf47
feat: background scan + prefetch queue for DPO pair tool
BitcrushedHeart Apr 4, 2026
708acad
feat: option to keep scoring remaining images after accepting a pair
BitcrushedHeart Apr 4, 2026
8a458ee
feat: DPO best checkpoint, pair review mode, and UI improvements
BitcrushedHeart Apr 5, 2026
46ad292
feat: support Forge/A1111 and ComfyUI PNG metadata in DPO curation
BitcrushedHeart Apr 5, 2026
ec33568
fix: flatten multi-line prompts in DPO caption export
BitcrushedHeart Apr 5, 2026
26e3b7a
feat: DPO patience tiebreaker, pair review window, and focus fixes
BitcrushedHeart Apr 6, 2026
4aeb2fb
Handle empty prompts as UNCONDITIONAL in DPO Curation Tool
BitcrushedHeart Apr 9, 2026
750619e
DPO tool: fast PNG scanning, multiline caption fix, orphan pruning
BitcrushedHeart Apr 12, 2026
cd21f59
DPO curation: normalize prompts before grouping
BitcrushedHeart Apr 14, 2026
ac42d25
Change default RLHF DPO beta from 5000 to 300
BitcrushedHeart Apr 14, 2026
613505a
fix(dpo): dedup keeps latest mtime instead of first seen
BitcrushedHeart Apr 19, 2026
ad42b57
feat(dpo): add bucket/batch-size analyzer window
BitcrushedHeart Apr 24, 2026
068b74b
style(dpo): split long import in RLHFTab to satisfy ruff
BitcrushedHeart Apr 24, 2026
4fb4815
fix(dpo): average accuracy across grad-accum window
BitcrushedHeart Apr 24, 2026
a1111c9
fix(dpo): filter already-paired sources from groups; tighten uncondit…
BitcrushedHeart May 1, 2026
fbd079e
fix(dpo): derive aspect ratio from image dimensions when metadata lac…
BitcrushedHeart May 1, 2026
d4eea79
feat(dpo): switch dedup from perceptual dhash to BLAKE3 content hash
BitcrushedHeart May 1, 2026
e566131
style(dpo): use dict.fromkeys for accumulator init (C420)
BitcrushedHeart May 10, 2026
669b64b
perf(rlhf): batch chosen+rejected forwards to cut DPO from 4 to 2 passes
BitcrushedHeart May 14, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modules/dataLoader/BaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
config = copy.copy(config)
config.batch_size = 1
config.multi_gpu = False
config.rlhf_enabled = config.rlhf_dpo_validation and config.rlhf_enabled

self.__ds = self._create_dataset(
config=config,
Expand Down
3 changes: 2 additions & 1 deletion modules/dataLoader/ChromaBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _cache_modules(self, config: TrainConfig, model: ChromaModel, model_setup: B
text_caching = not config.train_text_encoder_or_embedding(),
)

def _output_modules(self, config: TrainConfig, model: ChromaModel, model_setup: BaseChromaSetup):
def _output_modules(self, config: TrainConfig, model: ChromaModel, model_setup: BaseChromaSetup, is_validation: bool = False):
pad_masked_tokens = PadMaskedTokens(tokens_name='tokens', tokens_mask_name='tokens_mask', hidden_state_name='text_encoder_hidden_state', max_length=model.tokenizer.model_max_length)

output_names = [
Expand All @@ -108,6 +108,7 @@ def _output_modules(self, config: TrainConfig, model: ChromaModel, model_setup:
vae=model.vae,
autocast_context=[model.autocast_context],
train_dtype=model.train_dtype,
is_validation=is_validation,
)

if config.latent_caching and not config.train_text_encoder_or_embedding():
Expand Down
3 changes: 2 additions & 1 deletion modules/dataLoader/Flux2BaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _cache_modules(self, config: TrainConfig, model: Flux2Model, model_setup: Ba
text_caching=True,
)

def _output_modules(self, config: TrainConfig, model: Flux2Model, model_setup: BaseFlux2Setup):
def _output_modules(self, config: TrainConfig, model: Flux2Model, model_setup: BaseFlux2Setup, is_validation: bool = False):
output_names = [
'image_path', 'latent_image',
'prompt',
Expand All @@ -115,6 +115,7 @@ def _output_modules(self, config: TrainConfig, model: Flux2Model, model_setup: B
vae=model.vae,
autocast_context=[model.autocast_context],
train_dtype=model.train_dtype,
is_validation=is_validation,
)

def _debug_modules(self, config: TrainConfig, model: Flux2Model):
Expand Down
3 changes: 2 additions & 1 deletion modules/dataLoader/FluxBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _cache_modules(self, config: TrainConfig, model: FluxModel, model_setup: Bas
text_caching=not config.train_text_encoder_or_embedding() or not config.train_text_encoder_2_or_embedding(),
)

def _output_modules(self, config: TrainConfig, model: FluxModel, model_setup: BaseFluxSetup):
def _output_modules(self, config: TrainConfig, model: FluxModel, model_setup: BaseFluxSetup, is_validation: bool = False):
output_names = [
'image_path', 'latent_image',
'prompt_1', 'prompt_2',
Expand Down Expand Up @@ -135,6 +135,7 @@ def _output_modules(self, config: TrainConfig, model: FluxModel, model_setup: Ba
vae=model.vae,
autocast_context=[model.autocast_context],
train_dtype=model.train_dtype,
is_validation=is_validation,
)

def _debug_modules(self, config: TrainConfig, model: FluxModel):
Expand Down
3 changes: 2 additions & 1 deletion modules/dataLoader/HiDreamBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _cache_modules(self, config: TrainConfig, model: HiDreamModel, model_setup:
or not config.train_text_encoder_4_or_embedding(),
)

def _output_modules(self, config: TrainConfig, model: HiDreamModel, model_setup: BaseHiDreamSetup):
def _output_modules(self, config: TrainConfig, model: HiDreamModel, model_setup: BaseHiDreamSetup, is_validation: bool = False):
output_names = [
'image_path', 'latent_image',
'prompt_1', 'prompt_2', 'prompt_3', 'prompt_4',
Expand Down Expand Up @@ -173,6 +173,7 @@ def _output_modules(self, config: TrainConfig, model: HiDreamModel, model_setup:
vae=model.vae,
autocast_context=[model.autocast_context],
train_dtype=model.train_dtype,
is_validation=is_validation,
)

def _debug_modules(self, config: TrainConfig, model: HiDreamModel):
Expand Down
3 changes: 2 additions & 1 deletion modules/dataLoader/HunyuanVideoBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _cache_modules(self, config: TrainConfig, model: HunyuanVideoModel, model_se
text_caching=not config.train_text_encoder_or_embedding() or not config.train_text_encoder_2_or_embedding(),
)

def _output_modules(self, config: TrainConfig, model: HunyuanVideoModel, model_setup: BaseHunyuanVideoSetup):
def _output_modules(self, config: TrainConfig, model: HunyuanVideoModel, model_setup: BaseHunyuanVideoSetup, is_validation: bool = False):
output_names = [
'image_path', 'latent_image',
'prompt_1', 'prompt_2',
Expand Down Expand Up @@ -129,6 +129,7 @@ def _output_modules(self, config: TrainConfig, model: HunyuanVideoModel, model_s
vae=model.vae,
autocast_context=[model.autocast_context],
train_dtype=model.train_dtype,
is_validation=is_validation,
)

def _debug_modules(self, config: TrainConfig, model: HunyuanVideoModel):
Expand Down
3 changes: 2 additions & 1 deletion modules/dataLoader/PixArtAlphaBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _cache_modules(self, config: TrainConfig, model: PixArtAlphaModel, model_set
text_caching=not config.train_text_encoder_or_embedding(),
)

def _output_modules(self, config: TrainConfig, model: PixArtAlphaModel, model_setup: BasePixArtAlphaSetup):
def _output_modules(self, config: TrainConfig, model: PixArtAlphaModel, model_setup: BasePixArtAlphaSetup, is_validation: bool = False):
output_names = [
'image_path', 'latent_image',
'prompt',
Expand All @@ -113,6 +113,7 @@ def _output_modules(self, config: TrainConfig, model: PixArtAlphaModel, model_se
vae=model.vae,
autocast_context=[model.autocast_context],
train_dtype=model.train_dtype,
is_validation=is_validation,
)

def _debug_modules(self, config: TrainConfig, model: PixArtAlphaModel):
Expand Down
3 changes: 2 additions & 1 deletion modules/dataLoader/QwenBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _cache_modules(self, config: TrainConfig, model: QwenModel, model_setup: Bas
text_caching=not config.train_text_encoder_or_embedding(),
)

def _output_modules(self, config: TrainConfig, model: QwenModel, model_setup: BaseQwenSetup):
def _output_modules(self, config: TrainConfig, model: QwenModel, model_setup: BaseQwenSetup, is_validation: bool = False):
pad_masked_tokens = PadMaskedTokens(tokens_name='tokens', tokens_mask_name='tokens_mask', hidden_state_name='text_encoder_hidden_state', max_length=PROMPT_MAX_LENGTH)

output_names = [
Expand All @@ -112,6 +112,7 @@ def _output_modules(self, config: TrainConfig, model: QwenModel, model_setup: Ba
vae=model.vae,
autocast_context=[model.autocast_context],
train_dtype=model.train_dtype,
is_validation=is_validation,
)

if config.latent_caching and not config.train_text_encoder_or_embedding():
Expand Down
3 changes: 2 additions & 1 deletion modules/dataLoader/SanaBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _cache_modules(self, config: TrainConfig, model: SanaModel, model_setup: Bas
text_caching=not config.train_text_encoder_or_embedding(),
)

def _output_modules(self, config: TrainConfig, model: SanaModel, model_setup: BaseSanaSetup):
def _output_modules(self, config: TrainConfig, model: SanaModel, model_setup: BaseSanaSetup, is_validation: bool = False):
output_names = [
'image_path', 'latent_image',
'prompt',
Expand All @@ -106,6 +106,7 @@ def _output_modules(self, config: TrainConfig, model: SanaModel, model_setup: Ba
vae=model.vae,
autocast_context=[model.autocast_context],
train_dtype=model.train_dtype,
is_validation=is_validation,
)

def _debug_modules(self, config: TrainConfig, model: SanaModel):
Expand Down
3 changes: 2 additions & 1 deletion modules/dataLoader/StableDiffusion3BaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _cache_modules(self, config: TrainConfig, model: StableDiffusion3Model, mode
text_caching=not config.train_text_encoder_or_embedding() or not config.train_text_encoder_2_or_embedding() or not config.train_text_encoder_3_or_embedding(),
)

def _output_modules(self, config: TrainConfig, model: StableDiffusion3Model, model_setup: BaseStableDiffusion3Setup):
def _output_modules(self, config: TrainConfig, model: StableDiffusion3Model, model_setup: BaseStableDiffusion3Setup, is_validation: bool = False):
output_names = [
'image_path', 'latent_image',
'prompt_1', 'prompt_2', 'prompt_3',
Expand Down Expand Up @@ -152,6 +152,7 @@ def _output_modules(self, config: TrainConfig, model: StableDiffusion3Model, mod
vae=model.vae,
autocast_context=[model.autocast_context],
train_dtype=model.train_dtype,
is_validation=is_validation,
)

def _debug_modules(self, config: TrainConfig, model: StableDiffusion3Model):
Expand Down
3 changes: 2 additions & 1 deletion modules/dataLoader/StableDiffusionBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _cache_modules(self, config: TrainConfig, model: StableDiffusionModel, model
text_caching=not config.train_text_encoder_or_embedding(),
)

def _output_modules(self, config: TrainConfig, model: StableDiffusionModel, model_setup: BaseStableDiffusionSetup):
def _output_modules(self, config: TrainConfig, model: StableDiffusionModel, model_setup: BaseStableDiffusionSetup, is_validation: bool = False):
output_names = [
'image_path', 'latent_image',
'prompt',
Expand All @@ -116,6 +116,7 @@ def _output_modules(self, config: TrainConfig, model: StableDiffusionModel, mode
vae=model.vae,
autocast_context=[model.autocast_context],
train_dtype=model.train_dtype,
is_validation=is_validation,
)

def _debug_modules(self, config: TrainConfig, model: StableDiffusionModel):
Expand Down
3 changes: 2 additions & 1 deletion modules/dataLoader/StableDiffusionXLBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _cache_modules(self, config: TrainConfig, model: StableDiffusionXLModel, mod
text_caching=not config.train_text_encoder_or_embedding() or not config.train_text_encoder_2_or_embedding(),
)

def _output_modules(self, config: TrainConfig, model: StableDiffusionXLModel, model_setup: BaseStableDiffusionXLSetup):
def _output_modules(self, config: TrainConfig, model: StableDiffusionXLModel, model_setup: BaseStableDiffusionXLSetup, is_validation: bool = False):
output_names = [
'image_path', 'latent_image',
'prompt_1', 'prompt_2',
Expand Down Expand Up @@ -130,6 +130,7 @@ def _output_modules(self, config: TrainConfig, model: StableDiffusionXLModel, mo
vae=model.vae,
autocast_context=[model.autocast_context, model.vae_autocast_context],
train_dtype=model.vae_train_dtype,
is_validation=is_validation,
)

def _debug_modules(self, config: TrainConfig, model: StableDiffusionXLModel):
Expand Down
3 changes: 2 additions & 1 deletion modules/dataLoader/WuerstchenBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def before_cache_image_fun():
before_cache_image_fun=before_cache_image_fun
)

def _output_modules(self, config: TrainConfig, model: WuerstchenModel, model_setup: BaseWuerstchenSetup):
def _output_modules(self, config: TrainConfig, model: WuerstchenModel, model_setup: BaseWuerstchenSetup, is_validation: bool = False):
output_names = [
'image_path', 'latent_image',
'prompt',
Expand Down Expand Up @@ -119,6 +119,7 @@ def before_cache_image_fun():
before_cache_image_fun=before_cache_image_fun,
autocast_context=[model.autocast_context],
train_dtype=model.train_dtype,
is_validation=is_validation,
)

def _debug_modules(self, config: TrainConfig, model: WuerstchenModel):
Expand Down
3 changes: 2 additions & 1 deletion modules/dataLoader/ZImageBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _cache_modules(self, config: TrainConfig, model: ZImageModel, model_setup: B
text_caching=True,
)

def _output_modules(self, config: TrainConfig, model: ZImageModel, model_setup: BaseZImageSetup):
def _output_modules(self, config: TrainConfig, model: ZImageModel, model_setup: BaseZImageSetup, is_validation: bool = False):
pad_masked_tokens = PadMaskedTokens(tokens_name='tokens', tokens_mask_name='tokens_mask', hidden_state_name='text_encoder_hidden_state', max_length=PROMPT_MAX_LENGTH)

output_names = [
Expand All @@ -107,6 +107,7 @@ def _output_modules(self, config: TrainConfig, model: ZImageModel, model_setup:
vae=model.vae,
autocast_context=[model.autocast_context],
train_dtype=model.train_dtype,
is_validation=is_validation,
)

if config.latent_caching:
Expand Down
118 changes: 118 additions & 0 deletions modules/dataLoader/dpo/PairByFilename.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import os

from modules.util.dpo_curation_util import dpo_pair_key

from mgds.PipelineModule import PipelineModule
from mgds.pipelineModuleTypes.RandomAccessPipelineModule import RandomAccessPipelineModule

import torch


class PairByFilename(
PipelineModule,
RandomAccessPipelineModule,
):
def __init__(
self,
concept_pairs: list[tuple[str, str]],
chosen_names: list[str | tuple[str, str]],
rejected_names: list[str | tuple[str, str]],
):
super().__init__()

self.chosen_names = [x if isinstance(x, tuple) else (x, x) for x in chosen_names]
self.rejected_names = [x if isinstance(x, tuple) else (x, x) for x in rejected_names]

self.concept_lookup = {}
for pair_id, (chosen_path, rejected_path) in enumerate(concept_pairs):
self.concept_lookup[self.__canonical_path(chosen_path)] = (pair_id, True)
self.concept_lookup[self.__canonical_path(rejected_path)] = (pair_id, False)

self._pair_indices: list[tuple[int, int]] | None = None

@staticmethod
def __canonical_path(path: str) -> str:
return os.path.normcase(os.path.abspath(path))

def __build_pair_indices(self):
chosen_indices = {}
rejected_indices = {}

for index in range(self._get_previous_length('image_path')):
concept_path = self.__canonical_path(self._get_previous_item(0, 'concept.path', index))
pair_info = self.concept_lookup.get(concept_path)
if pair_info is None:
continue

pair_id, is_chosen = pair_info
image_path = self._get_previous_item(0, 'image_path', index)
key = (pair_id, dpo_pair_key(image_path, concept_path))

if is_chosen:
chosen_indices[key] = index
else:
rejected_indices[key] = index

pair_indices = []
missing_rejected = sorted(set(chosen_indices) - set(rejected_indices))
missing_chosen = sorted(set(rejected_indices) - set(chosen_indices))
if missing_rejected or missing_chosen:
details = []
if missing_rejected:
details.append(f"{len(missing_rejected)} chosen files are missing rejected matches")
if missing_chosen:
details.append(f"{len(missing_chosen)} rejected files are missing chosen matches")
raise RuntimeError("RLHF DPO concept pairs must match exactly by filename: " + ", ".join(details) + ".")

for key, chosen_index in chosen_indices.items():
rejected_index = rejected_indices.get(key)
if rejected_index is not None:
pair_indices.append((chosen_index, rejected_index))

pair_indices.sort(key=lambda x: x[0])
if not pair_indices:
raise RuntimeError("No DPO pairs could be matched by filename between the configured chosen/rejected concepts.")

self._pair_indices = pair_indices

def __get_pair_indices(self) -> list[tuple[int, int]]:
if self._pair_indices is None:
self.__build_pair_indices()
return self._pair_indices

def length(self) -> int:
return len(self.__get_pair_indices())

def get_inputs(self) -> list[str]:
names = ['concept.path', 'image_path', 'prompt', 'crop_resolution']
names += [in_name for in_name, _ in self.chosen_names]
names += [in_name for in_name, _ in self.rejected_names]
return list(dict.fromkeys(names))

def get_outputs(self) -> list[str]:
return [out_name for _, out_name in self.chosen_names] + [out_name for _, out_name in self.rejected_names]

def get_item(self, variation: int, index: int, requested_name: str = None) -> dict:
chosen_index, rejected_index = self.__get_pair_indices()[index]

chosen_prompt = self._get_previous_item(variation, 'prompt', chosen_index)
rejected_prompt = self._get_previous_item(variation, 'prompt', rejected_index)
if chosen_prompt != rejected_prompt:
raise RuntimeError("RLHF DPO paired samples must use identical prompts/captions in chosen and rejected concepts.")

chosen_crop_resolution = self._get_previous_item(variation, 'crop_resolution', chosen_index)
rejected_crop_resolution = self._get_previous_item(variation, 'crop_resolution', rejected_index)
if isinstance(chosen_crop_resolution, torch.Tensor) and isinstance(rejected_crop_resolution, torch.Tensor):
same_resolution = torch.equal(chosen_crop_resolution, rejected_crop_resolution)
else:
same_resolution = chosen_crop_resolution == rejected_crop_resolution
if not same_resolution:
raise RuntimeError("RLHF DPO paired samples must have matching crop resolutions in chosen and rejected concepts.")

item = {}
for in_name, out_name in self.chosen_names:
item[out_name] = self._get_previous_item(variation, in_name, chosen_index)
for in_name, out_name in self.rejected_names:
item[out_name] = self._get_previous_item(variation, in_name, rejected_index)

return item
Loading