Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,3 +783,10 @@ class ZImagePixelSpace(ChromaRadiance):
No VAE encoding/decoding — the model operates directly on RGB pixels.
"""
pass

class CogVideoX(LatentFormat):
latent_channels = 16
latent_dimensions = 3

def __init__(self):
self.scale_factor = 1.15258426
Empty file.
571 changes: 571 additions & 0 deletions comfy/ldm/cogvideo/model.py

Large diffs are not rendered by default.

570 changes: 570 additions & 0 deletions comfy/ldm/cogvideo/vae.py

Large diffs are not rendered by default.

485 changes: 485 additions & 0 deletions comfy/ldm/cogvideo/vae_backup.py

Large diffs are not rendered by default.

60 changes: 60 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15
import comfy.ldm.cogvideo.model
import comfy.ldm.rt_detr.rtdetr_v4

import comfy.model_management
Expand Down Expand Up @@ -79,6 +80,7 @@ class ModelType(Enum):
IMG_TO_IMG = 9
FLOW_COSMOS = 10
IMG_TO_IMG_FLOW = 11
V_PREDICTION_DDPM = 12


def model_sampling(model_config, model_type):
Expand Down Expand Up @@ -113,6 +115,8 @@ def model_sampling(model_config, model_type):
s = comfy.model_sampling.ModelSamplingCosmosRFlow
elif model_type == ModelType.IMG_TO_IMG_FLOW:
c = comfy.model_sampling.IMG_TO_IMG_FLOW
elif model_type == ModelType.V_PREDICTION_DDPM:
c = comfy.model_sampling.V_PREDICTION_DDPM

class ModelSampling(s, c):
pass
Expand Down Expand Up @@ -1962,3 +1966,59 @@ def concat_cond(self, **kwargs):
class RT_DETR_v4(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)

class CogVideoX(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_DDPM, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cogvideo.model.CogVideoXTransformer3DModel)
self.image_to_video = image_to_video

def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
# Detect extra channels needed (e.g. 32 - 16 = 16 for ref latent)
extra_channels = self.diffusion_model.in_channels - noise.shape[1]
if extra_channels == 0:
return None

image = kwargs.get("concat_latent_image", None)
device = kwargs["device"]

if image is None:
shape = list(noise.shape)
shape[1] = extra_channels
return torch.zeros(shape, dtype=noise.dtype, layout=noise.layout, device=noise.device)

latent_dim = self.latent_format.latent_channels
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")

if noise.ndim == 5 and image.ndim == 5:
if image.shape[-3] < noise.shape[-3]:
image = torch.nn.functional.pad(image, (0, 0, 0, 0, 0, noise.shape[-3] - image.shape[-3]), "constant", 0)
elif image.shape[-3] > noise.shape[-3]:
image = image[:, :, :noise.shape[-3]]

for i in range(0, image.shape[1], latent_dim):
image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim])
image = utils.resize_to_batch_size(image, noise.shape[0])

if image.shape[1] > extra_channels:
image = image[:, :extra_channels]
elif image.shape[1] < extra_channels:
repeats = extra_channels // image.shape[1]
remainder = extra_channels % image.shape[1]
parts = [image] * repeats
if remainder > 0:
parts.append(image[:, :remainder])
image = torch.cat(parts, dim=1)

return image
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
# OFS embedding (CogVideoX 1.5 I2V), default 2.0 as used by SparkVSR
if self.diffusion_model.ofs_proj_dim is not None:
ofs = kwargs.get("ofs", None)
if ofs is None:
noise = kwargs.get("noise", None)
ofs = torch.full((noise.shape[0],), 2.0, device=noise.device, dtype=noise.dtype)
out['ofs'] = comfy.conds.CONDRegular(ofs)
Comment on lines +2018 to +2023
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Convert caller-provided ofs values to a tensor before wrapping them.

The default path creates a tensor, but Line 2019 accepts raw kwargs unchanged. If a workflow passes ofs as a Python float/int, CONDRegular later treats it like a tensor during batching/concat and that override path breaks.

Proposed fix
         if self.diffusion_model.ofs_proj_dim is not None:
+            noise = kwargs.get("noise", None)
             ofs = kwargs.get("ofs", None)
             if ofs is None:
-                noise = kwargs.get("noise", None)
                 ofs = torch.full((noise.shape[0],), 2.0, device=noise.device, dtype=noise.dtype)
+            else:
+                ofs = torch.as_tensor(ofs, device=noise.device, dtype=noise.dtype).reshape(-1)
             out['ofs'] = comfy.conds.CONDRegular(ofs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@comfy/model_base.py` around lines 2018 - 2023, The caller-provided ofs from
kwargs may be a scalar (int/float) and is passed unchanged to
comfy.conds.CONDRegular, causing later tensor operations to fail; in the branch
inside the if self.diffusion_model.ofs_proj_dim is not None block, coerce/of
convert kwargs.get("ofs") into a torch tensor matching the batch
size/device/dtype of noise (like using torch.full or torch.tensor with
device=noise.device and dtype=noise.dtype and shape (noise.shape[0],)), before
assigning out['ofs'] = comfy.conds.CONDRegular(ofs); ensure you still handle
None by falling back to the existing default creation path and preserve
dtype/device from noise.

return out
48 changes: 48 additions & 0 deletions comfy/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):

return dit_config

if '{}blocks.0.norm1.linear.weight'.format(key_prefix) in state_dict_keys: # CogVideoX
dit_config = {}
dit_config["image_model"] = "cogvideox"

# Extract config from weight shapes
norm1_weight = state_dict['{}blocks.0.norm1.linear.weight'.format(key_prefix)]
time_embed_dim = norm1_weight.shape[1]
dim = norm1_weight.shape[0] // 6

dit_config["num_attention_heads"] = dim // 64
dit_config["attention_head_dim"] = 64
dit_config["time_embed_dim"] = time_embed_dim
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')

# Detect in_channels from patch_embed
patch_proj_key = '{}patch_embed.proj.weight'.format(key_prefix)
if patch_proj_key in state_dict_keys:
w = state_dict[patch_proj_key]
if w.ndim == 4:
# Conv2d: [out, in, kh, kw] — CogVideoX 1.0
dit_config["in_channels"] = w.shape[1]
dit_config["patch_size"] = w.shape[2]
elif w.ndim == 2:
# Linear: [out, in_channels * patch_size * patch_size * patch_size_t] — CogVideoX 1.5
dit_config["patch_size"] = 2
dit_config["patch_size_t"] = 2
dit_config["in_channels"] = w.shape[1] // (2 * 2 * 2) # 256 // 8 = 32

text_proj_key = '{}patch_embed.text_proj.weight'.format(key_prefix)
if text_proj_key in state_dict_keys:
dit_config["text_embed_dim"] = state_dict[text_proj_key].shape[1]

# Detect OFS embedding
ofs_key = '{}ofs_embedding_linear_1.weight'.format(key_prefix)
if ofs_key in state_dict_keys:
dit_config["ofs_embed_dim"] = state_dict[ofs_key].shape[1]

# Detect positional embedding type
pos_key = '{}patch_embed.pos_embedding'.format(key_prefix)
if pos_key in state_dict_keys:
dit_config["use_learned_positional_embeddings"] = True
dit_config["use_rotary_positional_embeddings"] = False
else:
dit_config["use_learned_positional_embeddings"] = False
dit_config["use_rotary_positional_embeddings"] = True
Comment on lines +531 to +537
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don’t infer RoPE from a missing patch_embed.pos_embedding key.

This fallback conflates two different cases. In comfy/ldm/cogvideo/model.py, CogVideoXPatchEmbed only persists pos_embedding when use_learned_positional_embeddings is true, so fixed sin/cos checkpoints also omit that key. With the current else, those 1.0-style checkpoints get loaded as use_rotary_positional_embeddings=True, which changes token positions at inference.

A safer split here is:

  • pos_embedding present → learned positional embeddings
  • temporal patching / 1.5-style layout detected → rotary positional embeddings
  • otherwise → fixed sin/cos (use_rotary_positional_embeddings=False, use_learned_positional_embeddings=False)
Suggested fix
         pos_key = '{}patch_embed.pos_embedding'.format(key_prefix)
         if pos_key in state_dict_keys:
             dit_config["use_learned_positional_embeddings"] = True
             dit_config["use_rotary_positional_embeddings"] = False
-        else:
+        elif dit_config.get("patch_size_t") is not None:
             dit_config["use_learned_positional_embeddings"] = False
             dit_config["use_rotary_positional_embeddings"] = True
+        else:
+            dit_config["use_learned_positional_embeddings"] = False
+            dit_config["use_rotary_positional_embeddings"] = False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@comfy/model_detection.py` around lines 531 - 537, The current logic treats a
missing pos_embedding as evidence to enable rotary embeddings; instead implement
a three-way decision: if pos_key
('{}patch_embed.pos_embedding'.format(key_prefix)) is present set
dit_config["use_learned_positional_embeddings"]=True and use_rotary=False; else
detect a temporal-patching / 1.5-style layout by scanning state_dict_keys for
known temporal/1.5 markers (e.g. any key under the same patch_embed prefix
containing 'temporal', 'time', or 'temporal_patch' or other
CogVideoXPatchEmbed-specific temporal names) and in that case set
use_rotary_positional_embeddings=True and
use_learned_positional_embeddings=False; otherwise set both
use_rotary_positional_embeddings=False and
use_learned_positional_embeddings=False. Reference pos_key, state_dict_keys,
dit_config and CogVideoXPatchEmbed when making the change.


return dit_config

if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {}
dit_config["image_model"] = "wan2.1"
Expand Down
24 changes: 24 additions & 0 deletions comfy/model_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,30 @@ def calculate_denoised(self, sigma, model_output, model_input):
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5

class V_PREDICTION_DDPM:
"""CogVideoX v-prediction: model receives raw x_t (unscaled), predicts velocity v.
x_0 = sqrt(alpha) * x_t - sqrt(1-alpha) * v
= x_t / sqrt(sigma^2 + 1) - v * sigma / sqrt(sigma^2 + 1)
"""
def calculate_input(self, sigma, noise):
return noise

def calculate_denoised(self, sigma, model_output, model_input):
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input / (sigma ** 2 + 1.0) ** 0.5 - model_output * sigma / (sigma ** 2 + 1.0) ** 0.5

def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = reshape_sigma(sigma, noise.ndim)
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
noise = noise * sigma
noise += latent_image
return noise
Comment on lines +62 to +76
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Construct DDPM x_t in normalized space, not K-diffusion space.

calculate_input()/calculate_denoised() assume model_input is raw DDPM x_t, but noise_scaling() still returns latent_image + sigma * noise. That injects an extra sqrt(1 + sigma^2) factor, so CogVideoX gets the wrong input distribution and high-sigma steps are badly mis-scaled.

Suggested fix
     def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
         sigma = reshape_sigma(sigma, noise.ndim)
-        if max_denoise:
-            noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
-        else:
-            noise = noise * sigma
-        noise += latent_image
-        return noise
+        return (latent_image + noise * sigma) / torch.sqrt(1.0 + sigma ** 2.0)

Because CogVideoX defaults to ModelType.V_PREDICTION_DDPM, this one propagates straight into the new model path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@comfy/model_sampling.py` around lines 62 - 76, calculate_input currently
returns raw noise but must construct DDPM x_t in normalized space; change
calculate_input to return latent_image + sigma * noise (using reshape_sigma as
needed). Fix noise_scaling so it only scales the noise by sigma (do not use
sqrt(1+sigma**2)) and only add latent_image when producing x_t; remove the
max_denoise branch applying torch.sqrt(1.0 + sigma**2). Leave calculate_denoised
as-is (it expects model_input to be DDPM x_t) and ensure all references use the
reshaped sigma variable.


def inverse_noise_scaling(self, sigma, latent):
return latent

class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = reshape_sigma(sigma, model_output.ndim)
Expand Down
12 changes: 12 additions & 0 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.cogvideo.vae
import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder
import comfy.pixel_space_convert
Expand Down Expand Up @@ -650,6 +651,17 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)

self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd and "decoder.mid_block.resnets.0.norm1.norm_layer.weight" in sd: # CogVideoX VAE
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = sd["encoder.conv_out.conv.weight"].shape[0] // 2
self.first_stage_model = comfy.ldm.cogvideo.vae.AutoencoderKLCogVideoX(latent_channels=self.latent_channels)
self.memory_used_decode = lambda shape, dtype: (2800 * max(2, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1400 * max(1, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
Expand Down
52 changes: 51 additions & 1 deletion comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,6 +1749,56 @@ def get_model(self, state_dict, prefix="", device=None):
def clip_target(self, state_dict={}):
return None

models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
class CogVideoX_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "cogvideox",
}

sampling_settings = {
"linear_start": 0.00085,
"linear_end": 0.012,
"beta_schedule": "linear",
"zsnr": True,
}

unet_extra_config = {}
latent_format = latent_formats.CogVideoX

supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]

vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]

def get_model(self, state_dict, prefix="", device=None):
# CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE
if self.unet_config.get("patch_size_t") is not None:
self.unet_config.setdefault("sample_height", 96)
self.unet_config.setdefault("sample_width", 170)
self.unet_config.setdefault("sample_frames", 81)
out = model_base.CogVideoX(self, device=device)
return out

def clip_target(self, state_dict={}):
class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226)

return supported_models_base.ClipTarget(CogVideoXT5Tokenizer, comfy.text_encoders.sd3_clip.T5XXLModel)

class CogVideoX_I2V(CogVideoX_T2V):
unet_config = {
"image_model": "cogvideox",
"in_channels": 32,
}

def get_model(self, state_dict, prefix="", device=None):
if self.unet_config.get("patch_size_t") is not None:
self.unet_config.setdefault("sample_height", 96)
self.unet_config.setdefault("sample_width", 170)
self.unet_config.setdefault("sample_frames", 81)
out = model_base.CogVideoX(self, image_to_video=True, device=device)
return out
Comment thread
coderabbitai[bot] marked this conversation as resolved.

models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, CogVideoX_I2V, CogVideoX_T2V]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comfyanonymous @Kosinkadink I think that this list should be formatted in a way that each element goes into a separate line. In this way the diff is very clear. Right now with every change, the whole array is shown as changed.
Why I mention this? When rebasing or merging with master one could easily remove a model that was already added if not careful.
What do you think?


models += [SVD_img2vid]
2 changes: 1 addition & 1 deletion nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2457,7 +2457,7 @@ async def init_builtin_extra_nodes():
"nodes_number_convert.py",
"nodes_painter.py",
"nodes_curve.py",
"nodes_rtdetr.py"
"nodes_rtdetr.py",
]

import_failed = []
Expand Down
Loading