Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion modules/modelSetup/BaseChromaSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def predict(
deterministic: bool = False,
) -> dict:
with model.autocast_context:
batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
batch_seed = int(batch.get("__val_noise_seed__", 0)) if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
generator = torch.Generator(device=config.train_device)
generator.manual_seed(batch_seed)
rand = Random(batch_seed)
Expand Down Expand Up @@ -196,6 +196,7 @@ def predict(
generator,
scaled_latent_image.shape[0],
config,
validation_override=batch.get("__val_timestep_unit__"),
)

scaled_noisy_latent_image, sigma = self._add_noise_discrete(
Expand Down
3 changes: 2 additions & 1 deletion modules/modelSetup/BaseErnieSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def predict(
deterministic: bool = False,
) -> dict:
with model.autocast_context:
batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
batch_seed = int(batch.get("__val_noise_seed__", 0)) if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
generator = torch.Generator(device=config.train_device)
generator.manual_seed(batch_seed)
rand = Random(batch_seed)
Expand Down Expand Up @@ -110,6 +110,7 @@ def predict(
scaled_latent_image.shape[0],
config,
shift=shift if config.dynamic_timestep_shifting else config.timestep_shift,
validation_override=batch.get("__val_timestep_unit__"),
)

scaled_noisy_latent_image, sigma = self._add_noise_discrete(
Expand Down
3 changes: 2 additions & 1 deletion modules/modelSetup/BaseFlux2Setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def predict(
deterministic: bool = False,
) -> dict:
with model.autocast_context:
batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
batch_seed = int(batch.get("__val_noise_seed__", 0)) if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
generator = torch.Generator(device=config.train_device)
generator.manual_seed(batch_seed)
rand = Random(batch_seed)
Expand Down Expand Up @@ -119,6 +119,7 @@ def predict(
scaled_latent_image.shape[0],
config,
shift = shift if config.dynamic_timestep_shifting else config.timestep_shift,
validation_override=batch.get("__val_timestep_unit__"),
)

scaled_noisy_latent_image, sigma = self._add_noise_discrete(
Expand Down
3 changes: 2 additions & 1 deletion modules/modelSetup/BaseFluxSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def predict(
deterministic: bool = False,
) -> dict:
with model.autocast_context:
batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
batch_seed = int(batch.get("__val_noise_seed__", 0)) if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
generator = torch.Generator(device=config.train_device)
generator.manual_seed(batch_seed)
rand = Random(batch_seed)
Expand Down Expand Up @@ -249,6 +249,7 @@ def predict(
scaled_latent_image.shape[0],
config,
shift = shift if config.dynamic_timestep_shifting else config.timestep_shift,
validation_override=batch.get("__val_timestep_unit__"),
)

scaled_noisy_latent_image, sigma = self._add_noise_discrete(
Expand Down
3 changes: 2 additions & 1 deletion modules/modelSetup/BaseHiDreamSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def predict(
deterministic: bool = False,
) -> dict:
with model.autocast_context:
batch_seed = 0 if deterministic else train_progress.global_step
batch_seed = int(batch.get("__val_noise_seed__", 0)) if deterministic else train_progress.global_step
generator = torch.Generator(device=config.train_device)
generator.manual_seed(batch_seed)
rand = Random(batch_seed)
Expand Down Expand Up @@ -348,6 +348,7 @@ def predict(
generator,
scaled_latent_image.shape[0],
config,
validation_override=batch.get("__val_timestep_unit__"),
)

scaled_noisy_latent_image, sigma = self._add_noise_discrete(
Expand Down
3 changes: 2 additions & 1 deletion modules/modelSetup/BaseHunyuanVideoSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def predict(
deterministic: bool = False,
) -> dict:
with model.autocast_context:
batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
batch_seed = int(batch.get("__val_noise_seed__", 0)) if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
generator = torch.Generator(device=config.train_device)
generator.manual_seed(batch_seed)
rand = Random(batch_seed)
Expand Down Expand Up @@ -243,6 +243,7 @@ def predict(
generator,
scaled_latent_image.shape[0],
config,
validation_override=batch.get("__val_timestep_unit__"),
)

scaled_noisy_latent_image, sigma = self._add_noise_discrete(
Expand Down
3 changes: 2 additions & 1 deletion modules/modelSetup/BasePixArtAlphaSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def predict(
deterministic: bool = False,
) -> dict:
with model.autocast_context:
batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
batch_seed = int(batch.get("__val_noise_seed__", 0)) if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
generator = torch.Generator(device=config.train_device)
generator.manual_seed(batch_seed)
rand = Random(batch_seed)
Expand Down Expand Up @@ -193,6 +193,7 @@ def predict(
generator,
scaled_latent_image.shape[0],
config,
validation_override=batch.get("__val_timestep_unit__"),
)

scaled_noisy_latent_image = self._add_noise_discrete(
Expand Down
3 changes: 2 additions & 1 deletion modules/modelSetup/BaseQwenSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def predict(
deterministic: bool = False,
) -> dict:
with model.autocast_context:
batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
batch_seed = int(batch.get("__val_noise_seed__", 0)) if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
generator = torch.Generator(device=config.train_device)
generator.manual_seed(batch_seed)
rand = Random(batch_seed)
Expand Down Expand Up @@ -114,6 +114,7 @@ def predict(
scaled_latent_image.shape[0],
config,
shift = shift if config.dynamic_timestep_shifting else config.timestep_shift,
validation_override=batch.get("__val_timestep_unit__"),
)

scaled_noisy_latent_image, sigma = self._add_noise_discrete(
Expand Down
3 changes: 2 additions & 1 deletion modules/modelSetup/BaseSanaSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def predict(
deterministic: bool = False,
) -> dict:
with model.autocast_context:
batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
batch_seed = int(batch.get("__val_noise_seed__", 0)) if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
generator = torch.Generator(device=config.train_device)
generator.manual_seed(batch_seed)
rand = Random(batch_seed)
Expand Down Expand Up @@ -203,6 +203,7 @@ def predict(
generator,
scaled_latent_image.shape[0],
config,
validation_override=batch.get("__val_timestep_unit__"),
)

scaled_noisy_latent_image, sigma = self._add_noise_discrete(
Expand Down
3 changes: 2 additions & 1 deletion modules/modelSetup/BaseStableDiffusion3Setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def predict(
deterministic: bool = False,
) -> dict:
with model.autocast_context:
batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
batch_seed = int(batch.get("__val_noise_seed__", 0)) if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
generator = torch.Generator(device=config.train_device)
generator.manual_seed(batch_seed)
rand = Random(batch_seed)
Expand Down Expand Up @@ -300,6 +300,7 @@ def predict(
generator,
scaled_latent_image.shape[0],
config,
validation_override=batch.get("__val_timestep_unit__"),
)

scaled_noisy_latent_image, sigma = self._add_noise_discrete(
Expand Down
3 changes: 2 additions & 1 deletion modules/modelSetup/BaseStableDiffusionSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def predict(
deterministic: bool = False,
) -> dict:
with model.autocast_context:
batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
batch_seed = int(batch.get("__val_noise_seed__", 0)) if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
generator = torch.Generator(device=config.train_device)
generator.manual_seed(batch_seed)
rand = Random(batch_seed)
Expand Down Expand Up @@ -182,6 +182,7 @@ def predict(
generator,
scaled_latent_image.shape[0],
config,
validation_override=batch.get("__val_timestep_unit__"),
)

latent_noise = self._create_noise(
Expand Down
3 changes: 2 additions & 1 deletion modules/modelSetup/BaseStableDiffusionXLSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def predict(
deterministic: bool = False,
) -> dict:
with model.autocast_context:
batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
batch_seed = int(batch.get("__val_noise_seed__", 0)) if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
generator = torch.Generator(device=config.train_device)
generator.manual_seed(batch_seed)
rand = Random(batch_seed)
Expand Down Expand Up @@ -233,6 +233,7 @@ def predict(
generator,
scaled_latent_image.shape[0],
config,
validation_override=batch.get("__val_timestep_unit__"),
)

latent_noise = self._create_noise(
Expand Down
3 changes: 2 additions & 1 deletion modules/modelSetup/BaseWuerstchenSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def predict(
elif model.model_type.is_stable_cascade():
scaled_latent_image = latent_image

batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
batch_seed = int(batch.get("__val_noise_seed__", 0)) if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
generator = torch.Generator(device=config.train_device)
generator.manual_seed(batch_seed)
rand = Random(batch_seed)
Expand All @@ -222,6 +222,7 @@ def predict(
generator,
scaled_latent_image.shape[0],
config,
validation_override=batch.get("__val_timestep_unit__"),
)

if model.model_type.is_wuerstchen_v2():
Expand Down
3 changes: 2 additions & 1 deletion modules/modelSetup/BaseZImageSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def predict(
deterministic: bool = False,
) -> dict:
with model.autocast_context:
batch_seed = 0 if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
batch_seed = int(batch.get("__val_noise_seed__", 0)) if deterministic else train_progress.global_step * multi.world_size() + multi.rank()
generator = torch.Generator(device=config.train_device)
generator.manual_seed(batch_seed)
rand = Random(batch_seed)
Expand All @@ -114,6 +114,7 @@ def predict(
scaled_latent_image.shape[0],
config,
shift = shift if config.dynamic_timestep_shifting else config.timestep_shift,
validation_override=batch.get("__val_timestep_unit__"),
)

scaled_noisy_latent_image, sigma = self._add_noise_discrete(
Expand Down
15 changes: 12 additions & 3 deletions modules/modelSetup/mixin/ModelSetupNoiseMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,21 @@ def _get_timestep_discrete(
batch_size: int,
config: TrainConfig,
shift: float = None,
validation_override: float | None = None,
) -> Tensor:
if shift is None:
shift = config.timestep_shift

if deterministic:
# -1 is for zero-based indexing
if validation_override is not None:
# Already-shifted unit position in [0, 1]; map to integer timestep.
t = int(validation_override * num_train_timesteps)
t = max(0, min(num_train_timesteps - 1, t))
else:
# -1 is for zero-based indexing
t = int(num_train_timesteps * 0.5) - 1
return torch.tensor(
int(num_train_timesteps * 0.5) - 1,
t,
dtype=torch.long,
device=generator.device,
).unsqueeze(0)
Expand Down Expand Up @@ -217,11 +224,13 @@ def _get_timestep_continuous(
generator: Generator,
batch_size: int,
config: TrainConfig,
validation_override: float | None = None,
) -> Tensor:
if deterministic:
fill_value = float(validation_override) if validation_override is not None else 0.5
return torch.full(
size=(batch_size,),
fill_value=0.5,
fill_value=fill_value,
device=generator.device,
)
else:
Expand Down
44 changes: 39 additions & 5 deletions modules/trainer/GenericTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from modules.util.bf16_stochastic_rounding import set_seed as bf16_stochastic_rounding_set_seed
from modules.util.callbacks.TrainCallbacks import TrainCallbacks
from modules.util.commands.TrainCommands import TrainCommands
from modules.util.config.ConceptConfig import ConceptConfig
from modules.util.config.SampleConfig import SampleConfig
from modules.util.config.TrainConfig import TrainConfig
from modules.util.dtype_util import create_grad_scaler, enable_grad_scaling
Expand All @@ -33,6 +34,11 @@
from modules.util.time_util import get_string_timestamp
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress
from modules.util.validation_timestep import (
apply_timestep_shift_unit,
stratified_unit_position,
validation_noise_seed,
)

import torch
from torch import Tensor, nn
Expand Down Expand Up @@ -364,20 +370,35 @@ def __validate(self, train_progress: TrainProgress):
mapping_seed_to_label = {}
mapping_label_to_seed = {}

for validation_batch in step_tqdm_validation:
concept_shift_by_seed = self.__validation_concept_shifts()
global_validation_shift = self.config.validation_timestep_shift
n_validation = current_epoch_length_validation

for i, validation_batch in enumerate(step_tqdm_validation):
if self.__needs_gc(train_progress):
torch_gc()

# since validation batch size = 1
concept_name = validation_batch["concept_name"][0]
concept_path = validation_batch["concept_path"][0]
concept_seed = validation_batch["concept_seed"].item()

shift_for_sample = concept_shift_by_seed.get(concept_seed)
if shift_for_sample is None:
shift_for_sample = global_validation_shift

pos = stratified_unit_position(i, n_validation)
validation_batch["__val_timestep_unit__"] = apply_timestep_shift_unit(
pos, shift_for_sample
)
validation_batch["__val_noise_seed__"] = validation_noise_seed(i)

with torch.no_grad():
model_output_data = self.model_setup.predict(
self.model, validation_batch, self.config, train_progress, deterministic=True)
loss_validation = self.model_setup.calculate_loss(
self.model, validation_batch, model_output_data, self.config)

# since validation batch size = 1
concept_name = validation_batch["concept_name"][0]
concept_path = validation_batch["concept_path"][0]
concept_seed = validation_batch["concept_seed"].item()
loss = loss_validation.item()

label = concept_name if concept_name else os.path.basename(concept_path)
Expand Down Expand Up @@ -413,6 +434,19 @@ def __validate(self, train_progress: TrainProgress):
total_average_loss,
train_progress.global_step)

def __validation_concept_shifts(self) -> dict[int, float]:
concepts = self.config.concepts
if concepts is None:
with open(self.config.concept_file_name, 'r') as f:
concepts = [ConceptConfig.default_values().from_dict(c) for c in json.load(f)]

return {
concept.seed: concept.validation_timestep_shift
for concept in concepts
if ConceptType(concept.type) == ConceptType.VALIDATION
and concept.validation_timestep_shift is not None
}

def __save_backup_config(self, backup_path):
config_path = os.path.join(backup_path, "onetrainer_config")
args_path = path_util.canonical_join(config_path, "args.json")
Expand Down
6 changes: 6 additions & 0 deletions modules/ui/ConceptWindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ def set_prompt_path_entry_enabled(option: str):
tooltip="The loss multiplyer for this concept.")
components.entry(frame, 9, 1, self.ui_state, "loss_weight")

# validation timestep shift (only meaningful for VALIDATION concepts)
components.label(frame, 10, 0, "Validation Timestep Shift",
tooltip="Per-concept override for the global Validation Timestep Shift. Leave blank to inherit. Only used for VALIDATION concepts.",
wide_tooltip=True)
components.entry(frame, 10, 1, self.ui_state, "validation_timestep_shift")

frame.pack(fill="both", expand=1)
return frame

Expand Down
9 changes: 7 additions & 2 deletions modules/ui/TrainingTab.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,11 +691,16 @@ def __create_noise_frame(self, master, row, supports_generalized_offset_noise: b
tooltip="Shift the timestep distribution. Use the preview to see more details.")
components.entry(frame, 8, 1, self.ui_state, "timestep_shift", required=True)

# validation timestep shift
components.label(frame, 9, 0, "Validation Timestep Shift",
tooltip="Shift the validation timestep distribution. Concepts can override this value individually.")
components.entry(frame, 9, 1, self.ui_state, "validation_timestep_shift", required=True)

if supports_dynamic_timestep_shifting:
# dynamic timestep shifting
components.label(frame, 9, 0, "Dynamic Timestep Shifting",
components.label(frame, 10, 0, "Dynamic Timestep Shifting",
tooltip="Dynamically shift the timestep distribution based on resolution. If enabled, the shifting parameters are taken from the model's scheduler configuration and Timestep Shift is ignored. Note: For Z-Image and Flux2, the dynamic shifting parameters are likely wrong and unknown. Use with care or set your own, fixed shift.", wide_tooltip=True)
components.switch(frame, 9, 1, self.ui_state, "dynamic_timestep_shifting")
components.switch(frame, 10, 1, self.ui_state, "dynamic_timestep_shifting")



Expand Down
Loading