Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions PixArt/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ def __init__(self, model_conf):
def model_type(self, state_dict, prefix=""):
return comfy.model_base.ModelType.EPS

def load_pixart(model_path, model_conf):
def load_pixart(model_path, model_conf, target_dtype: torch.dtype | None=None):
Comment thread
SLAPaper marked this conversation as resolved.
Outdated
state_dict = comfy.utils.load_torch_file(model_path)
state_dict = state_dict.get("model", state_dict)


# prefix
for prefix in ["model.diffusion_model.",]:
if any(True for x in state_dict if x.startswith(prefix)):
Expand All @@ -36,8 +37,11 @@ def load_pixart(model_path, model_conf):
if "adaln_single.linear.weight" in state_dict:
state_dict = convert_state_dict(state_dict) # Diffusers

parameters = comfy.utils.calculate_parameters(state_dict)
unet_dtype = model_management.unet_dtype(model_params=parameters)
if target_dtype is None:
parameters = comfy.utils.calculate_parameters(state_dict)
unet_dtype = model_management.unet_dtype(model_params=parameters)
else:
unet_dtype = target_dtype

model_conf = EXM_PixArt(model_conf) # convert to object
model = comfy.model_base.BaseModel(
Expand Down
27 changes: 26 additions & 1 deletion PixArt/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,22 @@
from .loader import load_pixart
from .sampler import sample_pixart

dtypes = [
"default",
"auto (comfy)",
"float32",
"float16",
"bfloat16",
]

class PixArtCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
"model": (list(pixart_conf.keys()),),
"dtype": (dtypes,),
}
}
RETURN_TYPES = ("MODEL",)
Expand All @@ -24,12 +33,28 @@ def INPUT_TYPES(s):
CATEGORY = "ExtraModels/PixArt"
TITLE = "PixArt Checkpoint Loader"

def load_checkpoint(self, ckpt_name, model):
def load_checkpoint(self, ckpt_name, model, dtype: str):
Comment thread
SLAPaper marked this conversation as resolved.
Outdated
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
model_conf = pixart_conf[model]

target_dtype: torch.dtype | None = None
Comment thread
SLAPaper marked this conversation as resolved.
Outdated
if dtype == "default":
target_dtype = torch.float16
elif dtype == 'auto (comfy)':
target_dtype = None
elif dtype == 'float32':
target_dtype = torch.float32
elif dtype == 'float16':
target_dtype = torch.float16
elif dtype == 'bfloat16':
target_dtype = torch.bfloat16
else:
raise ValueError(f"Invalid dtype: {dtype}")

model = load_pixart(
model_path = ckpt_path,
model_conf = model_conf,
target_dtype = target_dtype,
)
return (model,)

Expand Down