Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions lightx2v_train/configs/train/dmd/flux2_klein_dmd_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
model:
name: flux2_klein
pretrained_model_name_or_path: /path/to/FLUX.2-klein-4B
max_sequence_length: 1024
running_dtype: bf16

data:
train:
name: image_dataset
num_workers: 8
prompt_dropout_rate: 0.0
target_area: 1048576 # 1024 * 1024
shuffle: true
data_path:
- /path/to/LightX2V_train_data_examples/dataset_v1/train.jsonl
val:
name: image_dataset
num_workers: 8
shuffle: false
data_path:
- /path/to/LightX2V_train_data_examples/dataset_v1/val.jsonl

scheduler:
num_train_timesteps: 1000
time_shift_settings:
do_time_shift: true
shift_type: exponential
time_shift_power: 1.0
dynamic_shift: false
time_shift_mu: 3.0

training:
method: dmd
train_type: lora
max_train_iters: 1000
gradient_accumulation_iters: 1
gradient_checkpointing: true
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_iters: 10
save_every_iters: 100
save_total_limit: 10
dmd:
num_inference_steps: 4
fake_update_ratio: 2
image_sizes:
- [1024, 1024]
- [768, 1344]
- [1344, 768]
renoise_sigma_min: 0.02
renoise_sigma_max: 1.0
renoise_discrete_samples: 1000
renoise_shift: 5.0
lora:
rank: 32
alpha: 32
target_modules:
- to_k
- to_q
- to_v
- to_out.0
student:
optimizer:
learning_rate: 0.0001
adam_beta1: 0.9
adam_beta2: 0.999
weight_decay: 0.001
adam_epsilon: 0.00000001
fake:
optimizer:
learning_rate: 0.00002
adam_beta1: 0.9
adam_beta2: 0.999
weight_decay: 0.001
adam_epsilon: 0.00000001
teacher:
guidance_scale: 4.0
negative_prompt: " "
cfg_norm: layer_norm
output_dir: ./output_train/flux2_klein_dmd_lora

inference:
method: image_infer
default_width: 1024
default_height: 1024
num_inference_steps: 4
cfg_guidance_scale: 4.0
negative_prompt: " "
enable_cfg: false
output_dir: ./output_infer/flux2_klein_dmd_lora
infer_every_iters: ${training.save_every_iters}

logging:
rank_zero_only: true
train_log_every_iters: 10
infer_log_every_steps: 10

resume:
auto_resume: true
3 changes: 3 additions & 0 deletions lightx2v_train/lightx2v_train/model_zoo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def __init__(self, config):
def load_components(self, transformer_only=False, reference_model=None):
raise NotImplementedError

def dmd_latent_shape(self, batch_size, height, width):
raise NotImplementedError(f"{self.__class__.__name__} must define dmd_latent_shape().")
Comment on lines 29 to +31

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The method dmd_latent_shape is already defined at lines 169-170 in this file. Adding it here at lines 30-31 introduces a duplicate method definition. Please remove this duplicate definition to keep the class clean and maintainable.


def denoiser_module(self):
raise NotImplementedError(f"{self.__class__.__name__} must define denoiser_module().")

Expand Down
2 changes: 1 addition & 1 deletion lightx2v_train/lightx2v_train/model_zoo/flux2_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def encode_prompt_condition(self, prompt):
)
return {"prompt_embed": prompt_embed, "text_ids": text_ids}

def prepare_denoiser_input(self, noisy_latent):
def prepare_denoiser_input(self, noisy_latent, condition=None):
h, w = noisy_latent.shape[2], noisy_latent.shape[3]
packed = Flux2Pipeline._pack_latents(noisy_latent)
img_ids = Flux2Pipeline._prepare_latent_ids(noisy_latent).to(self.device)
Expand Down
29 changes: 25 additions & 4 deletions lightx2v_train/lightx2v_train/model_zoo/flux2_klein.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@ class Flux2KleinDenoiserInput:
class Flux2KleinModel(BaseModel):
pipeline_cls = Flux2KleinPipeline

def load_components(self):
def load_components(self, transformer_only=False, reference_model=None):
if transformer_only:
if reference_model is not None:
self.text_pipeline = reference_model.text_pipeline
self.vae = reference_model.vae
self.image_processor = reference_model.image_processor
self.transformer = self.load_transformer()
return
model_path = self.config["model"]["pretrained_model_name_or_path"]
self.text_pipeline = Flux2KleinPipeline.from_pretrained(
model_path,
Expand Down Expand Up @@ -81,7 +88,11 @@ def encode_to_latent(self, sample):
latent = self.vae.encode(image).latent_dist.sample()
return self._normalize_patch_latents(latent)

def encode_prompt_text(self, prompt):
def encode_condition(self, sample):
prompt = sample["prompt"]
return self.encode_prompt_condition(prompt)

def encode_prompt_condition(self, prompt):
model_config = self.config["model"]
prompt_embed, text_ids = self.text_pipeline.encode_prompt(
prompt=prompt,
Expand All @@ -92,8 +103,18 @@ def encode_prompt_text(self, prompt):
)
return {"prompt_embed": prompt_embed, "text_ids": text_ids}

def encode_condition(self, sample):
return self.encode_prompt_text(sample["prompt"])
def encode_prompt_text(self, prompt):
return self.encode_prompt_condition(prompt)

def dmd_latent_shape(self, batch_size, height, width):
latent_h = 2 * (int(height) // (self.vae_scale_factor * 2))
latent_w = 2 * (int(width) // (self.vae_scale_factor * 2))
return (
batch_size,
self.transformer.config.in_channels,
latent_h // 2,
latent_w // 2,
)

def prepare_denoiser_input(self, noisy_latent, condition=None):
h, w = noisy_latent.shape[2], noisy_latent.shape[3]
Expand Down
Loading