@@ -46,14 +46,14 @@ def _index_sv_ratio(S: torch.Tensor, ratio: float) -> int:
4646
4747def _index_sv_cumulative (S : torch .Tensor , target : float , max_rank : int = None ) -> int :
4848 """Cumulative mode - keep enough SVs to reach target % of total.
49-
49+
5050 Calculates relative to max_rank if provided, otherwise relative to full.
5151 """
5252 if max_rank is not None and max_rank < len (S ):
5353 total = torch .sum (S [:max_rank ])
5454 else :
5555 total = torch .sum (S )
56-
56+
5757 if total < 1e-8 :
5858 return 1
5959 cumsum = torch .cumsum (S , dim = 0 ) / total
@@ -63,7 +63,7 @@ def _index_sv_cumulative(S: torch.Tensor, target: float, max_rank: int = None) -
6363
6464def _index_sv_fro (S : torch .Tensor , target : float , max_rank : int = None ) -> int :
6565 """Frobenius norm mode - preserve target fraction of Frobenius norm.
66-
66+
6767 Calculates relative to max_rank if provided, otherwise relative to full.
6868 This means "retain target% of what's achievable within max_rank".
6969 """
@@ -75,10 +75,10 @@ def _index_sv_fro(S: torch.Tensor, target: float, max_rank: int = None) -> int:
7575 else :
7676 S_sq = S .pow (2 )
7777 total_sq = torch .sum (S_sq )
78-
78+
7979 if total_sq < 1e-8 :
8080 return 1
81-
81+
8282 # Cumsum of all S (not capped) to find where we reach target
8383 cumsum = torch .cumsum (S .pow (2 ), dim = 0 ) / total_sq
8484 rank = int (torch .searchsorted (cumsum , target ** 2 ).item ()) + 1
@@ -172,50 +172,50 @@ def _svd_extract_linear_lowrank(
172172) -> tuple :
173173 """
174174 Low-rank SVD decomposition using svd_lowrank for fixed rank extraction.
175-
175+
176176 Much faster than full SVD when rank << min(m, n) because it only computes
177177 the top-k singular values using randomized algorithms.
178-
178+
179179 Args:
180180 niter: Power iterations for SVD accuracy (higher = more accurate but slower)
181-
181+
182182 Returns:
183183 (lora_down, lora_up, diff), "low rank" or (weight, "full")
184184 """
185185 weight = weight .to (device = device , dtype = torch .float32 )
186186 out_dim , in_dim = weight .shape
187-
187+
188188 # Clamp rank to valid range
189189 max_possible_rank = min (out_dim , in_dim )
190- rank = max (1 , min (rank , max_possible_rank - 1 ))
191-
190+ rank = max (1 , min (rank , max_possible_rank ))
191+
192192 # Check if decomposition is worthwhile
193- if rank >= max_possible_rank // 2 :
193+ if rank >= max_possible_rank :
194194 return weight , "full"
195-
195+
196196 # Use svd_lowrank for faster low-rank approximation
197197 try :
198198 U , S , V = torch .svd_lowrank (weight , q = rank , niter = niter )
199199 # U: [out_dim, rank], S: [rank], V: [in_dim, rank]
200200 Vh = V .T # [rank, in_dim]
201201 except Exception as e :
202202 raise RuntimeError (f"svd_lowrank failed: { e } " )
203-
203+
204204 # Clamp outliers
205205 if clamp_quantile < 1.0 and len (S ) > 0 :
206206 try :
207207 max_val = torch .quantile (S , clamp_quantile )
208208 S = S .clamp (max = max_val )
209209 except RuntimeError :
210210 pass
211-
211+
212212 # Construct LoRA matrices
213213 lora_up = U @ torch .diag (S ) # [out_dim, rank]
214214 lora_down = Vh # [rank, in_dim]
215-
215+
216216 # Compute reconstruction diff
217217 diff = weight - (lora_up @ lora_down )
218-
218+
219219 return (lora_down , lora_up , diff ), "low rank"
220220
221221
@@ -239,7 +239,7 @@ def _svd_extract_linear(
239239 """
240240 weight = weight .to (device = device , dtype = torch .float32 )
241241 out_dim , in_dim = weight .shape
242-
242+
243243 # For fixed rank, use the faster svd_lowrank
244244 if mode == "fixed" and mode_param > 0 :
245245 target_rank = int (mode_param )
@@ -257,7 +257,7 @@ def _svd_extract_linear(
257257 rank = _compute_rank (S , mode , mode_param , max_rank )
258258
259259 # Check if decomposition is worthwhile
260- if rank >= min (out_dim , in_dim ) // 2 :
260+ if rank >= min (out_dim , in_dim ):
261261 return weight , "full"
262262
263263 # Truncate to rank
@@ -303,7 +303,7 @@ def _svd_extract_conv(
303303
304304 # Flatten for SVD
305305 if is_1x1 :
306- mat = weight .squeeze ( ) # [out_ch, in_ch]
306+ mat = weight .view ( out_ch , in_ch ) # [out_ch, in_ch]
307307 else :
308308 mat = weight .reshape (out_ch , - 1 ) # [out_ch, in_ch*k*k]
309309
@@ -382,8 +382,11 @@ def _extract_chunked_layer(
382382 return None , None , 0
383383
384384 # Combine chunks
385+ # Note: Using the first chunk's down matrix as the basis for all chunks
386+ # is a better approximation than a simple mean if the chunks share the same input space (like QKV).
387+ # Ideally, we would use a more sophisticated joint SVD, but this is a reasonable fallback.
385388 combined_lora_up = torch .cat (all_lora_up , dim = 0 )
386- combined_lora_down = torch . stack ( all_lora_down , dim = 0 ). mean ( dim = 0 )
389+ combined_lora_down = all_lora_down [ 0 ]
387390 rank = combined_lora_down .shape [0 ]
388391
389392 return combined_lora_up , combined_lora_down , rank
@@ -476,39 +479,40 @@ def extract_lora_from_files(
476479 try :
477480 keys_a = set (handler_a .keys ())
478481 keys_b = set (handler_b .keys ())
479-
480482 weight_keys = [k for k in keys_a if k .endswith (".weight" )]
481-
482483 pbar = comfy .utils .ProgressBar (len (weight_keys ))
483-
484484 stats = {"extracted" : 0 , "full" : 0 , "skipped" : 0 , "chunked" : 0 }
485485
486- for key in tqdm ( weight_keys , desc = "Extracting LoRA" , unit = "layers" ):
486+ def _process_layer ( key ):
487487 lora_name = _format_lora_key (key )
488488
489489 if _matches_any_pattern (key , skip_patterns ):
490- stats ["skipped" ] += 1
491- pbar .update (1 )
492- continue
490+ return "skipped" , None
493491
494492 # Load tensors with pinned memory for CUDA
495493 use_pinned = device == 'cuda'
496-
494+
497495 if key not in keys_b :
498496 if mismatch_mode == "skip" :
499- stats ["skipped" ] += 1
500- pbar .update (1 )
501- continue
497+ return "skipped" , None
498+ if mismatch_mode == "error" :
499+ raise ValueError (f"Key { key } not found in model B" )
500+
502501 cpu_a = handler_a .get_tensor (key )
503502 if use_pinned :
504503 weight_diff = transfer_to_gpu_pinned (cpu_a , device , torch .float32 )
505504 else :
506505 weight_diff = cpu_a .to (device = device , dtype = torch .float32 )
507506 del cpu_a
507+
508+ if mismatch_mode == "zeros" :
509+ # For zeros mode, we treat the missing base as a zeroed tensor of same shape
510+ # which is already captured by weight_diff = tensor_a
511+ pass
508512 else :
509513 cpu_a = handler_a .get_tensor (key )
510514 cpu_b = handler_b .get_tensor (key )
511-
515+
512516 if use_pinned :
513517 tensor_a = transfer_to_gpu_pinned (cpu_a , device , torch .float32 )
514518 tensor_b = transfer_to_gpu_pinned (cpu_b , device , torch .float32 )
@@ -519,10 +523,12 @@ def extract_lora_from_files(
519523
520524 if tensor_a .shape != tensor_b .shape :
521525 if mismatch_mode == "skip" :
522- stats ["skipped" ] += 1
523- pbar .update (1 )
524526 del tensor_a , tensor_b
525- continue
527+ return "skipped" , None
528+ if mismatch_mode == "error" :
529+ raise ValueError (f"Shape mismatch for { key } : { tensor_a .shape } vs { tensor_b .shape } " )
530+
531+ # For zeros/fallback, use tensor_a as the difference
526532 weight_diff = tensor_a
527533 del tensor_b
528534 else :
@@ -531,19 +537,16 @@ def extract_lora_from_files(
531537
532538 # Skip small differences
533539 if min_diff > 0 and weight_diff .abs ().max () < min_diff :
534- stats ["skipped" ] += 1
535540 del weight_diff
536- pbar .update (1 )
537- continue
541+ return "skipped" , None
538542
539543 # Skip 1D tensors
540544 if weight_diff .ndim < 2 :
541- stats ["skipped" ] += 1
542545 del weight_diff
543- pbar .update (1 )
544- continue
546+ return "skipped" , None
545547
546548 is_conv = weight_diff .ndim == 4
549+ layer_results = {}
547550
548551 try :
549552 if is_conv :
@@ -564,32 +567,36 @@ def extract_lora_from_files(
564567 weight_diff , num_chunks , mode , linear_param , device , linear_max_rank
565568 )
566569 if lora_up is not None :
567- output_sd [f"{ lora_name } .lora_up.weight" ] = lora_up .to (save_torch_dtype ).cpu ().contiguous ()
568- output_sd [f"{ lora_name } .lora_down.weight" ] = lora_down .to (save_torch_dtype ).cpu ().contiguous ()
569- output_sd [f"{ lora_name } .alpha" ] = torch .tensor (rank ).to (save_torch_dtype )
570- stats ["chunked" ] += 1
570+ layer_results [f"{ lora_name } .lora_up.weight" ] = lora_up .to (save_torch_dtype ).cpu ().contiguous ()
571+ layer_results [f"{ lora_name } .lora_down.weight" ] = lora_down .to (save_torch_dtype ).cpu ().contiguous ()
572+ layer_results [f"{ lora_name } .alpha" ] = torch .tensor (rank ).to (save_torch_dtype )
571573 del weight_diff
572- pbar .update (1 )
573- continue
574+ return "chunked" , layer_results
574575
575576 print (f"[LoRA Extract] Failed: { key } : { e } " )
576- stats ["skipped" ] += 1
577577 del weight_diff
578- pbar .update (1 )
579- continue
578+ return "skipped" , None
580579
581580 # Store result
582581 if mode_str == "full" :
583- output_sd [f"{ lora_name } .diff" ] = weight_diff .to (save_torch_dtype ).cpu ().contiguous ()
584- stats [ "full" ] += 1
582+ layer_results [f"{ lora_name } .diff" ] = weight_diff .to (save_torch_dtype ).cpu ().contiguous ()
583+ status = "full"
585584 else :
586585 lora_down , lora_up , _ = result
587- output_sd [f"{ lora_name } .lora_down.weight" ] = lora_down .to (save_torch_dtype ).cpu ().contiguous ()
588- output_sd [f"{ lora_name } .lora_up.weight" ] = lora_up .to (save_torch_dtype ).cpu ().contiguous ()
589- output_sd [f"{ lora_name } .alpha" ] = torch .tensor (lora_down .shape [0 ]).to (save_torch_dtype )
590- stats [ "extracted" ] += 1
586+ layer_results [f"{ lora_name } .lora_down.weight" ] = lora_down .to (save_torch_dtype ).cpu ().contiguous ()
587+ layer_results [f"{ lora_name } .lora_up.weight" ] = lora_up .to (save_torch_dtype ).cpu ().contiguous ()
588+ layer_results [f"{ lora_name } .alpha" ] = torch .tensor (lora_down .shape [0 ]).to (save_torch_dtype )
589+ status = "extracted"
591590
592591 del weight_diff
592+ return status , layer_results
593+
594+ for key in tqdm (weight_keys , desc = "Extracting LoRA" , unit = "layers" ):
595+ status , layer_sd = _process_layer (key )
596+ stats [status ] += 1
597+ if layer_sd :
598+ output_sd .update (layer_sd )
599+
593600 if force_clear_cache :
594601 import gc
595602 gc .collect ()
@@ -706,7 +713,7 @@ def _format_lora_key(key: str) -> str:
706713 # Handle already correct prefixes
707714 if key .startswith ("diffusion_model." ) or key .startswith ("transformer." ):
708715 return key
709-
716+
710717 # Handle known Unet blocks (Diffusers format UNet)
711718 if any (key .startswith (p ) for p in ["down_blocks" , "up_blocks" , "mid_block" , "conv_in" , "conv_out" , "time_embedding" , "class_embedding" ]):
712719 return f"diffusion_model.{ key } "
@@ -746,9 +753,9 @@ def define_schema(cls):
746753 description = "Extract LoRA with specified fixed rank for each layer type." ,
747754 inputs = [
748755 * _get_model_inputs (),
749- io .Int .Input ("linear_dim" , default = 64 , min = 1 , max = 4096 ,
756+ io .Int .Input ("linear_dim" , default = 64 , min = 1 , max = 16384 ,
750757 tooltip = "Rank for linear/attention layers" ),
751- io .Int .Input ("conv_dim" , default = 32 , min = 1 , max = 4096 ,
758+ io .Int .Input ("conv_dim" , default = 32 , min = 1 , max = 16384 ,
752759 tooltip = "Rank for conv layers" ),
753760 io .Int .Input ("svd_niter" , default = 2 , min = 0 , max = 10 ,
754761 tooltip = "SVD power iterations (higher = more accurate but slower)" ),
@@ -768,7 +775,7 @@ def execute(cls, model_a, model_b, linear_dim, conv_dim, svd_niter, chunk_large_
768775
769776 output_sd = extract_lora_from_files (
770777 model_a_path , model_b_path , "fixed" , linear_dim , conv_dim ,
771- "lora_unet_" , device , save_dtype , linear_dim , conv_dim ,
778+ device , save_dtype , linear_dim , conv_dim ,
772779 clamp_quantile , min_diff , skip_patterns , mismatch_mode , chunk_large_layers , svd_niter ,
773780 lazy_load , force_clear_cache
774781 )
@@ -792,9 +799,9 @@ def define_schema(cls):
792799 tooltip = "Ratio threshold for linear layers (higher = more SVs kept)" ),
793800 io .Float .Input ("conv_ratio" , default = 2.0 , min = 1.0 , max = 100.0 , step = 0.1 ,
794801 tooltip = "Ratio threshold for conv layers (higher = more SVs kept)" ),
795- io .Int .Input ("linear_max_rank" , default = 128 , min = 1 , max = 3072 ,
802+ io .Int .Input ("linear_max_rank" , default = 128 , min = 1 , max = 16384 ,
796803 tooltip = "Maximum rank for linear layers" ),
797- io .Int .Input ("conv_max_rank" , default = 128 , min = 1 , max = 3072 ,
804+ io .Int .Input ("conv_max_rank" , default = 128 , min = 1 , max = 16384 ,
798805 tooltip = "Maximum rank for conv layers" ),
799806 * _get_common_inputs (),
800807 ],
@@ -812,7 +819,7 @@ def execute(cls, model_a, model_b, linear_ratio, conv_ratio, linear_max_rank, co
812819
813820 output_sd = extract_lora_from_files (
814821 model_a_path , model_b_path , "ratio" , linear_ratio , conv_ratio ,
815- "lora_unet_" , device , save_dtype , linear_max_rank , conv_max_rank ,
822+ device , save_dtype , linear_max_rank , conv_max_rank ,
816823 clamp_quantile , min_diff , skip_patterns , mismatch_mode , chunk_large_layers ,
817824 lazy_load = lazy_load , force_clear_cache = force_clear_cache
818825 )
@@ -836,9 +843,9 @@ def define_schema(cls):
836843 tooltip = "Target cumulative % for linear layers" ),
837844 io .Float .Input ("conv_quantile" , default = 0.9 , min = 0.0 , max = 1.0 , step = 0.01 ,
838845 tooltip = "Target cumulative % for conv layers" ),
839- io .Int .Input ("linear_max_rank" , default = 128 , min = 1 , max = 3072 ,
846+ io .Int .Input ("linear_max_rank" , default = 128 , min = 1 , max = 16384 ,
840847 tooltip = "Maximum rank for linear layers" ),
841- io .Int .Input ("conv_max_rank" , default = 128 , min = 1 , max = 3072 ,
848+ io .Int .Input ("conv_max_rank" , default = 128 , min = 1 , max = 16384 ,
842849 tooltip = "Maximum rank for conv layers" ),
843850 * _get_common_inputs (),
844851 ],
@@ -856,7 +863,7 @@ def execute(cls, model_a, model_b, linear_quantile, conv_quantile, linear_max_ra
856863
857864 output_sd = extract_lora_from_files (
858865 model_a_path , model_b_path , "quantile" , linear_quantile , conv_quantile ,
859- "lora_unet_" , device , save_dtype , linear_max_rank , conv_max_rank ,
866+ device , save_dtype , linear_max_rank , conv_max_rank ,
860867 clamp_quantile , min_diff , skip_patterns , mismatch_mode , chunk_large_layers ,
861868 lazy_load = lazy_load , force_clear_cache = force_clear_cache
862869 )
@@ -878,9 +885,9 @@ def define_schema(cls):
878885 * _get_model_inputs (),
879886 io .Combo .Input ("knee_method" , options = ["sv_knee" , "sv_cumulative_knee" ],
880887 default = "sv_knee" , tooltip = "Knee detection method" ),
881- io .Int .Input ("linear_max_rank" , default = 128 , min = 1 , max = 3072 ,
888+ io .Int .Input ("linear_max_rank" , default = 128 , min = 1 , max = 16384 ,
882889 tooltip = "Maximum rank for linear layers" ),
883- io .Int .Input ("conv_max_rank" , default = 128 , min = 1 , max = 3072 ,
890+ io .Int .Input ("conv_max_rank" , default = 128 , min = 1 , max = 16384 ,
884891 tooltip = "Maximum rank for conv layers" ),
885892 * _get_common_inputs (),
886893 ],
@@ -898,7 +905,7 @@ def execute(cls, model_a, model_b, knee_method, linear_max_rank, conv_max_rank,
898905
899906 output_sd = extract_lora_from_files (
900907 model_a_path , model_b_path , knee_method , 0 , 0 ,
901- "lora_unet_" , device , save_dtype , linear_max_rank , conv_max_rank ,
908+ device , save_dtype , linear_max_rank , conv_max_rank ,
902909 clamp_quantile , min_diff , skip_patterns , mismatch_mode , chunk_large_layers ,
903910 lazy_load = lazy_load , force_clear_cache = force_clear_cache
904911 )
@@ -922,9 +929,9 @@ def define_schema(cls):
922929 tooltip = "Target Frobenius norm fraction for linear" ),
923930 io .Float .Input ("conv_target" , default = 0.9 , min = 0.0 , max = 1.0 , step = 0.01 ,
924931 tooltip = "Target Frobenius norm fraction for conv" ),
925- io .Int .Input ("linear_max_rank" , default = 128 , min = 1 , max = 3072 ,
932+ io .Int .Input ("linear_max_rank" , default = 128 , min = 1 , max = 16384 ,
926933 tooltip = "Maximum rank for linear layers" ),
927- io .Int .Input ("conv_max_rank" , default = 128 , min = 1 , max = 3072 ,
934+ io .Int .Input ("conv_max_rank" , default = 128 , min = 1 , max = 16384 ,
928935 tooltip = "Maximum rank for conv layers" ),
929936 * _get_common_inputs (),
930937 ],
@@ -942,7 +949,7 @@ def execute(cls, model_a, model_b, linear_target, conv_target, linear_max_rank,
942949
943950 output_sd = extract_lora_from_files (
944951 model_a_path , model_b_path , "sv_fro" , linear_target , conv_target ,
945- "lora_unet_" , device , save_dtype , linear_max_rank , conv_max_rank ,
952+ device , save_dtype , linear_max_rank , conv_max_rank ,
946953 clamp_quantile , min_diff , skip_patterns , mismatch_mode , chunk_large_layers ,
947954 lazy_load = lazy_load , force_clear_cache = force_clear_cache
948955 )
0 commit comments