Skip to content

Commit 65f0713

Browse files
committed
add nodes and update presets
1 parent 45ff972 commit 65f0713

2 files changed

Lines changed: 193 additions & 1 deletion

File tree

__init__.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,26 @@ def _composite(original_np: np.ndarray,
677677
return result, composite_mask, stats
678678

679679

680+
def get_token_count(clip, text):
681+
"""
682+
Robustly tokenizes a text segment and returns the number of its content tokens.
683+
"""
684+
if not text:
685+
return 0
686+
687+
tokens = clip.tokenize(text)
688+
689+
max_content_len = 0
690+
for key in tokens:
691+
if len(tokens[key]) > 0 and len(tokens[key][0]) > 0:
692+
693+
content_len = len(tokens[key][0]) - 2
694+
if content_len > max_content_len:
695+
max_content_len = content_len
696+
697+
return max(0, max_content_len)
698+
699+
680700
class LlamaTokenizerOptions(io.ComfyNode):
681701
@classmethod
682702
def define_schema(cls) -> io.Schema:
@@ -1663,6 +1683,175 @@ def execute(cls, preset: str) -> io.NodeOutput:
16631683
return io.NodeOutput(preset)
16641684

16651685

1686+
class MultiTypeDemo(io.ComfyNode):
1687+
@classmethod
1688+
def define_schema(cls) -> io.Schema:
1689+
return io.Schema(
1690+
node_id="MultiTypeDemo",
1691+
display_name="Multi-Type Demo (Primitive)",
1692+
category="advanced/primitives",
1693+
inputs=[
1694+
io.MultiType.Input(
1695+
"type_list",
1696+
types=[io.String, io.Int, io.Float],
1697+
tooltip="Demo input that can accept multiple types. Output will just forward the value regardless of type.",
1698+
),
1699+
],
1700+
outputs=[
1701+
io.AnyType.Output(display_name="output", is_output_list=True),
1702+
],
1703+
)
1704+
@classmethod
1705+
def execute(cls, type_list) -> io.NodeOutput:
1706+
"""Simply forward the input value as output, demonstrating multi-type handling."""
1707+
if type_list[0] is not None:
1708+
assert isinstance(type_list[0], str), f"Expected string type for first input, got {type(type_list[0])}"
1709+
input_string = type_list[0]
1710+
else:
1711+
input_string = ""
1712+
if type_list[1] is not None:
1713+
assert isinstance(type_list[1], int), f"Expected int type for second input, got {type(type_list[1])}"
1714+
input_int = type_list[1]
1715+
else:
1716+
input_int = 0
1717+
if type_list[2] is not None:
1718+
assert isinstance(type_list[2], float), f"Expected float type for third input, got {type(type_list[2])}"
1719+
input_float = type_list[2]
1720+
else:
1721+
input_float = 0.0
1722+
1723+
return io.NodeOutput([input_string, input_int, input_float])
1724+
1725+
1726+
class AnyList(io.ComfyNode):
1727+
@classmethod
1728+
def define_schema(cls) -> io.Schema:
1729+
autogrow_template = io.Autogrow.TemplatePrefix(
1730+
io.AnyType.Input("input"),
1731+
prefix="input",
1732+
min=1,
1733+
max=100
1734+
)
1735+
return io.Schema(
1736+
node_id="AnyList",
1737+
display_name="Any List (Primitive)",
1738+
category="advanced",
1739+
inputs=[
1740+
io.Autogrow.Input(
1741+
"input",
1742+
template=autogrow_template,
1743+
tooltip="Add items of any type. The list will grow to accommodate all items added."
1744+
),
1745+
],
1746+
outputs=[
1747+
io.AnyType.Output(display_name="output"),
1748+
],
1749+
)
1750+
1751+
@classmethod
1752+
def execute(cls, input) -> io.NodeOutput:
1753+
"""Simply forward the input list as output, demonstrating autogrow list creation."""
1754+
# Convert dict to list of values (autogrow inputs come as dict with keys like "input0", "input1", etc)
1755+
items = list(input.values()) if isinstance(input, dict) else input
1756+
print(f"Received list with {len(items)} items. Item types:")
1757+
output = []
1758+
for item in items:
1759+
output += [item] if not isinstance(item, list) else item
1760+
return io.NodeOutput(output)
1761+
1762+
1763+
class AttentionBiasTextEncode(io.ComfyNode):
1764+
@classmethod
1765+
def define_schema(cls) -> io.Schema:
1766+
return io.Schema(
1767+
node_id="AttentionBiasTextEncode",
1768+
category="advanced/conditioning",
1769+
display_name="CLIP Text Encode with Attention Bias (Experimental)",
1770+
inputs=[
1771+
io.Clip.Input("clip"),
1772+
io.String.Input("text", multiline=True, dynamic_prompts=True),
1773+
],
1774+
outputs=[
1775+
io.Conditioning.Output(display_name="conditioning"),
1776+
]
1777+
)
1778+
1779+
@classmethod
1780+
def execute(cls, clip, text) -> io.NodeOutput:
1781+
if clip is None:
1782+
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
1783+
1784+
if '<' not in text and '>' not in text and '=' not in text:
1785+
tokens = clip.tokenize(text)
1786+
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
1787+
return ([[cond, {"pooled_output": pooled}]], )
1788+
1789+
bias_pattern = re.compile(r"<([^>]+)=([0-9.-]+)>")
1790+
split_pattern = re.compile(r"(<[^>]+=[0-9.-]+>)")
1791+
segments = split_pattern.split(text)
1792+
1793+
clean_text = ""
1794+
biases_to_apply = []
1795+
1796+
current_token_index = 1
1797+
1798+
for segment in segments:
1799+
if not segment:
1800+
continue
1801+
1802+
match = bias_pattern.fullmatch(segment)
1803+
if match:
1804+
bias_text, strength_str = match.groups()
1805+
strength = float(strength_str)
1806+
clean_text += bias_text
1807+
num_tokens = get_token_count(clip, bias_text)
1808+
1809+
if num_tokens > 0:
1810+
start_index = current_token_index
1811+
end_index = current_token_index + num_tokens
1812+
biases_to_apply.append({"start": start_index, "end": end_index, "strength": strength})
1813+
1814+
current_token_index += num_tokens
1815+
else:
1816+
clean_text += segment
1817+
num_tokens = get_token_count(clip, segment)
1818+
current_token_index += num_tokens
1819+
1820+
tokens = clip.tokenize(clean_text)
1821+
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
1822+
1823+
if not biases_to_apply:
1824+
return io.NodeOutput([[cond, {"pooled_output": pooled}]])
1825+
1826+
cond_dict = {"pooled_output": pooled}
1827+
n_text_tokens = cond.shape[1]
1828+
device = cond.device
1829+
dtype = torch.float16
1830+
1831+
1832+
final_seq_len = n_text_tokens + 1
1833+
attn_mask = torch.zeros((1, final_seq_len, final_seq_len), dtype=dtype, device=device)
1834+
1835+
pooled_offset = 1
1836+
1837+
for bias in biases_to_apply:
1838+
strength = bias["strength"]
1839+
attn_bias_value = torch.log(torch.tensor(strength, dtype=dtype, device=device))
1840+
1841+
start = min(bias["start"] + pooled_offset, final_seq_len)
1842+
end = min(bias["end"] + pooled_offset, final_seq_len)
1843+
1844+
if start >= end:
1845+
continue
1846+
1847+
attn_mask[:, :, start:end] += attn_bias_value
1848+
attn_mask[:, start:end, :] += attn_bias_value
1849+
1850+
cond_dict["attention_mask"] = attn_mask
1851+
cond_dict["attention_mask_img_shape"] = (1, 1)
1852+
1853+
return io.NodeOutput([[cond, cond_dict]])
1854+
16661855
class TextEncodeFlux2SystemPrompt(io.ComfyNode):
16671856
@classmethod
16681857
def define_schema(cls):
@@ -3764,6 +3953,9 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
37643953
VLMSysQueryAddPresets,
37653954
VLMSysInstrAdvPresets,
37663955
UnifiedPresets,
3956+
MultiTypeDemo,
3957+
AnyList,
3958+
AttentionBiasTextEncode,
37673959
FrakturPadNode,
37683960
UnFrakturPadNode,
37693961
JoinerPadding,

0 commit comments

Comments
 (0)