Skip to content

Commit 71a865b

Browse files
adi776borateyiyixuxugithub-actions[bot]
authored
Fix: Cosmos2.5 Video2World frame extraction and add default negative prompt (#13018)
* fix: Extract last frames for conditioning in Cosmos Video2World * Added default negative prompt * Apply style fixes * Added default negative prompt in cosmos2 text2image pipeline --------- Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 53279ef commit 71a865b

File tree

5 files changed

+75
-10
lines changed

5 files changed

+75
-10
lines changed

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ def __init__(self, *args, **kwargs):
5252

5353
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
5454

55+
DEFAULT_NEGATIVE_PROMPT = (
56+
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
57+
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
58+
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
59+
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
60+
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
61+
"Overall, the video is of poor quality."
62+
)
63+
5564

5665
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
5766
def retrieve_latents(
@@ -359,7 +368,7 @@ def encode_prompt(
359368
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
360369

361370
if do_classifier_free_guidance and negative_prompt_embeds is None:
362-
negative_prompt = negative_prompt or ""
371+
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
363372
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
364373

365374
if prompt is not None and type(prompt) is not type(negative_prompt):
@@ -549,6 +558,7 @@ def __call__(
549558
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
550559
max_sequence_length: int = 512,
551560
conditional_frame_timestep: float = 0.1,
561+
num_latent_conditional_frames: int = 2,
552562
):
553563
r"""
554564
The call function to the pipeline for generation. Supports three modes:
@@ -614,6 +624,10 @@ def __call__(
614624
max_sequence_length (`int`, defaults to `512`):
615625
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
616626
the prompt is shorter than this length, it will be padded.
627+
num_latent_conditional_frames (`int`, defaults to `2`):
628+
Number of latent conditional frames to use for Video2World conditioning. The number of pixel frames
629+
extracted from the input video is calculated as `4 * (num_latent_conditional_frames - 1) + 1`. Set to 1
630+
for Image2World-like behavior (single frame conditioning).
617631
618632
Examples:
619633
@@ -692,19 +706,38 @@ def __call__(
692706
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
693707
num_frames_in = 0
694708
else:
695-
num_frames_in = len(video)
696-
697709
if batch_size != 1:
698710
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
699711

712+
if num_latent_conditional_frames not in [1, 2]:
713+
raise ValueError(
714+
f"num_latent_conditional_frames must be 1 or 2, but got {num_latent_conditional_frames}"
715+
)
716+
717+
frames_to_extract = 4 * (num_latent_conditional_frames - 1) + 1
718+
719+
total_input_frames = len(video)
720+
721+
if total_input_frames < frames_to_extract:
722+
raise ValueError(
723+
f"Input video has only {total_input_frames} frames but Video2World requires at least "
724+
f"{frames_to_extract} frames for conditioning."
725+
)
726+
727+
num_frames_in = frames_to_extract
728+
700729
assert video is not None
701730
video = self.video_processor.preprocess_video(video, height, width)
702731

703-
# pad with last frame (for video2world)
732+
# For Video2World: extract last frames_to_extract frames from input, then pad
733+
if image is None and num_frames_in > 0 and num_frames_in < video.shape[2]:
734+
video = video[:, :, -num_frames_in:, :, :]
735+
704736
num_frames_out = num_frames
737+
705738
if video.shape[2] < num_frames_out:
706-
n_pad_frames = num_frames_out - num_frames_in
707-
last_frame = video[0, :, -1:, :, :] # [C, T==1, H, W]
739+
n_pad_frames = num_frames_out - video.shape[2]
740+
last_frame = video[:, :, -1:, :, :] # [B, C, T==1, H, W]
708741
pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W]
709742
video = torch.cat((video, pad_frames), dim=2)
710743

src/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ def __init__(self, *args, **kwargs):
4949

5050
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
5151

52+
DEFAULT_NEGATIVE_PROMPT = (
53+
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
54+
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
55+
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
56+
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
57+
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
58+
"Overall, the video is of poor quality."
59+
)
5260

5361
EXAMPLE_DOC_STRING = """
5462
Examples:
@@ -300,7 +308,7 @@ def encode_prompt(
300308
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
301309

302310
if do_classifier_free_guidance and negative_prompt_embeds is None:
303-
negative_prompt = negative_prompt or ""
311+
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
304312
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
305313

306314
if prompt is not None and type(prompt) is not type(negative_prompt):

src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ def __init__(self, *args, **kwargs):
5050

5151
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
5252

53+
DEFAULT_NEGATIVE_PROMPT = (
54+
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
55+
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
56+
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
57+
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
58+
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
59+
"Overall, the video is of poor quality."
60+
)
5361

5462
EXAMPLE_DOC_STRING = """
5563
Examples:
@@ -319,7 +327,7 @@ def encode_prompt(
319327
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
320328

321329
if do_classifier_free_guidance and negative_prompt_embeds is None:
322-
negative_prompt = negative_prompt or ""
330+
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
323331
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
324332

325333
if prompt is not None and type(prompt) is not type(negative_prompt):

src/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ def __init__(self, *args, **kwargs):
4949

5050
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
5151

52+
DEFAULT_NEGATIVE_PROMPT = (
53+
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
54+
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
55+
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
56+
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
57+
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
58+
"Overall, the video is of poor quality."
59+
)
5260

5361
EXAMPLE_DOC_STRING = """
5462
Examples:
@@ -285,7 +293,7 @@ def encode_prompt(
285293
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
286294

287295
if do_classifier_free_guidance and negative_prompt_embeds is None:
288-
negative_prompt = negative_prompt or ""
296+
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
289297
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
290298

291299
if prompt is not None and type(prompt) is not type(negative_prompt):

src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ def __init__(self, *args, **kwargs):
5050

5151
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
5252

53+
DEFAULT_NEGATIVE_PROMPT = (
54+
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
55+
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
56+
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
57+
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
58+
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
59+
"Overall, the video is of poor quality."
60+
)
5361

5462
EXAMPLE_DOC_STRING = """
5563
Examples:
@@ -331,7 +339,7 @@ def encode_prompt(
331339
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
332340

333341
if do_classifier_free_guidance and negative_prompt_embeds is None:
334-
negative_prompt = negative_prompt or ""
342+
negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT
335343
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
336344

337345
if prompt is not None and type(prompt) is not type(negative_prompt):

0 commit comments

Comments
 (0)