Skip to content

Commit 21c0972

Browse files
committed
Refine LoRA extraction to output accurate layer formats across the board and add experimental new DARE+Ties merge operation variant
1 parent 3bcd2bf commit 21c0972

3 files changed

Lines changed: 90 additions & 63 deletions

File tree

nodes/lora_extract_svd.py

Lines changed: 70 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Native implementation of Low-Rank Adapter extraction with multiple rank selection modes.
55
No external dependencies required.
66
"""
7+
import fnmatch
78
import os
89
import re
910
import torch
@@ -399,23 +400,38 @@ def _extract_chunked_layer(
399400
# Pattern Matching
400401
# =============================================================================
401402

402-
def _compile_patterns(pattern_string: str) -> list:
403-
"""Compile regex patterns from whitespace-separated string."""
403+
def _compile_patterns(pattern_string: str, glob_mode: bool = False) -> list:
404+
"""Compile patterns from whitespace-separated string.
405+
406+
In regex mode (default) returns compiled re.Pattern objects.
407+
In glob mode returns plain strings; fnmatch handles matching.
408+
"""
404409
if not pattern_string or not pattern_string.strip():
405410
return []
406411
patterns = []
407412
for p in pattern_string.split():
408413
p = p.strip()
409-
if p:
414+
if not p:
415+
continue
416+
if glob_mode:
417+
patterns.append(p)
418+
else:
410419
try:
411420
patterns.append(re.compile(p))
412421
except re.error:
413422
pass
414423
return patterns
415424

416425

417-
def _matches_any_pattern(key: str, patterns: list) -> bool:
418-
"""Check if key matches any pattern."""
426+
def _matches_any_pattern(key: str, patterns: list, glob_mode: bool = False) -> bool:
427+
"""Check if key matches any pattern.
428+
429+
Glob mode: each pattern is matched as a substring glob using fnmatch
430+
(dots are literal, * matches any sequence of characters).
431+
Regex mode: each pattern is a compiled re.Pattern, matched via search().
432+
"""
433+
if glob_mode:
434+
return any(fnmatch.fnmatch(key, f"*{p}*") for p in patterns)
419435
return any(p.search(key) for p in patterns)
420436

421437

@@ -441,6 +457,7 @@ def extract_lora_from_files(
441457
svd_niter: int = 2,
442458
lazy_load: bool = True,
443459
force_clear_cache: bool = True,
460+
glob_skip_patterns: bool = False,
444461
) -> dict[str, torch.Tensor]:
445462
"""
446463
Extract LoRA from difference between two models.
@@ -460,6 +477,8 @@ def extract_lora_from_files(
460477
skip_patterns_str: Patterns for layers to skip
461478
mismatch_mode: How to handle mismatches
462479
chunk_large_layers: Enable chunked extraction
480+
glob_skip_patterns: When True, treat skip_patterns as glob (* wildcard).
481+
When False (default), treat as Python regex.
463482
"""
464483
output_sd = {}
465484

@@ -469,7 +488,7 @@ def extract_lora_from_files(
469488
"bf16": torch.bfloat16,
470489
}.get(save_dtype, torch.float16)
471490

472-
skip_patterns = _compile_patterns(skip_patterns_str)
491+
skip_patterns = _compile_patterns(skip_patterns_str, glob_mode=glob_skip_patterns)
473492

474493
# Prepare memory before heavy operation
475494
total_size_gb = estimate_model_size(model_a_path) + estimate_model_size(model_b_path)
@@ -489,7 +508,7 @@ def extract_lora_from_files(
489508
def _process_layer(key):
490509
lora_name = _format_lora_key(key)
491510

492-
if _matches_any_pattern(key, skip_patterns):
511+
if _matches_any_pattern(key, skip_patterns, glob_mode=glob_skip_patterns):
493512
return "skipped", None
494513

495514
# Load tensors with pinned memory for CUDA
@@ -570,9 +589,8 @@ def _process_layer(key):
570589
weight_diff, num_chunks, mode, linear_param, device, linear_max_rank
571590
)
572591
if lora_up is not None:
573-
layer_results[f"{lora_name}.lora_up.weight"] = lora_up.to(save_torch_dtype).cpu().contiguous()
574-
layer_results[f"{lora_name}.lora_down.weight"] = lora_down.to(save_torch_dtype).cpu().contiguous()
575-
layer_results[f"{lora_name}.alpha"] = torch.tensor(rank).to(save_torch_dtype)
592+
layer_results[f"{lora_name}.lora_B.weight"] = lora_up.to(save_torch_dtype).cpu().contiguous()
593+
layer_results[f"{lora_name}.lora_A.weight"] = lora_down.to(save_torch_dtype).cpu().contiguous()
576594
del weight_diff
577595
return "chunked", layer_results
578596

@@ -586,9 +604,8 @@ def _process_layer(key):
586604
status = "full"
587605
else:
588606
lora_down, lora_up, _ = result
589-
layer_results[f"{lora_name}.lora_down.weight"] = lora_down.to(save_torch_dtype).cpu().contiguous()
590-
layer_results[f"{lora_name}.lora_up.weight"] = lora_up.to(save_torch_dtype).cpu().contiguous()
591-
layer_results[f"{lora_name}.alpha"] = torch.tensor(lora_down.shape[0]).to(save_torch_dtype)
607+
layer_results[f"{lora_name}.lora_A.weight"] = lora_down.to(save_torch_dtype).cpu().contiguous()
608+
layer_results[f"{lora_name}.lora_B.weight"] = lora_up.to(save_torch_dtype).cpu().contiguous()
592609
status = "extracted"
593610

594611
del weight_diff
@@ -686,46 +703,48 @@ def _reconstruct_dots(key: str) -> str:
686703
def _format_lora_key(key: str) -> str:
687704
"""
688705
Format the key for LoRA saving.
689-
Standardizes to 'diffusion_model.' or 'transformer.' prefix.
706+
Always standardizes to 'diffusion_model.' prefix for maximum compatibility.
707+
708+
ComfyUI maps 'diffusion_model.<layer>' generically for all model types, making
709+
it universally compatible. The 'transformer.' prefix only works for a subset of
710+
model types via model-specific mapping code and is intentionally not used here.
690711
"""
691712
if key.endswith(".weight"):
692713
key = key[:-7]
693714

694-
# Handle ComfyUI Checkpoint wrapper
695-
if key.startswith("model.diffusion_model.") or key.startswith("net."):
696-
# Check if it's a Diffusers-format transformer inside a checkpoint
697-
if key.startswith("model.diffusion_model."):
698-
inner_key = key[22:] # len("model.diffusion_model.")
699-
else:
700-
inner_key = key[4:] # len("net.")
715+
# Handle ComfyUI Checkpoint wrapper (ldm/sgm format: model.diffusion_model.*)
716+
if key.startswith("model.diffusion_model."):
717+
inner_key = key[22:] # strip len("model.diffusion_model.")
718+
return f"diffusion_model.{inner_key}"
701719

702-
if inner_key.startswith("transformer_blocks") or inner_key.startswith("single_transformer_blocks"):
703-
return f"transformer.{inner_key}"
704-
# Standard checkpoint or Original Flux
720+
# Handle net.* wrapper
721+
if key.startswith("net."):
722+
inner_key = key[4:] # strip len("net.")
705723
return f"diffusion_model.{inner_key}"
706724

707-
# Handle direct Diffusers keys
725+
# Handle direct Diffusers keys (transformer_blocks.* / single_transformer_blocks.* without prefix)
708726
if key.startswith("transformer_blocks") or key.startswith("single_transformer_blocks"):
709-
return f"transformer.{key}"
727+
return f"diffusion_model.{key}"
710728

711729
# Handle legacy lora_unet_ prefix (for resizing without base)
712730
if key.startswith("lora_unet_"):
713731
core = key[10:]
714732
dotted = _reconstruct_dots(core)
715-
# Check if reconstructed key implies transformer format
716-
if dotted.startswith("transformer_blocks") or dotted.startswith("single_transformer_blocks"):
717-
return f"transformer.{dotted}"
718733
return f"diffusion_model.{dotted}"
719734

720-
# Handle already correct prefixes
721-
if key.startswith("diffusion_model.") or key.startswith("transformer."):
735+
# Handle already-prefixed keys
736+
if key.startswith("diffusion_model."):
722737
return key
723738

724-
# Handle known Unet blocks (Diffusers format UNet)
739+
# Absorb now-invalid transformer. prefix into diffusion_model. for compatibility
740+
if key.startswith("transformer."):
741+
return f"diffusion_model.{key[12:]}"
742+
743+
# Handle known Diffusers UNet blocks
725744
if any(key.startswith(p) for p in ["down_blocks", "up_blocks", "mid_block", "conv_in", "conv_out", "time_embedding", "class_embedding"]):
726745
return f"diffusion_model.{key}"
727746

728-
# Default fallback to diffusion_model
747+
# Default fallback
729748
return f"diffusion_model.{key}"
730749

731750

@@ -744,7 +763,10 @@ def _get_common_inputs():
744763
io.Combo.Input("save_dtype", options=["fp16", "bf16", "fp32"], default="fp16"),
745764
io.Combo.Input("device", options=["cuda", "cpu"], default="cuda"),
746765
io.String.Input("skip_patterns", default="", multiline=True,
747-
tooltip="Regex patterns for layers to skip"),
766+
tooltip="Patterns for layers to skip (regex or glob depending on glob_skip_patterns)"),
767+
io.Boolean.Input("glob_skip_patterns", default=False,
768+
tooltip="When True, skip_patterns use glob syntax (* = any sequence, ? = any char, dots are literal). "
769+
"When False (default), patterns are Python regex matched as substrings."),
748770
]
749771

750772

@@ -775,7 +797,7 @@ def define_schema(cls):
775797
@classmethod
776798
def execute(cls, model_a, model_b, linear_dim, conv_dim, svd_niter, chunk_large_layers,
777799
clamp_quantile, min_diff, mismatch_mode, output_filename,
778-
save_dtype, device, skip_patterns, lazy_load, force_clear_cache) -> io.NodeOutput:
800+
save_dtype, device, skip_patterns, glob_skip_patterns, lazy_load, force_clear_cache) -> io.NodeOutput:
779801

780802
model_a_path = folder_paths.get_full_path_or_raise("diffusion_models", model_a)
781803
model_b_path = folder_paths.get_full_path_or_raise("diffusion_models", model_b)
@@ -784,7 +806,7 @@ def execute(cls, model_a, model_b, linear_dim, conv_dim, svd_niter, chunk_large_
784806
model_a_path, model_b_path, "fixed", linear_dim, conv_dim,
785807
device, save_dtype, linear_dim, conv_dim,
786808
clamp_quantile, min_diff, skip_patterns, mismatch_mode, chunk_large_layers, svd_niter,
787-
lazy_load, force_clear_cache
809+
lazy_load, force_clear_cache, glob_skip_patterns
788810
)
789811

790812
return io.NodeOutput(_save_lora(output_sd, output_filename))
@@ -819,7 +841,7 @@ def define_schema(cls):
819841
@classmethod
820842
def execute(cls, model_a, model_b, linear_ratio, conv_ratio, linear_max_rank, conv_max_rank,
821843
chunk_large_layers, clamp_quantile, min_diff, mismatch_mode, output_filename,
822-
save_dtype, device, skip_patterns, lazy_load, force_clear_cache) -> io.NodeOutput:
844+
save_dtype, device, skip_patterns, glob_skip_patterns, lazy_load, force_clear_cache) -> io.NodeOutput:
823845

824846
model_a_path = folder_paths.get_full_path_or_raise("diffusion_models", model_a)
825847
model_b_path = folder_paths.get_full_path_or_raise("diffusion_models", model_b)
@@ -828,7 +850,8 @@ def execute(cls, model_a, model_b, linear_ratio, conv_ratio, linear_max_rank, co
828850
model_a_path, model_b_path, "ratio", linear_ratio, conv_ratio,
829851
device, save_dtype, linear_max_rank, conv_max_rank,
830852
clamp_quantile, min_diff, skip_patterns, mismatch_mode, chunk_large_layers,
831-
lazy_load=lazy_load, force_clear_cache=force_clear_cache
853+
lazy_load=lazy_load, force_clear_cache=force_clear_cache,
854+
glob_skip_patterns=glob_skip_patterns
832855
)
833856

834857
return io.NodeOutput(_save_lora(output_sd, output_filename))
@@ -863,7 +886,7 @@ def define_schema(cls):
863886
@classmethod
864887
def execute(cls, model_a, model_b, linear_quantile, conv_quantile, linear_max_rank, conv_max_rank,
865888
chunk_large_layers, clamp_quantile, min_diff, mismatch_mode, output_filename,
866-
save_dtype, device, skip_patterns, lazy_load, force_clear_cache) -> io.NodeOutput:
889+
save_dtype, device, skip_patterns, glob_skip_patterns, lazy_load, force_clear_cache) -> io.NodeOutput:
867890

868891
model_a_path = folder_paths.get_full_path_or_raise("diffusion_models", model_a)
869892
model_b_path = folder_paths.get_full_path_or_raise("diffusion_models", model_b)
@@ -872,7 +895,8 @@ def execute(cls, model_a, model_b, linear_quantile, conv_quantile, linear_max_ra
872895
model_a_path, model_b_path, "quantile", linear_quantile, conv_quantile,
873896
device, save_dtype, linear_max_rank, conv_max_rank,
874897
clamp_quantile, min_diff, skip_patterns, mismatch_mode, chunk_large_layers,
875-
lazy_load=lazy_load, force_clear_cache=force_clear_cache
898+
lazy_load=lazy_load, force_clear_cache=force_clear_cache,
899+
glob_skip_patterns=glob_skip_patterns
876900
)
877901

878902
return io.NodeOutput(_save_lora(output_sd, output_filename))
@@ -905,7 +929,7 @@ def define_schema(cls):
905929
@classmethod
906930
def execute(cls, model_a, model_b, knee_method, linear_max_rank, conv_max_rank,
907931
chunk_large_layers, clamp_quantile, min_diff, mismatch_mode, output_filename,
908-
save_dtype, device, skip_patterns, lazy_load, force_clear_cache) -> io.NodeOutput:
932+
save_dtype, device, skip_patterns, glob_skip_patterns, lazy_load, force_clear_cache) -> io.NodeOutput:
909933

910934
model_a_path = folder_paths.get_full_path_or_raise("diffusion_models", model_a)
911935
model_b_path = folder_paths.get_full_path_or_raise("diffusion_models", model_b)
@@ -914,7 +938,8 @@ def execute(cls, model_a, model_b, knee_method, linear_max_rank, conv_max_rank,
914938
model_a_path, model_b_path, knee_method, 0, 0,
915939
device, save_dtype, linear_max_rank, conv_max_rank,
916940
clamp_quantile, min_diff, skip_patterns, mismatch_mode, chunk_large_layers,
917-
lazy_load=lazy_load, force_clear_cache=force_clear_cache
941+
lazy_load=lazy_load, force_clear_cache=force_clear_cache,
942+
glob_skip_patterns=glob_skip_patterns
918943
)
919944

920945
return io.NodeOutput(_save_lora(output_sd, output_filename))
@@ -949,7 +974,7 @@ def define_schema(cls):
949974
@classmethod
950975
def execute(cls, model_a, model_b, linear_target, conv_target, linear_max_rank, conv_max_rank,
951976
chunk_large_layers, clamp_quantile, min_diff, mismatch_mode, output_filename,
952-
save_dtype, device, skip_patterns, lazy_load, force_clear_cache) -> io.NodeOutput:
977+
save_dtype, device, skip_patterns, glob_skip_patterns, lazy_load, force_clear_cache) -> io.NodeOutput:
953978

954979
model_a_path = folder_paths.get_full_path_or_raise("diffusion_models", model_a)
955980
model_b_path = folder_paths.get_full_path_or_raise("diffusion_models", model_b)
@@ -958,7 +983,8 @@ def execute(cls, model_a, model_b, linear_target, conv_target, linear_max_rank,
958983
model_a_path, model_b_path, "sv_fro", linear_target, conv_target,
959984
device, save_dtype, linear_max_rank, conv_max_rank,
960985
clamp_quantile, min_diff, skip_patterns, mismatch_mode, chunk_large_layers,
961-
lazy_load=lazy_load, force_clear_cache=force_clear_cache
986+
lazy_load=lazy_load, force_clear_cache=force_clear_cache,
987+
glob_skip_patterns=glob_skip_patterns
962988
)
963989

964990
return io.NodeOutput(_save_lora(output_sd, output_filename))

nodes/lora_merger.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -735,18 +735,11 @@ def merge_multi_loras_dare_enhanced(
735735
continue
736736

737737
def process_ties_dare_enhanced(tensors, weights, dim_to_pad):
738-
padded = []
738+
processed = []
739739
for t, w in tensors:
740-
if t.shape[dim_to_pad] < max_rank:
741-
padding = [0] * (len(t.shape) * 2)
742-
rev_dim = len(t.shape) - 1 - dim_to_pad
743-
padding[rev_dim*2 + 1] = max_rank - t.shape[dim_to_pad]
744-
t = torch.nn.functional.pad(t, tuple(padding))
745-
padded.append(t * w)
740+
t_val = t * w
746741

747-
# Enhanced DARE
748-
for i in range(len(padded)):
749-
t_val = padded[i]
742+
# Enhanced DARE
750743
abs_t = torch.abs(t_val)
751744
max_val = torch.max(abs_t)
752745
if max_val > 0:
@@ -756,18 +749,26 @@ def process_ties_dare_enhanced(tensors, weights, dim_to_pad):
756749
random_mask = torch.bernoulli(prob, generator=rng)
757750
interpolated_mask = torch.lerp(random_mask, prob, mask_smooth)
758751

759-
padded[i] = (t_val * interpolated_mask) / prob
752+
t_val = t_val * interpolated_mask
760753

761-
# TIES
762-
if trim_quantile > 0:
763-
for i in range(len(padded)):
764-
flat = padded[i].abs().flatten()
754+
# TIES
755+
if trim_quantile > 0:
756+
flat = t_val.abs().flatten()
765757
k = max(1, int(len(flat) * trim_quantile))
766758
if k > 0:
767759
threshold = torch.kthvalue(flat, k).values
768-
padded[i] = torch.where(padded[i].abs() < threshold, torch.zeros_like(padded[i]), padded[i])
760+
t_val = torch.where(t_val.abs() < threshold, torch.zeros_like(t_val), t_val)
761+
762+
# Pad
763+
if t_val.shape[dim_to_pad] < max_rank:
764+
padding = [0] * (len(t_val.shape) * 2)
765+
rev_dim = len(t_val.shape) - 1 - dim_to_pad
766+
padding[rev_dim*2 + 1] = max_rank - t_val.shape[dim_to_pad]
767+
t_val = torch.nn.functional.pad(t_val, tuple(padding))
768+
769+
processed.append(t_val)
769770

770-
stacked = torch.stack(padded)
771+
stacked = torch.stack(processed)
771772
signs = torch.sign(stacked)
772773
sum_signs = signs.sum(dim=0)
773774
dominant_sign = torch.sign(sum_signs)

nodes/merger_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ def oper(self, a, b):
816816
random_mask = torch.bernoulli(prob, generator=rng)
817817
interpolated_mask = torch.lerp(random_mask, prob, self.mask_smooth)
818818

819-
rescaled_delta = (delta_val * interpolated_mask) / prob
819+
rescaled_delta = delta_val * interpolated_mask
820820
return rescaled_delta
821821

822822

@@ -845,7 +845,7 @@ def oper(self, a, b):
845845
random_mask = torch.bernoulli(prob, generator=rng)
846846
interpolated_mask = torch.lerp(random_mask, prob, self.mask_smooth)
847847

848-
delta_val = (delta_val * interpolated_mask) / prob
848+
delta_val = delta_val * interpolated_mask
849849

850850
if self.beta > 0:
851851
flat_abs = delta_val.abs().flatten()

0 commit comments

Comments
 (0)