@@ -677,6 +677,26 @@ def _composite(original_np: np.ndarray,
677677 return result , composite_mask , stats
678678
679679
680+ def get_token_count (clip , text ):
681+ """
682+ Robustly tokenizes a text segment and returns the number of its content tokens.
683+ """
684+ if not text :
685+ return 0
686+
687+ tokens = clip .tokenize (text )
688+
689+ max_content_len = 0
690+ for key in tokens :
691+ if len (tokens [key ]) > 0 and len (tokens [key ][0 ]) > 0 :
692+
693+ content_len = len (tokens [key ][0 ]) - 2
694+ if content_len > max_content_len :
695+ max_content_len = content_len
696+
697+ return max (0 , max_content_len )
698+
699+
680700class LlamaTokenizerOptions (io .ComfyNode ):
681701 @classmethod
682702 def define_schema (cls ) -> io .Schema :
@@ -1663,6 +1683,175 @@ def execute(cls, preset: str) -> io.NodeOutput:
16631683 return io .NodeOutput (preset )
16641684
16651685
1686+ class MultiTypeDemo (io .ComfyNode ):
1687+ @classmethod
1688+ def define_schema (cls ) -> io .Schema :
1689+ return io .Schema (
1690+ node_id = "MultiTypeDemo" ,
1691+ display_name = "Multi-Type Demo (Primitive)" ,
1692+ category = "advanced/primitives" ,
1693+ inputs = [
1694+ io .MultiType .Input (
1695+ "type_list" ,
1696+ types = [io .String , io .Int , io .Float ],
1697+ tooltip = "Demo input that can accept multiple types. Output will just forward the value regardless of type." ,
1698+ ),
1699+ ],
1700+ outputs = [
1701+ io .AnyType .Output (display_name = "output" , is_output_list = True ),
1702+ ],
1703+ )
1704+ @classmethod
1705+ def execute (cls , type_list ) -> io .NodeOutput :
1706+ """Simply forward the input value as output, demonstrating multi-type handling."""
1707+ if type_list [0 ] is not None :
1708+ assert isinstance (type_list [0 ], str ), f"Expected string type for first input, got { type (type_list [0 ])} "
1709+ input_string = type_list [0 ]
1710+ else :
1711+ input_string = ""
1712+ if type_list [1 ] is not None :
1713+ assert isinstance (type_list [1 ], int ), f"Expected int type for second input, got { type (type_list [1 ])} "
1714+ input_int = type_list [1 ]
1715+ else :
1716+ input_int = 0
1717+ if type_list [2 ] is not None :
1718+ assert isinstance (type_list [2 ], float ), f"Expected float type for third input, got { type (type_list [2 ])} "
1719+ input_float = type_list [2 ]
1720+ else :
1721+ input_float = 0.0
1722+
1723+ return io .NodeOutput ([input_string , input_int , input_float ])
1724+
1725+
1726+ class AnyList (io .ComfyNode ):
1727+ @classmethod
1728+ def define_schema (cls ) -> io .Schema :
1729+ autogrow_template = io .Autogrow .TemplatePrefix (
1730+ io .AnyType .Input ("input" ),
1731+ prefix = "input" ,
1732+ min = 1 ,
1733+ max = 100
1734+ )
1735+ return io .Schema (
1736+ node_id = "AnyList" ,
1737+ display_name = "Any List (Primitive)" ,
1738+ category = "advanced" ,
1739+ inputs = [
1740+ io .Autogrow .Input (
1741+ "input" ,
1742+ template = autogrow_template ,
1743+ tooltip = "Add items of any type. The list will grow to accommodate all items added."
1744+ ),
1745+ ],
1746+ outputs = [
1747+ io .AnyType .Output (display_name = "output" ),
1748+ ],
1749+ )
1750+
1751+ @classmethod
1752+ def execute (cls , input ) -> io .NodeOutput :
1753+ """Simply forward the input list as output, demonstrating autogrow list creation."""
1754+ # Convert dict to list of values (autogrow inputs come as dict with keys like "input0", "input1", etc)
1755+ items = list (input .values ()) if isinstance (input , dict ) else input
1756+ print (f"Received list with { len (items )} items. Item types:" )
1757+ output = []
1758+ for item in items :
1759+ output += [item ] if not isinstance (item , list ) else item
1760+ return io .NodeOutput (output )
1761+
1762+
1763+ class AttentionBiasTextEncode (io .ComfyNode ):
1764+ @classmethod
1765+ def define_schema (cls ) -> io .Schema :
1766+ return io .Schema (
1767+ node_id = "AttentionBiasTextEncode" ,
1768+ category = "advanced/conditioning" ,
1769+ display_name = "CLIP Text Encode with Attention Bias (Experimental)" ,
1770+ inputs = [
1771+ io .Clip .Input ("clip" ),
1772+ io .String .Input ("text" , multiline = True , dynamic_prompts = True ),
1773+ ],
1774+ outputs = [
1775+ io .Conditioning .Output (display_name = "conditioning" ),
1776+ ]
1777+ )
1778+
1779+ @classmethod
1780+ def execute (cls , clip , text ) -> io .NodeOutput :
1781+ if clip is None :
1782+ raise RuntimeError ("ERROR: clip input is invalid: None\n \n If the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model." )
1783+
1784+ if '<' not in text and '>' not in text and '=' not in text :
1785+ tokens = clip .tokenize (text )
1786+ cond , pooled = clip .encode_from_tokens (tokens , return_pooled = True )
1787+ return ([[cond , {"pooled_output" : pooled }]], )
1788+
1789+ bias_pattern = re .compile (r"<([^>]+)=([0-9.-]+)>" )
1790+ split_pattern = re .compile (r"(<[^>]+=[0-9.-]+>)" )
1791+ segments = split_pattern .split (text )
1792+
1793+ clean_text = ""
1794+ biases_to_apply = []
1795+
1796+ current_token_index = 1
1797+
1798+ for segment in segments :
1799+ if not segment :
1800+ continue
1801+
1802+ match = bias_pattern .fullmatch (segment )
1803+ if match :
1804+ bias_text , strength_str = match .groups ()
1805+ strength = float (strength_str )
1806+ clean_text += bias_text
1807+ num_tokens = get_token_count (clip , bias_text )
1808+
1809+ if num_tokens > 0 :
1810+ start_index = current_token_index
1811+ end_index = current_token_index + num_tokens
1812+ biases_to_apply .append ({"start" : start_index , "end" : end_index , "strength" : strength })
1813+
1814+ current_token_index += num_tokens
1815+ else :
1816+ clean_text += segment
1817+ num_tokens = get_token_count (clip , segment )
1818+ current_token_index += num_tokens
1819+
1820+ tokens = clip .tokenize (clean_text )
1821+ cond , pooled = clip .encode_from_tokens (tokens , return_pooled = True )
1822+
1823+ if not biases_to_apply :
1824+ return io .NodeOutput ([[cond , {"pooled_output" : pooled }]])
1825+
1826+ cond_dict = {"pooled_output" : pooled }
1827+ n_text_tokens = cond .shape [1 ]
1828+ device = cond .device
1829+ dtype = torch .float16
1830+
1831+
1832+ final_seq_len = n_text_tokens + 1
1833+ attn_mask = torch .zeros ((1 , final_seq_len , final_seq_len ), dtype = dtype , device = device )
1834+
1835+ pooled_offset = 1
1836+
1837+ for bias in biases_to_apply :
1838+ strength = bias ["strength" ]
1839+ attn_bias_value = torch .log (torch .tensor (strength , dtype = dtype , device = device ))
1840+
1841+ start = min (bias ["start" ] + pooled_offset , final_seq_len )
1842+ end = min (bias ["end" ] + pooled_offset , final_seq_len )
1843+
1844+ if start >= end :
1845+ continue
1846+
1847+ attn_mask [:, :, start :end ] += attn_bias_value
1848+ attn_mask [:, start :end , :] += attn_bias_value
1849+
1850+ cond_dict ["attention_mask" ] = attn_mask
1851+ cond_dict ["attention_mask_img_shape" ] = (1 , 1 )
1852+
1853+ return io .NodeOutput ([[cond , cond_dict ]])
1854+
16661855class TextEncodeFlux2SystemPrompt (io .ComfyNode ):
16671856 @classmethod
16681857 def define_schema (cls ):
@@ -3764,6 +3953,9 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
37643953 VLMSysQueryAddPresets ,
37653954 VLMSysInstrAdvPresets ,
37663955 UnifiedPresets ,
3956+ MultiTypeDemo ,
3957+ AnyList ,
3958+ AttentionBiasTextEncode ,
37673959 FrakturPadNode ,
37683960 UnFrakturPadNode ,
37693961 JoinerPadding ,
0 commit comments