-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathnodes.py
More file actions
161 lines (137 loc) · 4.91 KB
/
Copy pathnodes.py
File metadata and controls
161 lines (137 loc) · 4.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import torch
import folder_paths
from nodes import EmptyLatentImage
from .conf import sana_conf, sana_res
from .loader import load_sana
dtypes = [
"auto",
"FP32",
"FP16",
"BF16"
]
class SanaCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
"model": (list(sana_conf.keys()),),
}
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_checkpoint"
CATEGORY = "ExtraModels/Sana"
TITLE = "Sana Checkpoint Loader"
def load_checkpoint(self, ckpt_name, model):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
model_conf = sana_conf[model]
model = load_sana(
model_path = ckpt_path,
model_conf = model_conf,
)
return (model,)
class EmptySanaLatentImage(EmptyLatentImage):
CATEGORY = "ExtraModels/Sana"
TITLE = "Empty Sana Latent Image"
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 32, height // 32, width // 32], device=self.device)
return ({"samples":latent}, )
class SanaResolutionSelect():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": (list(sana_res.keys()),),
"ratio": (list(sana_res["1024px"].keys()),{"default":"1.00"}),
}
}
RETURN_TYPES = ("INT","INT")
RETURN_NAMES = ("width","height")
FUNCTION = "get_res"
CATEGORY = "ExtraModels/Sana"
TITLE = "Sana Resolution Select"
def get_res(self, model, ratio):
width, height = sana_res[model][ratio]
return (width,height)
class SanaResolutionCond:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"cond": ("CONDITIONING", ),
"width": ("INT", {"default": 1024.0, "min": 0, "max": 8192}),
"height": ("INT", {"default": 1024.0, "min": 0, "max": 8192}),
}
}
RETURN_TYPES = ("CONDITIONING",)
RETURN_NAMES = ("cond",)
FUNCTION = "add_cond"
CATEGORY = "ExtraModels/Sana"
TITLE = "Sana Resolution Conditioning"
def add_cond(self, cond, width, height):
for c in range(len(cond)):
cond[c][1].update({
"img_hw": [[height, width]],
"aspect_ratio": [[height/width]],
})
return (cond,)
class SanaTextEncode:
@classmethod
def INPUT_TYPES(s):
return {
"optional": {
"chi_prompt_string": ("STRING", {"forceInput": True})
},
"required": {
"text": ("STRING", {"multiline": True}),
"GEMMA": ("GEMMA",),
"chi": ("BOOLEAN", {"default": True})
}
}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "ExtraModels/Sana"
TITLE = "Sana Text Encode"
def encode(self, text, GEMMA=None, chi=True, chi_prompt_string=None):
tokenizer = GEMMA["tokenizer"]
text_encoder = GEMMA["text_encoder"]
with torch.no_grad():
if chi_prompt_string is None and chi == True:
chi_prompt = "\n".join(preset_te_prompt)
elif chi_prompt_string is not None and chi == True:
chi_prompt = chi_prompt_string
else:
chi_prompt = ""
full_prompt = chi_prompt + text
num_chi_tokens = len(tokenizer.encode(chi_prompt))
max_length = num_chi_tokens + 300 - 2
tokens = tokenizer(
[full_prompt],
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).to(text_encoder.device)
select_idx = [0] + list(range(-300 + 1, 0))
embs = text_encoder(tokens.input_ids, tokens.attention_mask)[0][:, None][:, :, select_idx]
emb_masks = tokens.attention_mask[:, select_idx]
embs = embs * emb_masks.unsqueeze(-1)
return ([[embs, {}]], )
preset_te_prompt = [
'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:',
'- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.',
'- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.',
'Here are examples of how to transform or refine prompts:',
'- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.',
'- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.',
'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:',
'User Prompt: '
]
NODE_CLASS_MAPPINGS = {
"SanaCheckpointLoader" : SanaCheckpointLoader,
"SanaResolutionSelect" : SanaResolutionSelect,
"SanaTextEncode" : SanaTextEncode,
"SanaResolutionCond" : SanaResolutionCond,
"EmptySanaLatentImage": EmptySanaLatentImage,
}