@@ -1562,3 +1562,314 @@ def execute(cls, base_model, lora_count,
15621562 skip_patterns_str = skip_patterns ,
15631563 )
15641564 return io .NodeOutput (path )
1565+
1566+
1567+ # =============================================================================
1568+ # LoRA Merge To Model (Save merged model, skip extraction)
1569+ # =============================================================================
1570+
1571+ def merge_loras_to_model (
1572+ lora_paths : List [str ],
1573+ lora_weights : List [float ],
1574+ base_model_path : str ,
1575+ device : str ,
1576+ save_dtype : torch .dtype ,
1577+ output_filename : str ,
1578+ skip_patterns_str : str = "" ,
1579+ verbose : bool = True ,
1580+ ) -> str :
1581+ """
1582+ Merge multiple LoRAs into a base model and save the result directly.
1583+
1584+ Unlike merge_multi_loras_via_base, this function:
1585+ - Does NOT extract the result back to a LoRA
1586+ - Saves the merged full model to the base model's directory
1587+
1588+ Args:
1589+ lora_paths: List of paths to LoRA files
1590+ lora_weights: List of weight strengths (0.0-2.0) for each LoRA
1591+ base_model_path: Path to base model
1592+ device: Processing device
1593+ save_dtype: Output dtype
1594+ output_filename: Output filename (without extension)
1595+ skip_patterns_str: Regex patterns for layers to skip
1596+ verbose: Print progress info
1597+
1598+ Returns:
1599+ Path to saved merged model
1600+ """
1601+ # Estimate memory and prepare
1602+ total_size_gb = estimate_model_size (base_model_path )
1603+ for lp in lora_paths :
1604+ total_size_gb += estimate_model_size (lp )
1605+
1606+ if verbose :
1607+ print (f"[LoRA Merge To Model] Preparing memory for { total_size_gb :.2f} GB operation..." )
1608+ print (f"[LoRA Merge To Model] Merging { len (lora_paths )} LoRAs with weights: { lora_weights } " )
1609+ prepare_for_large_operation (total_size_gb * 1.5 , torch .device (device ))
1610+
1611+ # Open all files
1612+ base_handler = MemoryEfficientSafeOpen (base_model_path )
1613+ lora_handlers = [MemoryEfficientSafeOpen (lp ) for lp in lora_paths ]
1614+
1615+ try :
1616+ # Detect format and extract pairs for each LoRA
1617+ lora_infos = []
1618+
1619+ for i , handler in enumerate (lora_handlers ):
1620+ keys = handler .keys ()
1621+ format_info = detect_lora_format (keys )
1622+ pairs = extract_lora_pairs (keys , format_info )
1623+ network_dim , network_alpha = detect_lora_rank (handler , pairs )
1624+ lora_infos .append ({
1625+ "handler" : handler ,
1626+ "format_info" : format_info ,
1627+ "pairs" : pairs ,
1628+ "network_dim" : network_dim ,
1629+ "network_alpha" : network_alpha ,
1630+ "weight" : lora_weights [i ],
1631+ })
1632+ if verbose :
1633+ print (f"[LoRA Merge To Model] LoRA { i + 1 } : { format_info ['format' ]} , { len (pairs )} layers, dim={ network_dim } " )
1634+
1635+ # Common prefixes
1636+ BASE_PREFIXES = ["model.diffusion_model." , "diffusion_model." , "transformer." , "model." ]
1637+ LORA_PREFIXES = [
1638+ "lora_unet_" , "lora_transformer_" , "lora_te1_" , "lora_te2_" , "lora_te_" ,
1639+ "lycoris_" , "diffusion_model." , "transformer." , "unet."
1640+ ]
1641+
1642+ def extract_core_layer_base (key : str ) -> str :
1643+ result = key
1644+ if result .endswith (".weight" ):
1645+ result = result [:- 7 ]
1646+ elif result .endswith (".bias" ):
1647+ result = result [:- 5 ]
1648+ for prefix in BASE_PREFIXES :
1649+ if result .startswith (prefix ):
1650+ result = result [len (prefix ):]
1651+ break
1652+ return result
1653+
1654+ def extract_core_layer_lora (block_name : str ) -> str :
1655+ result = block_name
1656+ for prefix in LORA_PREFIXES :
1657+ if result .startswith (prefix ):
1658+ result = result [len (prefix ):]
1659+ break
1660+ return result .replace ("." , "_" )
1661+
1662+ # Build LoRA lookup: core layer name (underscored) -> list of (info, block_keys)
1663+ lora_lookup = {}
1664+ for info in lora_infos :
1665+ for block_name , block_keys in info ["pairs" ].items ():
1666+ core = extract_core_layer_lora (block_name )
1667+ if core not in lora_lookup :
1668+ lora_lookup [core ] = []
1669+ lora_lookup [core ].append ((info , block_keys ))
1670+
1671+ # Compile skip patterns
1672+ skip_patterns = _compile_patterns (skip_patterns_str )
1673+
1674+ # Preserve metadata from base model
1675+ base_metadata = base_handler .metadata ().copy () if base_handler .metadata () else {}
1676+ base_metadata ["merge_comment" ] = f"Merged { len (lora_paths )} LoRAs with weights: { lora_weights } "
1677+
1678+ output_sd = {}
1679+ stats = {"merged" : 0 , "copied" : 0 , "skipped" : 0 }
1680+ base_keys = list (base_handler .keys ())
1681+ pbar = comfy .utils .ProgressBar (len (base_keys ))
1682+
1683+ if verbose :
1684+ print (f"[LoRA Merge To Model] Processing { len (base_keys )} base model keys..." )
1685+
1686+ with torch .no_grad ():
1687+ for base_key in tqdm (base_keys , desc = "Merging to model" , unit = "keys" ):
1688+ # Check skip patterns
1689+ if _matches_any_pattern (base_key , skip_patterns ):
1690+ stats ["skipped" ] += 1
1691+ pbar .update (1 )
1692+ continue
1693+
1694+ # Load base weight
1695+ cpu_base = base_handler .get_tensor (base_key )
1696+
1697+ # Only process weight tensors for LoRA merging
1698+ if base_key .endswith (".weight" ):
1699+ core = extract_core_layer_base (base_key )
1700+ core_underscored = core .replace ("." , "_" )
1701+
1702+ # Check if any LoRA contributes to this layer
1703+ if core_underscored in lora_lookup :
1704+ # Transfer to GPU for computation
1705+ if device == 'cuda' :
1706+ base_weight = transfer_to_gpu_pinned (cpu_base , device , torch .float32 )
1707+ else :
1708+ base_weight = cpu_base .to (device = device , dtype = torch .float32 )
1709+ del cpu_base
1710+
1711+ # Accumulate deltas from all contributing LoRAs
1712+ for info , block_keys in lora_lookup [core_underscored ]:
1713+ is_full_diff = info ["format_info" ].get ("is_full_diff" , False )
1714+
1715+ if is_full_diff :
1716+ # Full diff format
1717+ if "diff" not in block_keys :
1718+ continue
1719+ cpu_diff = info ["handler" ].get_tensor (block_keys ["diff" ])
1720+ if device == 'cuda' :
1721+ delta = transfer_to_gpu_pinned (cpu_diff , device , torch .float32 )
1722+ else :
1723+ delta = cpu_diff .to (device = device , dtype = torch .float32 )
1724+ del cpu_diff
1725+ effective_scale = info ["weight" ]
1726+ else :
1727+ # Standard LoRA format
1728+ if "down" not in block_keys or "up" not in block_keys :
1729+ continue
1730+
1731+ cpu_down = info ["handler" ].get_tensor (block_keys ["down" ])
1732+ cpu_up = info ["handler" ].get_tensor (block_keys ["up" ])
1733+ if device == 'cuda' :
1734+ lora_down = transfer_to_gpu_pinned (cpu_down , device , torch .float32 )
1735+ lora_up = transfer_to_gpu_pinned (cpu_up , device , torch .float32 )
1736+ else :
1737+ lora_down = cpu_down .to (device = device , dtype = torch .float32 )
1738+ lora_up = cpu_up .to (device = device , dtype = torch .float32 )
1739+ del cpu_down , cpu_up
1740+
1741+ # Get alpha
1742+ if "alpha" in block_keys :
1743+ alpha_tensor = info ["handler" ].get_tensor (block_keys ["alpha" ])
1744+ layer_alpha = float (alpha_tensor .item ())
1745+ else :
1746+ layer_alpha = float (info ["network_dim" ])
1747+ layer_scale = layer_alpha / info ["network_dim" ] if info ["network_dim" ] > 0 else 1.0
1748+ effective_scale = layer_scale * info ["weight" ]
1749+
1750+ # Compute delta
1751+ is_conv = len (lora_down .shape ) == 4
1752+ if is_conv :
1753+ in_rank , in_size , kernel_size , k_ = lora_down .shape
1754+ out_size , out_rank , _ , _ = lora_up .shape
1755+ delta = lora_up .reshape (out_size , - 1 ) @ lora_down .reshape (in_rank , - 1 )
1756+ delta = delta .reshape (out_size , in_size , kernel_size , kernel_size )
1757+ else :
1758+ delta = lora_up @ lora_down
1759+ del lora_down , lora_up
1760+
1761+ # Apply delta to base weight
1762+ base_weight = base_weight + effective_scale * delta
1763+ del delta
1764+
1765+ # Store merged weight
1766+ output_sd [base_key ] = base_weight .to (save_dtype ).cpu ().contiguous ()
1767+ del base_weight
1768+ stats ["merged" ] += 1
1769+ else :
1770+ # No LoRA contribution, copy as-is
1771+ output_sd [base_key ] = cpu_base .to (save_dtype ).contiguous ()
1772+ stats ["copied" ] += 1
1773+ else :
1774+ # Non-weight tensor (bias, norm, etc.), copy as-is
1775+ output_sd [base_key ] = cpu_base .to (save_dtype ).contiguous ()
1776+ stats ["copied" ] += 1
1777+
1778+ pbar .update (1 )
1779+
1780+ if verbose :
1781+ print (f"[LoRA Merge To Model] Done: { stats ['merged' ]} merged, { stats ['copied' ]} copied, { stats ['skipped' ]} skipped" )
1782+
1783+ # Save to base model directory
1784+ base_dir = os .path .dirname (base_model_path )
1785+ os .makedirs (base_dir , exist_ok = True )
1786+ output_path = os .path .join (base_dir , f"{ output_filename .strip ()} .safetensors" )
1787+
1788+ save_file (output_sd , output_path , base_metadata )
1789+ print (f"[LoRA Merge To Model] Saved to { output_path } " )
1790+
1791+ return output_path
1792+
1793+ finally :
1794+ base_handler .__exit__ (None , None , None )
1795+ for handler in lora_handlers :
1796+ handler .__exit__ (None , None , None )
1797+ cleanup_after_operation ()
1798+
1799+
1800+ class LoRAMergeToModel (io .ComfyNode ):
1801+ """Merge multiple LoRAs into base model and save as full model."""
1802+
1803+ @classmethod
1804+ def define_schema (cls ):
1805+ return io .Schema (
1806+ node_id = "LoRAMergeToModel" ,
1807+ display_name = "LoRA Merge To Model" ,
1808+ category = "ModelUtils/LoRA/Merge" ,
1809+ description = "Merge 1-4 LoRAs into a base model and save the result. Saves to base model directory." ,
1810+ inputs = [
1811+ io .Combo .Input ("base_model" , options = folder_paths .get_filename_list ("diffusion_models" ),
1812+ tooltip = "Base model the LoRAs were trained on" ),
1813+ io .Combo .Input ("lora_count" , options = ["1" , "2" , "3" , "4" ], default = "2" ,
1814+ tooltip = "Number of LoRAs to merge" ),
1815+ # LoRA 1
1816+ io .Combo .Input ("lora_1" , options = folder_paths .get_filename_list ("loras" ),
1817+ tooltip = "First LoRA" ),
1818+ io .Float .Input ("weight_1" , default = 1.0 , min = 0.0 , max = 2.0 , step = 0.05 ,
1819+ tooltip = "Weight strength for LoRA 1" ),
1820+ # LoRA 2
1821+ io .Combo .Input ("lora_2" , options = ["None" ] + folder_paths .get_filename_list ("loras" ),
1822+ default = "None" , tooltip = "Second LoRA" ),
1823+ io .Float .Input ("weight_2" , default = 1.0 , min = 0.0 , max = 2.0 , step = 0.05 ,
1824+ tooltip = "Weight strength for LoRA 2" ),
1825+ # LoRA 3
1826+ io .Combo .Input ("lora_3" , options = ["None" ] + folder_paths .get_filename_list ("loras" ),
1827+ default = "None" , tooltip = "Third LoRA" ),
1828+ io .Float .Input ("weight_3" , default = 1.0 , min = 0.0 , max = 2.0 , step = 0.05 ,
1829+ tooltip = "Weight strength for LoRA 3" ),
1830+ # LoRA 4
1831+ io .Combo .Input ("lora_4" , options = ["None" ] + folder_paths .get_filename_list ("loras" ),
1832+ default = "None" , tooltip = "Fourth LoRA" ),
1833+ io .Float .Input ("weight_4" , default = 1.0 , min = 0.0 , max = 2.0 , step = 0.05 ,
1834+ tooltip = "Weight strength for LoRA 4" ),
1835+ # Settings
1836+ io .String .Input ("skip_patterns" , default = "" , multiline = True ,
1837+ tooltip = "Regex patterns for layers to skip" ),
1838+ io .String .Input ("output_filename" , default = "merged_model" ),
1839+ io .Combo .Input ("save_dtype" , options = ["fp16" , "bf16" , "fp32" ], default = "fp16" ),
1840+ io .Combo .Input ("device" , options = ["cuda" , "cpu" ], default = "cuda" ),
1841+ ],
1842+ outputs = [io .String .Output (display_name = "output_path" )],
1843+ is_output_node = True ,
1844+ )
1845+
1846+ @classmethod
1847+ def execute (cls , base_model , lora_count ,
1848+ lora_1 , weight_1 , lora_2 , weight_2 , lora_3 , weight_3 , lora_4 , weight_4 ,
1849+ skip_patterns , output_filename , save_dtype , device ) -> io .NodeOutput :
1850+
1851+ # Build LoRA list based on count
1852+ count = int (lora_count )
1853+ lora_names = [lora_1 , lora_2 , lora_3 , lora_4 ][:count ]
1854+ lora_weights = [weight_1 , weight_2 , weight_3 , weight_4 ][:count ]
1855+
1856+ # Filter out "None" entries
1857+ valid_loras = [(name , weight ) for name , weight in zip (lora_names , lora_weights ) if name != "None" ]
1858+ if not valid_loras :
1859+ raise ValueError ("At least one LoRA must be selected" )
1860+
1861+ lora_names , lora_weights = zip (* valid_loras )
1862+ lora_paths = [folder_paths .get_full_path_or_raise ("loras" , name ) for name in lora_names ]
1863+ base_path = folder_paths .get_full_path_or_raise ("diffusion_models" , base_model )
1864+ dtype = {"fp16" : torch .float16 , "bf16" : torch .bfloat16 , "fp32" : torch .float32 }[save_dtype ]
1865+
1866+ path = merge_loras_to_model (
1867+ lora_paths = list (lora_paths ),
1868+ lora_weights = list (lora_weights ),
1869+ base_model_path = base_path ,
1870+ device = device ,
1871+ save_dtype = dtype ,
1872+ output_filename = output_filename ,
1873+ skip_patterns_str = skip_patterns ,
1874+ )
1875+ return io .NodeOutput (path )
0 commit comments