Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 95 additions & 3 deletions src/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,104 @@
from omegaconf import DictConfig, ListConfig, OmegaConf
from ..utils.model_registry import MODEL_CLASSES

#imports for Dynamic VRAM patching
import comfy
import torch
import torch.nn as nn
from ..models.video_vae_v3.modules.causal_inflation_lib import InflatedCausalConv3d

try:
OmegaConf.register_new_resolver("eval", eval)
except Exception as e:
if "already registered" not in str(e):
raise

def swap_layers_recursively(model, target_ops):
"""
Recursively replaces standard torch.nn modules with custom ComfyUI operations
from the provided target_ops class (e.g., manual_cast, fp8_ops).
"""
for name, child in model.named_children():
new_layer = None

# 1. Handle Linear Layers
if isinstance(child, nn.Linear) and hasattr(target_ops, "Linear"):
new_layer = target_ops.Linear(
child.in_features,
child.out_features,
bias=child.bias is not None,
device=child.weight.device,
dtype=child.weight.dtype
)

# 2. Handle Convolutional Layers (1D, 2D)
elif isinstance(child, (nn.Conv1d, nn.Conv2d)):
dim = 2 if isinstance(child, nn.Conv2d) else 1
op_name = f"Conv{dim}d"

if hasattr(target_ops, op_name):
target_cls = getattr(target_ops, op_name)
new_layer = target_cls(
child.in_channels,
child.out_channels,
child.kernel_size,
stride=child.stride,
padding=child.padding,
dilation=child.dilation,
groups=child.groups,
bias=child.bias is not None,
padding_mode=child.padding_mode,
device=child.weight.device,
dtype=child.weight.dtype
)

# 2b. Handle InflatedCausalConv3d
elif isinstance (child, InflatedCausalConv3d) and hasattr(target_ops, "Conv3d"):
child.__bases__ = (target_ops.Conv3d,)
new_layer = child

# 3. Handle Normalization Layers
elif isinstance(child, nn.LayerNorm) and hasattr(target_ops, "LayerNorm"):
new_layer = target_ops.LayerNorm(
child.normalized_shape,
eps=child.eps,
elementwise_affine=child.elementwise_affine,
device=child.weight.device if child.elementwise_affine else None,
dtype=child.weight.dtype if child.elementwise_affine else None
)

elif isinstance(child, nn.GroupNorm) and hasattr(target_ops, "GroupNorm"):
new_layer = target_ops.GroupNorm(
child.num_groups,
child.num_channels,
eps=child.eps,
affine=child.affine,
device=child.weight.device if child.affine else None,
dtype=child.weight.dtype if child.affine else None
)

# 4. Handle Embeddings
elif isinstance(child, nn.Embedding) and hasattr(target_ops, "Embedding"):
new_layer = target_ops.Embedding(
child.num_embeddings,
child.embedding_dim,
padding_idx=child.padding_idx,
max_norm=child.max_norm,
norm_type=child.norm_type,
scale_grad_by_freq=child.scale_grad_by_freq,
sparse=child.sparse,
device=child.weight.device,
dtype=child.weight.dtype
)

# 5. Apply replacement or recurse
if new_layer is not None:
setattr(model, name, new_layer)
else:
# If no replacement happened, recurse into children
swap_layers_recursively(child, target_ops)

return model

def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]:
"""
Expand Down Expand Up @@ -125,10 +216,11 @@ def create_object(config: DictConfig) -> Any:
name=config.__object__.name,
)
args = config.__object__.get("args", "as_config")
default_ops = comfy.ops.manual_cast
if args == "as_config":
return item(config)
return swap_layers_recursively(item(config), default_ops)
if args == "as_params":
config = OmegaConf.to_object(config)
config.pop("__object__")
return item(**config)
raise NotImplementedError(f"Unknown args type: {args}")
return swap_layers_recursively(item(**config), default_ops)
raise NotImplementedError(f"Unknown args type: {args}")
7 changes: 4 additions & 3 deletions src/core/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import torch
from omegaconf import OmegaConf
from typing import Dict, Any, Optional, Tuple, Union, Callable
import comfy

# Import SafeTensors with fallback
try:
Expand Down Expand Up @@ -118,7 +119,7 @@ def load_quantized_state_dict(checkpoint_path: str, device: torch.device = torch

# Try direct device loading first (optimal path)
try:
state = load_safetensors_file(checkpoint_path, device=device_str)
state = comfy.utils.load_torch_file(checkpoint_path, device=device_str)
except RuntimeError as e:
# MPS allocator fallback: some PyTorch/macOS versions have issues with
# direct MPS loading (allocation failures, watermark errors, etc.)
Expand All @@ -132,7 +133,7 @@ def load_quantized_state_dict(checkpoint_path: str, device: torch.device = torch
if debug:
debug.log("Using CPU intermediate loading for MPS compatibility",
category="info", indent_level=1)
state = load_safetensors_file(checkpoint_path, device="cpu")
state = comfy.utils.load_torch_file(checkpoint_path, device=device_str)
# Tensors will be moved to MPS during model.load_state_dict()
else:
# Re-raise if it's a different error (file corruption, etc.)
Expand Down Expand Up @@ -962,4 +963,4 @@ def dequantize(device: Optional[torch.device] = None,
debug.log(f"Warning: Could not dequantize tensor: {e}", level="WARNING", category="dit", force=True)
return tensor.to(device or tensor.device, dtype)

return dequantize
return dequantize