44Native implementation of Low-Rank Adapter extraction with multiple rank selection modes.
55No external dependencies required.
66"""
7+ import fnmatch
78import os
89import re
910import torch
@@ -399,23 +400,38 @@ def _extract_chunked_layer(
399400# Pattern Matching
400401# =============================================================================
401402
402- def _compile_patterns (pattern_string : str ) -> list :
403- """Compile regex patterns from whitespace-separated string."""
403+ def _compile_patterns (pattern_string : str , glob_mode : bool = False ) -> list :
404+ """Compile patterns from whitespace-separated string.
405+
406+ In regex mode (default) returns compiled re.Pattern objects.
407+ In glob mode returns plain strings; fnmatch handles matching.
408+ """
404409 if not pattern_string or not pattern_string .strip ():
405410 return []
406411 patterns = []
407412 for p in pattern_string .split ():
408413 p = p .strip ()
409- if p :
414+ if not p :
415+ continue
416+ if glob_mode :
417+ patterns .append (p )
418+ else :
410419 try :
411420 patterns .append (re .compile (p ))
412421 except re .error :
413422 pass
414423 return patterns
415424
416425
417- def _matches_any_pattern (key : str , patterns : list ) -> bool :
418- """Check if key matches any pattern."""
426+ def _matches_any_pattern (key : str , patterns : list , glob_mode : bool = False ) -> bool :
427+ """Check if key matches any pattern.
428+
429+ Glob mode: each pattern is matched as a substring glob using fnmatch
430+ (dots are literal, * matches any sequence of characters).
431+ Regex mode: each pattern is a compiled re.Pattern, matched via search().
432+ """
433+ if glob_mode :
434+ return any (fnmatch .fnmatch (key , f"*{ p } *" ) for p in patterns )
419435 return any (p .search (key ) for p in patterns )
420436
421437
@@ -441,6 +457,7 @@ def extract_lora_from_files(
441457 svd_niter : int = 2 ,
442458 lazy_load : bool = True ,
443459 force_clear_cache : bool = True ,
460+ glob_skip_patterns : bool = False ,
444461) -> dict [str , torch .Tensor ]:
445462 """
446463 Extract LoRA from difference between two models.
@@ -460,6 +477,8 @@ def extract_lora_from_files(
460477 skip_patterns_str: Patterns for layers to skip
461478 mismatch_mode: How to handle mismatches
462479 chunk_large_layers: Enable chunked extraction
480+ glob_skip_patterns: When True, treat skip_patterns as glob (* wildcard).
481+ When False (default), treat as Python regex.
463482 """
464483 output_sd = {}
465484
@@ -469,7 +488,7 @@ def extract_lora_from_files(
469488 "bf16" : torch .bfloat16 ,
470489 }.get (save_dtype , torch .float16 )
471490
472- skip_patterns = _compile_patterns (skip_patterns_str )
491+ skip_patterns = _compile_patterns (skip_patterns_str , glob_mode = glob_skip_patterns )
473492
474493 # Prepare memory before heavy operation
475494 total_size_gb = estimate_model_size (model_a_path ) + estimate_model_size (model_b_path )
@@ -489,7 +508,7 @@ def extract_lora_from_files(
489508 def _process_layer (key ):
490509 lora_name = _format_lora_key (key )
491510
492- if _matches_any_pattern (key , skip_patterns ):
511+ if _matches_any_pattern (key , skip_patterns , glob_mode = glob_skip_patterns ):
493512 return "skipped" , None
494513
495514 # Load tensors with pinned memory for CUDA
@@ -570,9 +589,8 @@ def _process_layer(key):
570589 weight_diff , num_chunks , mode , linear_param , device , linear_max_rank
571590 )
572591 if lora_up is not None :
573- layer_results [f"{ lora_name } .lora_up.weight" ] = lora_up .to (save_torch_dtype ).cpu ().contiguous ()
574- layer_results [f"{ lora_name } .lora_down.weight" ] = lora_down .to (save_torch_dtype ).cpu ().contiguous ()
575- layer_results [f"{ lora_name } .alpha" ] = torch .tensor (rank ).to (save_torch_dtype )
592+ layer_results [f"{ lora_name } .lora_B.weight" ] = lora_up .to (save_torch_dtype ).cpu ().contiguous ()
593+ layer_results [f"{ lora_name } .lora_A.weight" ] = lora_down .to (save_torch_dtype ).cpu ().contiguous ()
576594 del weight_diff
577595 return "chunked" , layer_results
578596
@@ -586,9 +604,8 @@ def _process_layer(key):
586604 status = "full"
587605 else :
588606 lora_down , lora_up , _ = result
589- layer_results [f"{ lora_name } .lora_down.weight" ] = lora_down .to (save_torch_dtype ).cpu ().contiguous ()
590- layer_results [f"{ lora_name } .lora_up.weight" ] = lora_up .to (save_torch_dtype ).cpu ().contiguous ()
591- layer_results [f"{ lora_name } .alpha" ] = torch .tensor (lora_down .shape [0 ]).to (save_torch_dtype )
607+ layer_results [f"{ lora_name } .lora_A.weight" ] = lora_down .to (save_torch_dtype ).cpu ().contiguous ()
608+ layer_results [f"{ lora_name } .lora_B.weight" ] = lora_up .to (save_torch_dtype ).cpu ().contiguous ()
592609 status = "extracted"
593610
594611 del weight_diff
@@ -686,46 +703,48 @@ def _reconstruct_dots(key: str) -> str:
686703def _format_lora_key (key : str ) -> str :
687704 """
688705 Format the key for LoRA saving.
689- Standardizes to 'diffusion_model.' or 'transformer.' prefix.
706+ Always standardizes to 'diffusion_model.' prefix for maximum compatibility.
707+
708+ ComfyUI maps 'diffusion_model.<layer>' generically for all model types, making
709+ it universally compatible. The 'transformer.' prefix only works for a subset of
710+ model types via model-specific mapping code and is intentionally not used here.
690711 """
691712 if key .endswith (".weight" ):
692713 key = key [:- 7 ]
693714
694- # Handle ComfyUI Checkpoint wrapper
695- if key .startswith ("model.diffusion_model." ) or key .startswith ("net." ):
696- # Check if it's a Diffusers-format transformer inside a checkpoint
697- if key .startswith ("model.diffusion_model." ):
698- inner_key = key [22 :] # len("model.diffusion_model.")
699- else :
700- inner_key = key [4 :] # len("net.")
715+ # Handle ComfyUI Checkpoint wrapper (ldm/sgm format: model.diffusion_model.*)
716+ if key .startswith ("model.diffusion_model." ):
717+ inner_key = key [22 :] # strip len("model.diffusion_model.")
718+ return f"diffusion_model.{ inner_key } "
701719
702- if inner_key . startswith ( "transformer_blocks" ) or inner_key . startswith ( "single_transformer_blocks" ):
703- return f"transformer. { inner_key } "
704- # Standard checkpoint or Original Flux
720+ # Handle net.* wrapper
721+ if key . startswith ( "net." ):
722+ inner_key = key [ 4 :] # strip len("net.")
705723 return f"diffusion_model.{ inner_key } "
706724
707- # Handle direct Diffusers keys
725+ # Handle direct Diffusers keys (transformer_blocks.* / single_transformer_blocks.* without prefix)
708726 if key .startswith ("transformer_blocks" ) or key .startswith ("single_transformer_blocks" ):
709- return f"transformer .{ key } "
727+ return f"diffusion_model .{ key } "
710728
711729 # Handle legacy lora_unet_ prefix (for resizing without base)
712730 if key .startswith ("lora_unet_" ):
713731 core = key [10 :]
714732 dotted = _reconstruct_dots (core )
715- # Check if reconstructed key implies transformer format
716- if dotted .startswith ("transformer_blocks" ) or dotted .startswith ("single_transformer_blocks" ):
717- return f"transformer.{ dotted } "
718733 return f"diffusion_model.{ dotted } "
719734
720- # Handle already correct prefixes
721- if key .startswith ("diffusion_model." ) or key . startswith ( "transformer." ) :
735+ # Handle already-prefixed keys
736+ if key .startswith ("diffusion_model." ):
722737 return key
723738
724- # Handle known Unet blocks (Diffusers format UNet)
739+ # Absorb now-invalid transformer. prefix into diffusion_model. for compatibility
740+ if key .startswith ("transformer." ):
741+ return f"diffusion_model.{ key [12 :]} "
742+
743+ # Handle known Diffusers UNet blocks
725744 if any (key .startswith (p ) for p in ["down_blocks" , "up_blocks" , "mid_block" , "conv_in" , "conv_out" , "time_embedding" , "class_embedding" ]):
726745 return f"diffusion_model.{ key } "
727746
728- # Default fallback to diffusion_model
747+ # Default fallback
729748 return f"diffusion_model.{ key } "
730749
731750
@@ -744,7 +763,10 @@ def _get_common_inputs():
744763 io .Combo .Input ("save_dtype" , options = ["fp16" , "bf16" , "fp32" ], default = "fp16" ),
745764 io .Combo .Input ("device" , options = ["cuda" , "cpu" ], default = "cuda" ),
746765 io .String .Input ("skip_patterns" , default = "" , multiline = True ,
747- tooltip = "Regex patterns for layers to skip" ),
766+ tooltip = "Patterns for layers to skip (regex or glob depending on glob_skip_patterns)" ),
767+ io .Boolean .Input ("glob_skip_patterns" , default = False ,
768+ tooltip = "When True, skip_patterns use glob syntax (* = any sequence, ? = any char, dots are literal). "
769+ "When False (default), patterns are Python regex matched as substrings." ),
748770 ]
749771
750772
@@ -775,7 +797,7 @@ def define_schema(cls):
775797 @classmethod
776798 def execute (cls , model_a , model_b , linear_dim , conv_dim , svd_niter , chunk_large_layers ,
777799 clamp_quantile , min_diff , mismatch_mode , output_filename ,
778- save_dtype , device , skip_patterns , lazy_load , force_clear_cache ) -> io .NodeOutput :
800+ save_dtype , device , skip_patterns , glob_skip_patterns , lazy_load , force_clear_cache ) -> io .NodeOutput :
779801
780802 model_a_path = folder_paths .get_full_path_or_raise ("diffusion_models" , model_a )
781803 model_b_path = folder_paths .get_full_path_or_raise ("diffusion_models" , model_b )
@@ -784,7 +806,7 @@ def execute(cls, model_a, model_b, linear_dim, conv_dim, svd_niter, chunk_large_
784806 model_a_path , model_b_path , "fixed" , linear_dim , conv_dim ,
785807 device , save_dtype , linear_dim , conv_dim ,
786808 clamp_quantile , min_diff , skip_patterns , mismatch_mode , chunk_large_layers , svd_niter ,
787- lazy_load , force_clear_cache
809+ lazy_load , force_clear_cache , glob_skip_patterns
788810 )
789811
790812 return io .NodeOutput (_save_lora (output_sd , output_filename ))
@@ -819,7 +841,7 @@ def define_schema(cls):
819841 @classmethod
820842 def execute (cls , model_a , model_b , linear_ratio , conv_ratio , linear_max_rank , conv_max_rank ,
821843 chunk_large_layers , clamp_quantile , min_diff , mismatch_mode , output_filename ,
822- save_dtype , device , skip_patterns , lazy_load , force_clear_cache ) -> io .NodeOutput :
844+ save_dtype , device , skip_patterns , glob_skip_patterns , lazy_load , force_clear_cache ) -> io .NodeOutput :
823845
824846 model_a_path = folder_paths .get_full_path_or_raise ("diffusion_models" , model_a )
825847 model_b_path = folder_paths .get_full_path_or_raise ("diffusion_models" , model_b )
@@ -828,7 +850,8 @@ def execute(cls, model_a, model_b, linear_ratio, conv_ratio, linear_max_rank, co
828850 model_a_path , model_b_path , "ratio" , linear_ratio , conv_ratio ,
829851 device , save_dtype , linear_max_rank , conv_max_rank ,
830852 clamp_quantile , min_diff , skip_patterns , mismatch_mode , chunk_large_layers ,
831- lazy_load = lazy_load , force_clear_cache = force_clear_cache
853+ lazy_load = lazy_load , force_clear_cache = force_clear_cache ,
854+ glob_skip_patterns = glob_skip_patterns
832855 )
833856
834857 return io .NodeOutput (_save_lora (output_sd , output_filename ))
@@ -863,7 +886,7 @@ def define_schema(cls):
863886 @classmethod
864887 def execute (cls , model_a , model_b , linear_quantile , conv_quantile , linear_max_rank , conv_max_rank ,
865888 chunk_large_layers , clamp_quantile , min_diff , mismatch_mode , output_filename ,
866- save_dtype , device , skip_patterns , lazy_load , force_clear_cache ) -> io .NodeOutput :
889+ save_dtype , device , skip_patterns , glob_skip_patterns , lazy_load , force_clear_cache ) -> io .NodeOutput :
867890
868891 model_a_path = folder_paths .get_full_path_or_raise ("diffusion_models" , model_a )
869892 model_b_path = folder_paths .get_full_path_or_raise ("diffusion_models" , model_b )
@@ -872,7 +895,8 @@ def execute(cls, model_a, model_b, linear_quantile, conv_quantile, linear_max_ra
872895 model_a_path , model_b_path , "quantile" , linear_quantile , conv_quantile ,
873896 device , save_dtype , linear_max_rank , conv_max_rank ,
874897 clamp_quantile , min_diff , skip_patterns , mismatch_mode , chunk_large_layers ,
875- lazy_load = lazy_load , force_clear_cache = force_clear_cache
898+ lazy_load = lazy_load , force_clear_cache = force_clear_cache ,
899+ glob_skip_patterns = glob_skip_patterns
876900 )
877901
878902 return io .NodeOutput (_save_lora (output_sd , output_filename ))
@@ -905,7 +929,7 @@ def define_schema(cls):
905929 @classmethod
906930 def execute (cls , model_a , model_b , knee_method , linear_max_rank , conv_max_rank ,
907931 chunk_large_layers , clamp_quantile , min_diff , mismatch_mode , output_filename ,
908- save_dtype , device , skip_patterns , lazy_load , force_clear_cache ) -> io .NodeOutput :
932+ save_dtype , device , skip_patterns , glob_skip_patterns , lazy_load , force_clear_cache ) -> io .NodeOutput :
909933
910934 model_a_path = folder_paths .get_full_path_or_raise ("diffusion_models" , model_a )
911935 model_b_path = folder_paths .get_full_path_or_raise ("diffusion_models" , model_b )
@@ -914,7 +938,8 @@ def execute(cls, model_a, model_b, knee_method, linear_max_rank, conv_max_rank,
914938 model_a_path , model_b_path , knee_method , 0 , 0 ,
915939 device , save_dtype , linear_max_rank , conv_max_rank ,
916940 clamp_quantile , min_diff , skip_patterns , mismatch_mode , chunk_large_layers ,
917- lazy_load = lazy_load , force_clear_cache = force_clear_cache
941+ lazy_load = lazy_load , force_clear_cache = force_clear_cache ,
942+ glob_skip_patterns = glob_skip_patterns
918943 )
919944
920945 return io .NodeOutput (_save_lora (output_sd , output_filename ))
@@ -949,7 +974,7 @@ def define_schema(cls):
949974 @classmethod
950975 def execute (cls , model_a , model_b , linear_target , conv_target , linear_max_rank , conv_max_rank ,
951976 chunk_large_layers , clamp_quantile , min_diff , mismatch_mode , output_filename ,
952- save_dtype , device , skip_patterns , lazy_load , force_clear_cache ) -> io .NodeOutput :
977+ save_dtype , device , skip_patterns , glob_skip_patterns , lazy_load , force_clear_cache ) -> io .NodeOutput :
953978
954979 model_a_path = folder_paths .get_full_path_or_raise ("diffusion_models" , model_a )
955980 model_b_path = folder_paths .get_full_path_or_raise ("diffusion_models" , model_b )
@@ -958,7 +983,8 @@ def execute(cls, model_a, model_b, linear_target, conv_target, linear_max_rank,
958983 model_a_path , model_b_path , "sv_fro" , linear_target , conv_target ,
959984 device , save_dtype , linear_max_rank , conv_max_rank ,
960985 clamp_quantile , min_diff , skip_patterns , mismatch_mode , chunk_large_layers ,
961- lazy_load = lazy_load , force_clear_cache = force_clear_cache
986+ lazy_load = lazy_load , force_clear_cache = force_clear_cache ,
987+ glob_skip_patterns = glob_skip_patterns
962988 )
963989
964990 return io .NodeOutput (_save_lora (output_sd , output_filename ))
0 commit comments