Skip to content

Commit be5985b

Browse files
leisuzztcaimm
andcommitted
Bugfix for dreambooth flux2 img2img2
Co-authored-by: tcaimm <93749364+tcaimm@users.noreply.github.com>
1 parent 86da067 commit be5985b

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

examples/dreambooth/train_dreambooth_lora_flux2_img2img.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,7 +1654,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16541654
packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input)
16551655
packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input)
16561656

1657-
noisy_len = packed_noisy_model_input.shape[1]
1657+
orig_input_shape = packed_noisy_model_input.shape
1658+
orig_input_ids_shape = model_input_ids.shape
16581659

16591660
# concatenate the model inputs with the cond inputs
16601661
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1)
@@ -1674,8 +1675,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16741675
img_ids=model_input_ids, # B, image_seq_len, 4
16751676
return_dict=False,
16761677
)[0]
1677-
model_pred = model_pred[:, :noisy_len:]
1678-
model_input_ids = model_input_ids[:, :noisy_len:]
1678+
model_pred = model_pred[:, : orig_input_shape[1], :]
1679+
model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :]
16791680

16801681
model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)
16811682

0 commit comments

Comments
 (0)