Skip to content

Commit ce8c6aa

Browse files
committed
feat: Add LoRAMergeToModel node for merging LoRAs directly into model
- New merge_loras_to_model() function that merges LoRAs into base model weights and saves the result without extraction - New LoRAMergeToModel node with simplified inputs (no extraction params) - Outputs merged model to the base model's directory
1 parent 6f3fd09 commit ce8c6aa

2 files changed

Lines changed: 314 additions & 1 deletion

File tree

__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
LoRAResizeFrobenius, LoRAResizeCumulative,
2929
LoRAResizeViaBaseFixed, LoRAResizeViaBaseRatio,
3030
LoRAResizeViaBaseFrobenius, LoRAResizeViaBaseCumulative,
31-
LoRAMultiMerge
31+
LoRAMultiMerge, LoRAMergeToModel
3232
)
3333

3434

@@ -62,6 +62,8 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
6262
LoRAResizeViaBaseFrobenius, LoRAResizeViaBaseCumulative,
6363
# LoRA Multi-Merge
6464
LoRAMultiMerge,
65+
# LoRA Merge To Model
66+
LoRAMergeToModel,
6567
]
6668

6769

nodes/lora_resize.py

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,3 +1562,314 @@ def execute(cls, base_model, lora_count,
15621562
skip_patterns_str=skip_patterns,
15631563
)
15641564
return io.NodeOutput(path)
1565+
1566+
1567+
# =============================================================================
1568+
# LoRA Merge To Model (Save merged model, skip extraction)
1569+
# =============================================================================
1570+
1571+
def merge_loras_to_model(
1572+
lora_paths: List[str],
1573+
lora_weights: List[float],
1574+
base_model_path: str,
1575+
device: str,
1576+
save_dtype: torch.dtype,
1577+
output_filename: str,
1578+
skip_patterns_str: str = "",
1579+
verbose: bool = True,
1580+
) -> str:
1581+
"""
1582+
Merge multiple LoRAs into a base model and save the result directly.
1583+
1584+
Unlike merge_multi_loras_via_base, this function:
1585+
- Does NOT extract the result back to a LoRA
1586+
- Saves the merged full model to the base model's directory
1587+
1588+
Args:
1589+
lora_paths: List of paths to LoRA files
1590+
lora_weights: List of weight strengths (0.0-2.0) for each LoRA
1591+
base_model_path: Path to base model
1592+
device: Processing device
1593+
save_dtype: Output dtype
1594+
output_filename: Output filename (without extension)
1595+
skip_patterns_str: Regex patterns for layers to skip
1596+
verbose: Print progress info
1597+
1598+
Returns:
1599+
Path to saved merged model
1600+
"""
1601+
# Estimate memory and prepare
1602+
total_size_gb = estimate_model_size(base_model_path)
1603+
for lp in lora_paths:
1604+
total_size_gb += estimate_model_size(lp)
1605+
1606+
if verbose:
1607+
print(f"[LoRA Merge To Model] Preparing memory for {total_size_gb:.2f}GB operation...")
1608+
print(f"[LoRA Merge To Model] Merging {len(lora_paths)} LoRAs with weights: {lora_weights}")
1609+
prepare_for_large_operation(total_size_gb * 1.5, torch.device(device))
1610+
1611+
# Open all files
1612+
base_handler = MemoryEfficientSafeOpen(base_model_path)
1613+
lora_handlers = [MemoryEfficientSafeOpen(lp) for lp in lora_paths]
1614+
1615+
try:
1616+
# Detect format and extract pairs for each LoRA
1617+
lora_infos = []
1618+
1619+
for i, handler in enumerate(lora_handlers):
1620+
keys = handler.keys()
1621+
format_info = detect_lora_format(keys)
1622+
pairs = extract_lora_pairs(keys, format_info)
1623+
network_dim, network_alpha = detect_lora_rank(handler, pairs)
1624+
lora_infos.append({
1625+
"handler": handler,
1626+
"format_info": format_info,
1627+
"pairs": pairs,
1628+
"network_dim": network_dim,
1629+
"network_alpha": network_alpha,
1630+
"weight": lora_weights[i],
1631+
})
1632+
if verbose:
1633+
print(f"[LoRA Merge To Model] LoRA {i+1}: {format_info['format']}, {len(pairs)} layers, dim={network_dim}")
1634+
1635+
# Common prefixes
1636+
BASE_PREFIXES = ["model.diffusion_model.", "diffusion_model.", "transformer.", "model."]
1637+
LORA_PREFIXES = [
1638+
"lora_unet_", "lora_transformer_", "lora_te1_", "lora_te2_", "lora_te_",
1639+
"lycoris_", "diffusion_model.", "transformer.", "unet."
1640+
]
1641+
1642+
def extract_core_layer_base(key: str) -> str:
1643+
result = key
1644+
if result.endswith(".weight"):
1645+
result = result[:-7]
1646+
elif result.endswith(".bias"):
1647+
result = result[:-5]
1648+
for prefix in BASE_PREFIXES:
1649+
if result.startswith(prefix):
1650+
result = result[len(prefix):]
1651+
break
1652+
return result
1653+
1654+
def extract_core_layer_lora(block_name: str) -> str:
1655+
result = block_name
1656+
for prefix in LORA_PREFIXES:
1657+
if result.startswith(prefix):
1658+
result = result[len(prefix):]
1659+
break
1660+
return result.replace(".", "_")
1661+
1662+
# Build LoRA lookup: core layer name (underscored) -> list of (info, block_keys)
1663+
lora_lookup = {}
1664+
for info in lora_infos:
1665+
for block_name, block_keys in info["pairs"].items():
1666+
core = extract_core_layer_lora(block_name)
1667+
if core not in lora_lookup:
1668+
lora_lookup[core] = []
1669+
lora_lookup[core].append((info, block_keys))
1670+
1671+
# Compile skip patterns
1672+
skip_patterns = _compile_patterns(skip_patterns_str)
1673+
1674+
# Preserve metadata from base model
1675+
base_metadata = base_handler.metadata().copy() if base_handler.metadata() else {}
1676+
base_metadata["merge_comment"] = f"Merged {len(lora_paths)} LoRAs with weights: {lora_weights}"
1677+
1678+
output_sd = {}
1679+
stats = {"merged": 0, "copied": 0, "skipped": 0}
1680+
base_keys = list(base_handler.keys())
1681+
pbar = comfy.utils.ProgressBar(len(base_keys))
1682+
1683+
if verbose:
1684+
print(f"[LoRA Merge To Model] Processing {len(base_keys)} base model keys...")
1685+
1686+
with torch.no_grad():
1687+
for base_key in tqdm(base_keys, desc="Merging to model", unit="keys"):
1688+
# Check skip patterns
1689+
if _matches_any_pattern(base_key, skip_patterns):
1690+
stats["skipped"] += 1
1691+
pbar.update(1)
1692+
continue
1693+
1694+
# Load base weight
1695+
cpu_base = base_handler.get_tensor(base_key)
1696+
1697+
# Only process weight tensors for LoRA merging
1698+
if base_key.endswith(".weight"):
1699+
core = extract_core_layer_base(base_key)
1700+
core_underscored = core.replace(".", "_")
1701+
1702+
# Check if any LoRA contributes to this layer
1703+
if core_underscored in lora_lookup:
1704+
# Transfer to GPU for computation
1705+
if device == 'cuda':
1706+
base_weight = transfer_to_gpu_pinned(cpu_base, device, torch.float32)
1707+
else:
1708+
base_weight = cpu_base.to(device=device, dtype=torch.float32)
1709+
del cpu_base
1710+
1711+
# Accumulate deltas from all contributing LoRAs
1712+
for info, block_keys in lora_lookup[core_underscored]:
1713+
is_full_diff = info["format_info"].get("is_full_diff", False)
1714+
1715+
if is_full_diff:
1716+
# Full diff format
1717+
if "diff" not in block_keys:
1718+
continue
1719+
cpu_diff = info["handler"].get_tensor(block_keys["diff"])
1720+
if device == 'cuda':
1721+
delta = transfer_to_gpu_pinned(cpu_diff, device, torch.float32)
1722+
else:
1723+
delta = cpu_diff.to(device=device, dtype=torch.float32)
1724+
del cpu_diff
1725+
effective_scale = info["weight"]
1726+
else:
1727+
# Standard LoRA format
1728+
if "down" not in block_keys or "up" not in block_keys:
1729+
continue
1730+
1731+
cpu_down = info["handler"].get_tensor(block_keys["down"])
1732+
cpu_up = info["handler"].get_tensor(block_keys["up"])
1733+
if device == 'cuda':
1734+
lora_down = transfer_to_gpu_pinned(cpu_down, device, torch.float32)
1735+
lora_up = transfer_to_gpu_pinned(cpu_up, device, torch.float32)
1736+
else:
1737+
lora_down = cpu_down.to(device=device, dtype=torch.float32)
1738+
lora_up = cpu_up.to(device=device, dtype=torch.float32)
1739+
del cpu_down, cpu_up
1740+
1741+
# Get alpha
1742+
if "alpha" in block_keys:
1743+
alpha_tensor = info["handler"].get_tensor(block_keys["alpha"])
1744+
layer_alpha = float(alpha_tensor.item())
1745+
else:
1746+
layer_alpha = float(info["network_dim"])
1747+
layer_scale = layer_alpha / info["network_dim"] if info["network_dim"] > 0 else 1.0
1748+
effective_scale = layer_scale * info["weight"]
1749+
1750+
# Compute delta
1751+
is_conv = len(lora_down.shape) == 4
1752+
if is_conv:
1753+
in_rank, in_size, kernel_size, k_ = lora_down.shape
1754+
out_size, out_rank, _, _ = lora_up.shape
1755+
delta = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
1756+
delta = delta.reshape(out_size, in_size, kernel_size, kernel_size)
1757+
else:
1758+
delta = lora_up @ lora_down
1759+
del lora_down, lora_up
1760+
1761+
# Apply delta to base weight
1762+
base_weight = base_weight + effective_scale * delta
1763+
del delta
1764+
1765+
# Store merged weight
1766+
output_sd[base_key] = base_weight.to(save_dtype).cpu().contiguous()
1767+
del base_weight
1768+
stats["merged"] += 1
1769+
else:
1770+
# No LoRA contribution, copy as-is
1771+
output_sd[base_key] = cpu_base.to(save_dtype).contiguous()
1772+
stats["copied"] += 1
1773+
else:
1774+
# Non-weight tensor (bias, norm, etc.), copy as-is
1775+
output_sd[base_key] = cpu_base.to(save_dtype).contiguous()
1776+
stats["copied"] += 1
1777+
1778+
pbar.update(1)
1779+
1780+
if verbose:
1781+
print(f"[LoRA Merge To Model] Done: {stats['merged']} merged, {stats['copied']} copied, {stats['skipped']} skipped")
1782+
1783+
# Save to base model directory
1784+
base_dir = os.path.dirname(base_model_path)
1785+
os.makedirs(base_dir, exist_ok=True)
1786+
output_path = os.path.join(base_dir, f"{output_filename.strip()}.safetensors")
1787+
1788+
save_file(output_sd, output_path, base_metadata)
1789+
print(f"[LoRA Merge To Model] Saved to {output_path}")
1790+
1791+
return output_path
1792+
1793+
finally:
1794+
base_handler.__exit__(None, None, None)
1795+
for handler in lora_handlers:
1796+
handler.__exit__(None, None, None)
1797+
cleanup_after_operation()
1798+
1799+
1800+
class LoRAMergeToModel(io.ComfyNode):
1801+
"""Merge multiple LoRAs into base model and save as full model."""
1802+
1803+
@classmethod
1804+
def define_schema(cls):
1805+
return io.Schema(
1806+
node_id="LoRAMergeToModel",
1807+
display_name="LoRA Merge To Model",
1808+
category="ModelUtils/LoRA/Merge",
1809+
description="Merge 1-4 LoRAs into a base model and save the result. Saves to base model directory.",
1810+
inputs=[
1811+
io.Combo.Input("base_model", options=folder_paths.get_filename_list("diffusion_models"),
1812+
tooltip="Base model the LoRAs were trained on"),
1813+
io.Combo.Input("lora_count", options=["1", "2", "3", "4"], default="2",
1814+
tooltip="Number of LoRAs to merge"),
1815+
# LoRA 1
1816+
io.Combo.Input("lora_1", options=folder_paths.get_filename_list("loras"),
1817+
tooltip="First LoRA"),
1818+
io.Float.Input("weight_1", default=1.0, min=0.0, max=2.0, step=0.05,
1819+
tooltip="Weight strength for LoRA 1"),
1820+
# LoRA 2
1821+
io.Combo.Input("lora_2", options=["None"] + folder_paths.get_filename_list("loras"),
1822+
default="None", tooltip="Second LoRA"),
1823+
io.Float.Input("weight_2", default=1.0, min=0.0, max=2.0, step=0.05,
1824+
tooltip="Weight strength for LoRA 2"),
1825+
# LoRA 3
1826+
io.Combo.Input("lora_3", options=["None"] + folder_paths.get_filename_list("loras"),
1827+
default="None", tooltip="Third LoRA"),
1828+
io.Float.Input("weight_3", default=1.0, min=0.0, max=2.0, step=0.05,
1829+
tooltip="Weight strength for LoRA 3"),
1830+
# LoRA 4
1831+
io.Combo.Input("lora_4", options=["None"] + folder_paths.get_filename_list("loras"),
1832+
default="None", tooltip="Fourth LoRA"),
1833+
io.Float.Input("weight_4", default=1.0, min=0.0, max=2.0, step=0.05,
1834+
tooltip="Weight strength for LoRA 4"),
1835+
# Settings
1836+
io.String.Input("skip_patterns", default="", multiline=True,
1837+
tooltip="Regex patterns for layers to skip"),
1838+
io.String.Input("output_filename", default="merged_model"),
1839+
io.Combo.Input("save_dtype", options=["fp16", "bf16", "fp32"], default="fp16"),
1840+
io.Combo.Input("device", options=["cuda", "cpu"], default="cuda"),
1841+
],
1842+
outputs=[io.String.Output(display_name="output_path")],
1843+
is_output_node=True,
1844+
)
1845+
1846+
@classmethod
1847+
def execute(cls, base_model, lora_count,
1848+
lora_1, weight_1, lora_2, weight_2, lora_3, weight_3, lora_4, weight_4,
1849+
skip_patterns, output_filename, save_dtype, device) -> io.NodeOutput:
1850+
1851+
# Build LoRA list based on count
1852+
count = int(lora_count)
1853+
lora_names = [lora_1, lora_2, lora_3, lora_4][:count]
1854+
lora_weights = [weight_1, weight_2, weight_3, weight_4][:count]
1855+
1856+
# Filter out "None" entries
1857+
valid_loras = [(name, weight) for name, weight in zip(lora_names, lora_weights) if name != "None"]
1858+
if not valid_loras:
1859+
raise ValueError("At least one LoRA must be selected")
1860+
1861+
lora_names, lora_weights = zip(*valid_loras)
1862+
lora_paths = [folder_paths.get_full_path_or_raise("loras", name) for name in lora_names]
1863+
base_path = folder_paths.get_full_path_or_raise("diffusion_models", base_model)
1864+
dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}[save_dtype]
1865+
1866+
path = merge_loras_to_model(
1867+
lora_paths=list(lora_paths),
1868+
lora_weights=list(lora_weights),
1869+
base_model_path=base_path,
1870+
device=device,
1871+
save_dtype=dtype,
1872+
output_filename=output_filename,
1873+
skip_patterns_str=skip_patterns,
1874+
)
1875+
return io.NodeOutput(path)

0 commit comments

Comments
 (0)