-
Notifications
You must be signed in to change notification settings - Fork 13k
CogVideoX #13351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CogVideoX #13351
Changes from all commits
3e4d15a
bf2c582
a16fc7e
63739c3
955ce48
8cf387b
ff916d8
845da1b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,6 +52,7 @@ | |
| import comfy.ldm.kandinsky5.model | ||
| import comfy.ldm.anima.model | ||
| import comfy.ldm.ace.ace_step15 | ||
| import comfy.ldm.cogvideo.model | ||
| import comfy.ldm.rt_detr.rtdetr_v4 | ||
|
|
||
| import comfy.model_management | ||
|
|
@@ -79,6 +80,7 @@ class ModelType(Enum): | |
| IMG_TO_IMG = 9 | ||
| FLOW_COSMOS = 10 | ||
| IMG_TO_IMG_FLOW = 11 | ||
| V_PREDICTION_DDPM = 12 | ||
|
|
||
|
|
||
| def model_sampling(model_config, model_type): | ||
|
|
@@ -113,6 +115,8 @@ def model_sampling(model_config, model_type): | |
| s = comfy.model_sampling.ModelSamplingCosmosRFlow | ||
| elif model_type == ModelType.IMG_TO_IMG_FLOW: | ||
| c = comfy.model_sampling.IMG_TO_IMG_FLOW | ||
| elif model_type == ModelType.V_PREDICTION_DDPM: | ||
| c = comfy.model_sampling.V_PREDICTION_DDPM | ||
|
|
||
| class ModelSampling(s, c): | ||
| pass | ||
|
|
@@ -1962,3 +1966,59 @@ def concat_cond(self, **kwargs): | |
| class RT_DETR_v4(BaseModel): | ||
| def __init__(self, model_config, model_type=ModelType.FLOW, device=None): | ||
| super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4) | ||
|
|
||
| class CogVideoX(BaseModel): | ||
| def __init__(self, model_config, model_type=ModelType.V_PREDICTION_DDPM, image_to_video=False, device=None): | ||
| super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cogvideo.model.CogVideoXTransformer3DModel) | ||
| self.image_to_video = image_to_video | ||
|
|
||
| def concat_cond(self, **kwargs): | ||
| noise = kwargs.get("noise", None) | ||
| # Detect extra channels needed (e.g. 32 - 16 = 16 for ref latent) | ||
| extra_channels = self.diffusion_model.in_channels - noise.shape[1] | ||
| if extra_channels == 0: | ||
| return None | ||
|
|
||
| image = kwargs.get("concat_latent_image", None) | ||
| device = kwargs["device"] | ||
|
|
||
| if image is None: | ||
| shape = list(noise.shape) | ||
| shape[1] = extra_channels | ||
| return torch.zeros(shape, dtype=noise.dtype, layout=noise.layout, device=noise.device) | ||
|
|
||
| latent_dim = self.latent_format.latent_channels | ||
| image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") | ||
|
|
||
| if noise.ndim == 5 and image.ndim == 5: | ||
| if image.shape[-3] < noise.shape[-3]: | ||
| image = torch.nn.functional.pad(image, (0, 0, 0, 0, 0, noise.shape[-3] - image.shape[-3]), "constant", 0) | ||
| elif image.shape[-3] > noise.shape[-3]: | ||
| image = image[:, :, :noise.shape[-3]] | ||
|
|
||
| for i in range(0, image.shape[1], latent_dim): | ||
| image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim]) | ||
| image = utils.resize_to_batch_size(image, noise.shape[0]) | ||
|
|
||
| if image.shape[1] > extra_channels: | ||
| image = image[:, :extra_channels] | ||
| elif image.shape[1] < extra_channels: | ||
| repeats = extra_channels // image.shape[1] | ||
| remainder = extra_channels % image.shape[1] | ||
| parts = [image] * repeats | ||
| if remainder > 0: | ||
| parts.append(image[:, :remainder]) | ||
| image = torch.cat(parts, dim=1) | ||
|
|
||
| return image | ||
|
|
||
| def extra_conds(self, **kwargs): | ||
| out = super().extra_conds(**kwargs) | ||
| # OFS embedding (CogVideoX 1.5 I2V), default 2.0 as used by SparkVSR | ||
| if self.diffusion_model.ofs_proj_dim is not None: | ||
| ofs = kwargs.get("ofs", None) | ||
| if ofs is None: | ||
| noise = kwargs.get("noise", None) | ||
| ofs = torch.full((noise.shape[0],), 2.0, device=noise.device, dtype=noise.dtype) | ||
| out['ofs'] = comfy.conds.CONDRegular(ofs) | ||
|
Comment on lines
+2018
to
+2023
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Convert caller-provided The default path creates a tensor, but Line 2019 accepts raw kwargs unchanged. If a workflow passes Proposed fix if self.diffusion_model.ofs_proj_dim is not None:
+ noise = kwargs.get("noise", None)
ofs = kwargs.get("ofs", None)
if ofs is None:
- noise = kwargs.get("noise", None)
ofs = torch.full((noise.shape[0],), 2.0, device=noise.device, dtype=noise.dtype)
+ else:
+ ofs = torch.as_tensor(ofs, device=noise.device, dtype=noise.dtype).reshape(-1)
out['ofs'] = comfy.conds.CONDRegular(ofs)🤖 Prompt for AI Agents |
||
| return out | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -490,6 +490,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): | |
|
|
||
| return dit_config | ||
|
|
||
| if '{}blocks.0.norm1.linear.weight'.format(key_prefix) in state_dict_keys: # CogVideoX | ||
| dit_config = {} | ||
| dit_config["image_model"] = "cogvideox" | ||
|
|
||
| # Extract config from weight shapes | ||
| norm1_weight = state_dict['{}blocks.0.norm1.linear.weight'.format(key_prefix)] | ||
| time_embed_dim = norm1_weight.shape[1] | ||
| dim = norm1_weight.shape[0] // 6 | ||
|
|
||
| dit_config["num_attention_heads"] = dim // 64 | ||
| dit_config["attention_head_dim"] = 64 | ||
| dit_config["time_embed_dim"] = time_embed_dim | ||
| dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') | ||
|
|
||
| # Detect in_channels from patch_embed | ||
| patch_proj_key = '{}patch_embed.proj.weight'.format(key_prefix) | ||
| if patch_proj_key in state_dict_keys: | ||
| w = state_dict[patch_proj_key] | ||
| if w.ndim == 4: | ||
| # Conv2d: [out, in, kh, kw] — CogVideoX 1.0 | ||
| dit_config["in_channels"] = w.shape[1] | ||
| dit_config["patch_size"] = w.shape[2] | ||
| elif w.ndim == 2: | ||
| # Linear: [out, in_channels * patch_size * patch_size * patch_size_t] — CogVideoX 1.5 | ||
| dit_config["patch_size"] = 2 | ||
| dit_config["patch_size_t"] = 2 | ||
| dit_config["in_channels"] = w.shape[1] // (2 * 2 * 2) # 256 // 8 = 32 | ||
|
|
||
| text_proj_key = '{}patch_embed.text_proj.weight'.format(key_prefix) | ||
| if text_proj_key in state_dict_keys: | ||
| dit_config["text_embed_dim"] = state_dict[text_proj_key].shape[1] | ||
|
|
||
| # Detect OFS embedding | ||
| ofs_key = '{}ofs_embedding_linear_1.weight'.format(key_prefix) | ||
| if ofs_key in state_dict_keys: | ||
| dit_config["ofs_embed_dim"] = state_dict[ofs_key].shape[1] | ||
|
|
||
| # Detect positional embedding type | ||
| pos_key = '{}patch_embed.pos_embedding'.format(key_prefix) | ||
| if pos_key in state_dict_keys: | ||
| dit_config["use_learned_positional_embeddings"] = True | ||
| dit_config["use_rotary_positional_embeddings"] = False | ||
| else: | ||
| dit_config["use_learned_positional_embeddings"] = False | ||
| dit_config["use_rotary_positional_embeddings"] = True | ||
|
Comment on lines
+531
to
+537
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don’t infer RoPE from a missing This fallback conflates two different cases. In A safer split here is:
Suggested fix pos_key = '{}patch_embed.pos_embedding'.format(key_prefix)
if pos_key in state_dict_keys:
dit_config["use_learned_positional_embeddings"] = True
dit_config["use_rotary_positional_embeddings"] = False
- else:
+ elif dit_config.get("patch_size_t") is not None:
dit_config["use_learned_positional_embeddings"] = False
dit_config["use_rotary_positional_embeddings"] = True
+ else:
+ dit_config["use_learned_positional_embeddings"] = False
+ dit_config["use_rotary_positional_embeddings"] = False🤖 Prompt for AI Agents |
||
|
|
||
| return dit_config | ||
|
|
||
| if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 | ||
| dit_config = {} | ||
| dit_config["image_model"] = "wan2.1" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,6 +54,30 @@ def calculate_denoised(self, sigma, model_output, model_input): | |
| sigma = reshape_sigma(sigma, model_output.ndim) | ||
| return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 | ||
|
|
||
| class V_PREDICTION_DDPM: | ||
| """CogVideoX v-prediction: model receives raw x_t (unscaled), predicts velocity v. | ||
| x_0 = sqrt(alpha) * x_t - sqrt(1-alpha) * v | ||
| = x_t / sqrt(sigma^2 + 1) - v * sigma / sqrt(sigma^2 + 1) | ||
| """ | ||
| def calculate_input(self, sigma, noise): | ||
| return noise | ||
|
|
||
| def calculate_denoised(self, sigma, model_output, model_input): | ||
| sigma = reshape_sigma(sigma, model_output.ndim) | ||
| return model_input / (sigma ** 2 + 1.0) ** 0.5 - model_output * sigma / (sigma ** 2 + 1.0) ** 0.5 | ||
|
|
||
| def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): | ||
| sigma = reshape_sigma(sigma, noise.ndim) | ||
| if max_denoise: | ||
| noise = noise * torch.sqrt(1.0 + sigma ** 2.0) | ||
| else: | ||
| noise = noise * sigma | ||
| noise += latent_image | ||
| return noise | ||
|
Comment on lines
+62
to
+76
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Construct DDPM
Suggested fix def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = reshape_sigma(sigma, noise.ndim)
- if max_denoise:
- noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
- else:
- noise = noise * sigma
- noise += latent_image
- return noise
+ return (latent_image + noise * sigma) / torch.sqrt(1.0 + sigma ** 2.0)Because 🤖 Prompt for AI Agents |
||
|
|
||
| def inverse_noise_scaling(self, sigma, latent): | ||
| return latent | ||
|
|
||
| class EDM(V_PREDICTION): | ||
| def calculate_denoised(self, sigma, model_output, model_input): | ||
| sigma = reshape_sigma(sigma, model_output.ndim) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1749,6 +1749,56 @@ def get_model(self, state_dict, prefix="", device=None): | |
| def clip_target(self, state_dict={}): | ||
| return None | ||
|
|
||
| models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4] | ||
| class CogVideoX_T2V(supported_models_base.BASE): | ||
| unet_config = { | ||
| "image_model": "cogvideox", | ||
| } | ||
|
|
||
| sampling_settings = { | ||
| "linear_start": 0.00085, | ||
| "linear_end": 0.012, | ||
| "beta_schedule": "linear", | ||
| "zsnr": True, | ||
| } | ||
|
|
||
| unet_extra_config = {} | ||
| latent_format = latent_formats.CogVideoX | ||
|
|
||
| supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] | ||
|
|
||
| vae_key_prefix = ["vae."] | ||
| text_encoder_key_prefix = ["text_encoders."] | ||
|
|
||
| def get_model(self, state_dict, prefix="", device=None): | ||
| # CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE | ||
| if self.unet_config.get("patch_size_t") is not None: | ||
| self.unet_config.setdefault("sample_height", 96) | ||
| self.unet_config.setdefault("sample_width", 170) | ||
| self.unet_config.setdefault("sample_frames", 81) | ||
| out = model_base.CogVideoX(self, device=device) | ||
| return out | ||
|
|
||
| def clip_target(self, state_dict={}): | ||
| class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer): | ||
| def __init__(self, embedding_directory=None, tokenizer_data={}): | ||
| super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226) | ||
|
|
||
| return supported_models_base.ClipTarget(CogVideoXT5Tokenizer, comfy.text_encoders.sd3_clip.T5XXLModel) | ||
|
|
||
| class CogVideoX_I2V(CogVideoX_T2V): | ||
| unet_config = { | ||
| "image_model": "cogvideox", | ||
| "in_channels": 32, | ||
| } | ||
|
|
||
| def get_model(self, state_dict, prefix="", device=None): | ||
| if self.unet_config.get("patch_size_t") is not None: | ||
| self.unet_config.setdefault("sample_height", 96) | ||
| self.unet_config.setdefault("sample_width", 170) | ||
| self.unet_config.setdefault("sample_frames", 81) | ||
| out = model_base.CogVideoX(self, image_to_video=True, device=device) | ||
| return out | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, CogVideoX_I2V, CogVideoX_T2V] | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @comfyanonymous @Kosinkadink I think that this list should be formatted in a way that each element goes into a separate line. In this way the diff is very clear. Right now with every change, the whole array is shown as changed. |
||
|
|
||
| models += [SVD_img2vid] | ||
Uh oh!
There was an error while loading. Please reload this page.