Skip to content

Commit ef91301

Browse files
authored
[QwenImage] fix prompt isolation tests (#13042)
* up * up * up * fix
1 parent 53d8a1e commit ef91301

File tree

7 files changed

+38
-74
lines changed

7 files changed

+38
-74
lines changed

examples/dreambooth/train_dreambooth_lora_qwen_image.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1467,7 +1467,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14671467
else:
14681468
num_repeat_elements = len(prompts)
14691469
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
1470-
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)
1470+
if prompt_embeds_mask is not None:
1471+
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)
14711472
# Convert images to latent space
14721473
if args.cache_latents:
14731474
model_input = latents_cache[step].sample()

src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -254,16 +254,17 @@ def encode_prompt(
254254
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
255255

256256
prompt_embeds = prompt_embeds[:, :max_sequence_length]
257-
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
258-
259257
_, seq_len, _ = prompt_embeds.shape
260258
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
261259
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
262-
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
263-
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
264260

265-
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
266-
prompt_embeds_mask = None
261+
if prompt_embeds_mask is not None:
262+
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
263+
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
264+
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
265+
266+
if prompt_embeds_mask.all():
267+
prompt_embeds_mask = None
267268

268269
return prompt_embeds, prompt_embeds_mask
269270

@@ -310,15 +311,6 @@ def check_inputs(
310311
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
311312
)
312313

313-
if prompt_embeds is not None and prompt_embeds_mask is None:
314-
raise ValueError(
315-
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
316-
)
317-
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
318-
raise ValueError(
319-
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
320-
)
321-
322314
if max_sequence_length is not None and max_sequence_length > 1024:
323315
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
324316

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,13 @@ def encode_prompt(
321321
_, seq_len, _ = prompt_embeds.shape
322322
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
323323
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
324-
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
325-
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
326324

327-
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
328-
prompt_embeds_mask = None
325+
if prompt_embeds_mask is not None:
326+
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
327+
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
328+
329+
if prompt_embeds_mask.all():
330+
prompt_embeds_mask = None
329331

330332
return prompt_embeds, prompt_embeds_mask
331333

@@ -372,15 +374,6 @@ def check_inputs(
372374
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
373375
)
374376

375-
if prompt_embeds is not None and prompt_embeds_mask is None:
376-
raise ValueError(
377-
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
378-
)
379-
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
380-
raise ValueError(
381-
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
382-
)
383-
384377
if max_sequence_length is not None and max_sequence_length > 1024:
385378
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
386379

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -378,14 +378,6 @@ def check_inputs(
378378
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
379379
)
380380

381-
if prompt_embeds is not None and prompt_embeds_mask is None:
382-
raise ValueError(
383-
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
384-
)
385-
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
386-
raise ValueError(
387-
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
388-
)
389381
if padding_mask_crop is not None:
390382
if not isinstance(image, PIL.Image.Image):
391383
raise ValueError(

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def get_timesteps(self, num_inference_steps, strength, device):
265265

266266
return timesteps, num_inference_steps - t_start
267267

268-
# Copied fromCopied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
268+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
269269
def encode_prompt(
270270
self,
271271
prompt: Union[str, List[str]],
@@ -297,16 +297,17 @@ def encode_prompt(
297297
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
298298

299299
prompt_embeds = prompt_embeds[:, :max_sequence_length]
300-
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
301-
302300
_, seq_len, _ = prompt_embeds.shape
303301
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
304302
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
305-
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
306-
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
307303

308-
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
309-
prompt_embeds_mask = None
304+
if prompt_embeds_mask is not None:
305+
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
306+
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
307+
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
308+
309+
if prompt_embeds_mask.all():
310+
prompt_embeds_mask = None
310311

311312
return prompt_embeds, prompt_embeds_mask
312313

@@ -357,15 +358,6 @@ def check_inputs(
357358
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
358359
)
359360

360-
if prompt_embeds is not None and prompt_embeds_mask is None:
361-
raise ValueError(
362-
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
363-
)
364-
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
365-
raise ValueError(
366-
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
367-
)
368-
369361
if max_sequence_length is not None and max_sequence_length > 1024:
370362
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
371363

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def get_timesteps(self, num_inference_steps, strength, device):
276276

277277
return timesteps, num_inference_steps - t_start
278278

279-
# Copied fromCopied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
279+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
280280
def encode_prompt(
281281
self,
282282
prompt: Union[str, List[str]],
@@ -308,16 +308,17 @@ def encode_prompt(
308308
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
309309

310310
prompt_embeds = prompt_embeds[:, :max_sequence_length]
311-
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
312-
313311
_, seq_len, _ = prompt_embeds.shape
314312
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
315313
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
316-
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
317-
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
318314

319-
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
320-
prompt_embeds_mask = None
315+
if prompt_embeds_mask is not None:
316+
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
317+
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
318+
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
319+
320+
if prompt_embeds_mask.all():
321+
prompt_embeds_mask = None
321322

322323
return prompt_embeds, prompt_embeds_mask
323324

@@ -372,14 +373,6 @@ def check_inputs(
372373
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
373374
)
374375

375-
if prompt_embeds is not None and prompt_embeds_mask is None:
376-
raise ValueError(
377-
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
378-
)
379-
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
380-
raise ValueError(
381-
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
382-
)
383376
if padding_mask_crop is not None:
384377
if not isinstance(image, PIL.Image.Image):
385378
raise ValueError(

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,16 +320,17 @@ def encode_prompt(
320320
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
321321

322322
prompt_embeds = prompt_embeds[:, :max_sequence_length]
323-
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
324-
325323
_, seq_len, _ = prompt_embeds.shape
326324
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
327325
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
328-
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
329-
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
330326

331-
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
332-
prompt_embeds_mask = None
327+
if prompt_embeds_mask is not None:
328+
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
329+
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
330+
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
331+
332+
if prompt_embeds_mask.all():
333+
prompt_embeds_mask = None
333334

334335
return prompt_embeds, prompt_embeds_mask
335336

0 commit comments

Comments
 (0)