Skip to content

Commit 2f4750f

Browse files
committed
Update Comfy Registry Icon
1 parent dffe57e commit 2f4750f

11 files changed

Lines changed: 576 additions & 7 deletions

FS_Nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Project: FlowState Creator Nodes
1+
# Project: FlowState Creator Suite
22
# Description: A collection of custom nodes to solve problems I couldn't find existing nodes for.
33
# Author: Johnathan Chivington
44
# Contact: flowstateeng@gmail.com | youtube.com/@flowstateeng

FlowStateUnifiedModelLoader.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# Project: FlowState Unified Model Loader
2+
# Description: Load checkpoints and UNETs, includes NF4 support.
3+
# Version: 1.0.0
4+
# Author: Johnathan Chivington
5+
# Contact: johnathan@flowstateengineering.com | youtube.com/@flowstateeng
6+
7+
8+
##
9+
# SYSTEM STATUS
10+
##
11+
print(f' - Load Unified Model Loader node.')
12+
13+
14+
##
15+
# FS IMPORTS
16+
##
17+
from .FS_Assets import *
18+
from .FS_Constants import *
19+
from .FS_Types import *
20+
from .FS_Utils import *
21+
22+
23+
##
24+
# OUTSIDE IMPORTS
25+
##
26+
import torch
27+
28+
import os, sys, time, io
29+
import folder_paths
30+
import warnings
31+
from contextlib import redirect_stdout
32+
33+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
34+
import comfy.sd
35+
36+
from comfy.utils import load_torch_file
37+
from nodes import UNETLoader
38+
from nodes import CheckpointLoaderSimple
39+
40+
from .NF4Loader import CheckpointLoaderNF4
41+
42+
warnings.filterwarnings('ignore', message='clean_up_tokenization_spaces')
43+
warnings.filterwarnings("ignore", message="Torch was not compiled with flash attention")
44+
45+
46+
##
47+
# NODES
48+
##
49+
class FlowStateUnifiedModelLoader:
50+
CATEGORY = 'FlowState/loader'
51+
DESCRIPTION = 'Load checkpoints and UNETs, includes NF4 support.'
52+
FUNCTION = 'load'
53+
RETURN_TYPES = MODEL_UNIFIED
54+
RETURN_NAMES = ('model', 'clip', 'vae', 'seed', 'model_type', )
55+
OUTPUT_TOOLTIPS = (
56+
'Checkpoint or UNET model.',
57+
'The CLIP model used for encoding text prompts.',
58+
'The VAE model used for encoding and decoding images to and from latent space.',
59+
'Global seed.',
60+
'Type of model to use.',
61+
)
62+
63+
@classmethod
64+
def INPUT_TYPES(s):
65+
return {
66+
'required': {
67+
'model_file': ALL_MODEL_LISTS(),
68+
'weight_dtype': (['default', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'fp8_e5m2'], ),
69+
'model_type': (['NF4', 'UNET', 'SD'],),
70+
'clip_1': CLIP_LIST(),
71+
'clip_2': CLIP_LIST(),
72+
'clip_type': (['default', 'sdxl', 'sd3', 'flux'], ),
73+
'vae_name': VAE_LIST(),
74+
'seed': SEED,
75+
}
76+
}
77+
78+
def load_taesd(self, name):
79+
sd = {}
80+
approx_vaes = folder_paths.get_filename_list('vae_approx')
81+
82+
encoder = next(filter(lambda a: a.startswith('{}_encoder.'.format(name)), approx_vaes))
83+
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
84+
85+
enc = load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
86+
for k in enc:
87+
sd["taesd_encoder.{}".format(k)] = enc[k]
88+
89+
dec = load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
90+
for k in dec:
91+
sd["taesd_decoder.{}".format(k)] = dec[k]
92+
93+
if name == "taesd":
94+
sd["vae_scale"] = torch.tensor(0.18215)
95+
sd["vae_shift"] = torch.tensor(0.0)
96+
elif name == "taesdxl":
97+
sd["vae_scale"] = torch.tensor(0.13025)
98+
sd["vae_shift"] = torch.tensor(0.0)
99+
elif name == "taesd3":
100+
sd["vae_scale"] = torch.tensor(1.5305)
101+
sd["vae_shift"] = torch.tensor(0.0609)
102+
elif name == "taef1":
103+
sd["vae_scale"] = torch.tensor(0.3611)
104+
sd["vae_shift"] = torch.tensor(0.1159)
105+
return sd
106+
107+
def load_vae(self, vae_name):
108+
vae = None
109+
vae_path = None
110+
captured_output = io.StringIO()
111+
with redirect_stdout(captured_output):
112+
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
113+
vae_path = self.load_taesd(vae_name)
114+
else:
115+
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
116+
vae_path = load_torch_file(vae_path)
117+
vae = comfy.sd.VAE(sd=vae_path)
118+
return vae
119+
120+
def load_clip(self, clip_name1, clip_name2, model_type):
121+
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
122+
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
123+
if model_type == "sdxl":
124+
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
125+
elif model_type == "sd3":
126+
clip_type = comfy.sd.CLIPType.SD3
127+
elif model_type == "flux":
128+
clip_type = comfy.sd.CLIPType.FLUX
129+
130+
clip = None
131+
captured_output = io.StringIO()
132+
with redirect_stdout(captured_output):
133+
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
134+
135+
return clip
136+
137+
def select_models(self, model_file, weight_dtype, model_type, clip_1, clip_2, clip_type, vae_name):
138+
model, clip, vae = None, None, None
139+
clip_fname, vae_fname = None, None
140+
141+
is_nf4 = model_type == 'NF4'
142+
is_sd = model_type == 'SD'
143+
is_unet = model_type == 'UNET'
144+
145+
default_clip = clip_type == 'default'
146+
default_vae = vae_name == 'default'
147+
148+
model_loader = CheckpointLoaderNF4 if is_nf4 else (CheckpointLoaderSimple if is_sd else UNETLoader)
149+
150+
loaded_model = None
151+
if is_unet:
152+
loaded_model = model = model_loader().load_unet(model_file, weight_dtype)[0]
153+
else:
154+
loaded_model = model_loader().load_checkpoint(model_file)
155+
model = loaded_model[0]
156+
157+
clip_and_vae_included = (isinstance(loaded_model, list) or isinstance(loaded_model, tuple)) and len(loaded_model) > 2
158+
159+
if not default_clip and not default_vae:
160+
clip_fname = f'{clip_1} & {clip_2}'
161+
clip_weight_type = 'flux' if default_clip else clip_type
162+
clip = self.load_clip(clip_1, clip_2, clip_weight_type)
163+
vae_fname = vae_name
164+
vae = self.load_vae(vae_fname)
165+
else:
166+
if clip_and_vae_included:
167+
if default_clip:
168+
clip_fname = 'included'
169+
clip = loaded_model[1]
170+
if default_vae:
171+
vae_fname = 'included'
172+
vae = loaded_model[2]
173+
else:
174+
clip_fname = f'{clip_1} & {clip_2}'
175+
clip_weight_type = 'flux' if default_clip else clip_type
176+
clip = self.load_clip(clip_1, clip_2, clip_weight_type)
177+
vae_fname = VAE_LIST_PATH[0]
178+
vae = self.load_vae(vae_fname)
179+
180+
return model, clip, vae, clip_fname, vae_fname
181+
182+
def load(self, model_file, weight_dtype, model_type, clip_1, clip_2, clip_type, vae_name, seed):
183+
print(
184+
f'\n\nFlowState Unified Model Loader'
185+
f'\n - Preparing loader\n'
186+
)
187+
188+
start_time = time.time()
189+
190+
model, clip, vae, clip_fname, vae_fname = self.select_models(
191+
model_file, weight_dtype, model_type, clip_1, clip_2, clip_type, vae_name
192+
)
193+
194+
loading_duration, loading_mins, loading_secs = get_mins_and_secs(start_time)
195+
vae_warn = '(Selected VAE not available)' if vae_fname != vae_name else ''
196+
197+
print(
198+
f'\nFlowState Unified Model Loader - Loading complete.'
199+
f'\n - Model Name: {model_file}'
200+
f'\n - VAE Name: {vae_fname} {vae_warn}'
201+
f'\n - CLIP Name: {clip_fname}'
202+
f'\n - Loading Time: {loading_mins}m {loading_secs}s\n'
203+
)
204+
205+
model_type_out = 'SD' if model_type == 'SD' else 'FLUX'
206+
207+
return (model, clip, vae, seed, [model_type_out], )
208+
209+

0 commit comments

Comments
 (0)