Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/diffusers/modular_pipelines/qwenimage/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,15 @@ def get_qwen_prompt_embeds_edit(
).to(device)

outputs = text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
input_ids=model_inputs.get("input_ids"),
attention_mask=model_inputs.get("attention_mask"),
pixel_values=model_inputs.get("pixel_values"),
image_grid_thw=model_inputs.get("image_grid_thw"),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reasoning behind this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When image=None is passed to the Qwen2VLProcessor (for example, when encoding the negative_prompt), the processor returns a BatchFeature object that does not contain the pixel_values and image_grid_thw keys.

Attempting to access them directly via dot notation (e.g. model_inputs.pixel_values) raises an AttributeError from the underlying transformers.utils.logging.FeatureExtractionUtils mapping class. Using .get() safely defaults to None, which prevents the pipeline from crashing during negative prompt generations. I also updated input_ids and attention_mask to use .get() here to maintain a consistent access pattern within the same function call.

output_hidden_states=True,
)

hidden_states = outputs.hidden_states[-1]
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.get("attention_mask"))
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
Expand Down Expand Up @@ -173,15 +173,15 @@ def get_qwen_prompt_embeds_edit_plus(
return_tensors="pt",
).to(device)
outputs = text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
input_ids=model_inputs.get("input_ids"),
attention_mask=model_inputs.get("attention_mask"),
pixel_values=model_inputs.get("pixel_values"),
image_grid_thw=model_inputs.get("image_grid_thw"),
output_hidden_states=True,
)

hidden_states = outputs.hidden_states[-1]
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.get("attention_mask"))
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
Expand Down
4 changes: 1 addition & 3 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,9 +584,7 @@ def __call__(

device = self._execution_device

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we handle negative_prompt_embeds_mask? 👀 Is this block affected?

if prompt_embeds_mask is not None:
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask.all():
prompt_embeds_mask = None

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If negative_prompt_embeds_mask isn't provided (is None), that block is safely skipped, and encode_prompt just returns None for the mask. The transformer model later receives this None mask and handles it natively by treating all tokens as valid.

Interestingly, while working on this, I found that some other variants (like edit and inpaint) were actually missing this exact if prompt_embeds_mask is not None: check and were crashing on .repeat(). I've added the same check to those pipelines in this PR as well so they all handle None masks gracefully now.


if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,7 @@ def __call__(

device = self._execution_device

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None

if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,13 @@ def encode_prompt(
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
if prompt_embeds_mask is not None:
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask.all():
prompt_embeds_mask = None
Comment on lines +308 to +314
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like an unrelated change? Since it already has

# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt

if we run make fix-copies, the changes would be propagated automatically.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll revert the manual edits on the copied methods, run make fix-copies to sync them up properly, and update the PR.


return prompt_embeds, prompt_embeds_mask

Expand Down Expand Up @@ -353,15 +355,6 @@ def check_inputs(
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)

if prompt_embeds is not None and prompt_embeds_mask is None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reasoning behind this?

What happens when users pass the embeds and not the masks? Maybe we should warn them?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the hard ValueError here because the base QwenImagePipeline actually allows passing embeds without masks, and the underlying transformer natively handles a None mask by treating all tokens as valid. Throwing a strict exception here was completely blocking users from doing exactly what this PR is fixing (passing negative_prompt_embeds without a mask to trigger True CFG).

That said, I totally agree we shouldn't just let it pass silently, especially since the text encoder's output often relies on masks for sequence packing. I'll swap these hard exceptions out for a logger.warning to let users know they should probably pass the mask if they have it. I'll get that updated!

raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
)

if max_sequence_length is not None and max_sequence_length > 1024:
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")

Expand Down Expand Up @@ -739,9 +732,7 @@ def __call__(

device = self._execution_device

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
Expand Down
33 changes: 12 additions & 21 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,15 @@ def _get_qwen_prompt_embeds(
).to(device)

outputs = self.text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
input_ids=model_inputs.get("input_ids"),
attention_mask=model_inputs.get("attention_mask"),
pixel_values=model_inputs.get("pixel_values"),
image_grid_thw=model_inputs.get("image_grid_thw"),
output_hidden_states=True,
)

hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.get("attention_mask"))
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
Expand Down Expand Up @@ -306,11 +306,13 @@ def encode_prompt(
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
if prompt_embeds_mask is not None:
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

Expand Down Expand Up @@ -357,15 +359,6 @@ def check_inputs(
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)

if prompt_embeds is not None and prompt_embeds_mask is None:
raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
)

if max_sequence_length is not None and max_sequence_length > 1024:
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")

Expand Down Expand Up @@ -705,9 +698,7 @@ def __call__(
image = self.image_processor.preprocess(image, calculated_height, calculated_width)
image = image.unsqueeze(2)

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None

if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,15 @@ def _get_qwen_prompt_embeds(
).to(device)

outputs = self.text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
input_ids=model_inputs.get("input_ids"),
attention_mask=model_inputs.get("attention_mask"),
pixel_values=model_inputs.get("pixel_values"),
image_grid_thw=model_inputs.get("image_grid_thw"),
output_hidden_states=True,
)

hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.get("attention_mask"))
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
Expand Down Expand Up @@ -318,11 +318,13 @@ def encode_prompt(
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
if prompt_embeds_mask is not None:
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

Expand Down Expand Up @@ -878,9 +880,7 @@ def __call__(
)
image = image.to(dtype=torch.float32)

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None

if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
Expand Down
33 changes: 12 additions & 21 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,15 +260,15 @@ def _get_qwen_prompt_embeds(
).to(device)

outputs = self.text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
input_ids=model_inputs.get("input_ids"),
attention_mask=model_inputs.get("attention_mask"),
pixel_values=model_inputs.get("pixel_values"),
image_grid_thw=model_inputs.get("image_grid_thw"),
output_hidden_states=True,
)

hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.get("attention_mask"))
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
Expand Down Expand Up @@ -320,11 +320,13 @@ def encode_prompt(
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
if prompt_embeds_mask is not None:
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

Expand Down Expand Up @@ -372,15 +374,6 @@ def check_inputs(
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)

if prompt_embeds is not None and prompt_embeds_mask is None:
raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
)

if max_sequence_length is not None and max_sequence_length > 1024:
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")

Expand Down Expand Up @@ -693,9 +686,7 @@ def __call__(
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None

if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -677,9 +677,7 @@ def __call__(

device = self._execution_device

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None

if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -822,9 +822,7 @@ def __call__(

device = self._execution_device

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None

if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
Expand Down
13 changes: 1 addition & 12 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,15 +384,6 @@ def check_inputs(
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)

if prompt_embeds is not None and prompt_embeds_mask is None:
raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
)

if max_sequence_length is not None and max_sequence_length > 1024:
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")

Expand Down Expand Up @@ -697,9 +688,7 @@ def __call__(
else:
batch_size = prompt_embeds.shape[0]

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None

if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
Expand Down
26 changes: 26 additions & 0 deletions tests/pipelines/qwenimage/test_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,29 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
expected_diff_max,
"VAE tiling should not affect the inference results",
)

def test_true_cfg_without_negative_prompt_embeds_mask(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pipe.to("cpu")
pipe.to(torch_device)

pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs("cpu")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inputs = self.get_dummy_inputs("cpu")
inputs = self.get_dummy_inputs(torch_device)

prompt = inputs.pop("prompt")

prompt_embeds, prompt_embeds_mask = pipe.encode_prompt(
prompt=prompt,
device="cpu",
num_images_per_prompt=1,
max_sequence_length=inputs.get("max_sequence_length", 16),
)

inputs["prompt_embeds"] = prompt_embeds
inputs["prompt_embeds_mask"] = prompt_embeds_mask
inputs["negative_prompt_embeds"] = prompt_embeds
inputs["negative_prompt"] = None
inputs["negative_prompt_embeds_mask"] = None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we not be using it because it's the core thing that we're fixing here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention here is explicitly to test the pipeline without providing the negative_prompt_embeds_mask (hence the test name test_true_cfg_without_negative_prompt_embeds_mask), to verify that our fix correctly catches the missing mask, issues the logger.warning, and successfully proceeds with CFG instead of throwing a hard ValueError or silently disabling CFG.

To make this intent clearer and avoid explicitly assigning None to the inputs dict, I will update the test to use inputs.pop(...) so that it relies directly on the pipeline's default None arguments.

inputs["true_cfg_scale"] = 2.0

image = pipe(**inputs).images
self.assertIsNotNone(image)
26 changes: 26 additions & 0 deletions tests/pipelines/qwenimage/test_qwenimage_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,29 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
expected_diff_max,
"VAE tiling should not affect the inference results",
)

def test_true_cfg_without_negative_prompt_embeds_mask(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs("cpu")
prompt = inputs.pop("prompt")

prompt_embeds, prompt_embeds_mask = pipe.encode_prompt(
prompt=prompt,
device="cpu",
num_images_per_prompt=1,
max_sequence_length=inputs.get("max_sequence_length", 16),
)

inputs["prompt_embeds"] = prompt_embeds
inputs["prompt_embeds_mask"] = prompt_embeds_mask
inputs["negative_prompt_embeds"] = prompt_embeds
inputs["negative_prompt"] = None
inputs["negative_prompt_embeds_mask"] = None
inputs["true_cfg_scale"] = 2.0

image = pipe(**inputs).images
self.assertIsNotNone(image)
Loading
Loading