Skip to content

Commit 1e53a87

Browse files
authored
Merge pull request #8 from silveroxides/feature/expand-lora-functionality
expand on functionality and reduce restrictions when extracting LoRA
2 parents 5df8c67 + b497022 commit 1e53a87

4 files changed

Lines changed: 161 additions & 133 deletions

File tree

nodes/lora_extract_svd.py

Lines changed: 80 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ def _index_sv_ratio(S: torch.Tensor, ratio: float) -> int:
4646

4747
def _index_sv_cumulative(S: torch.Tensor, target: float, max_rank: int = None) -> int:
4848
"""Cumulative mode - keep enough SVs to reach target % of total.
49-
49+
5050
Calculates relative to max_rank if provided, otherwise relative to full.
5151
"""
5252
if max_rank is not None and max_rank < len(S):
5353
total = torch.sum(S[:max_rank])
5454
else:
5555
total = torch.sum(S)
56-
56+
5757
if total < 1e-8:
5858
return 1
5959
cumsum = torch.cumsum(S, dim=0) / total
@@ -63,7 +63,7 @@ def _index_sv_cumulative(S: torch.Tensor, target: float, max_rank: int = None) -
6363

6464
def _index_sv_fro(S: torch.Tensor, target: float, max_rank: int = None) -> int:
6565
"""Frobenius norm mode - preserve target fraction of Frobenius norm.
66-
66+
6767
Calculates relative to max_rank if provided, otherwise relative to full.
6868
This means "retain target% of what's achievable within max_rank".
6969
"""
@@ -75,10 +75,10 @@ def _index_sv_fro(S: torch.Tensor, target: float, max_rank: int = None) -> int:
7575
else:
7676
S_sq = S.pow(2)
7777
total_sq = torch.sum(S_sq)
78-
78+
7979
if total_sq < 1e-8:
8080
return 1
81-
81+
8282
# Cumsum of all S (not capped) to find where we reach target
8383
cumsum = torch.cumsum(S.pow(2), dim=0) / total_sq
8484
rank = int(torch.searchsorted(cumsum, target ** 2).item()) + 1
@@ -172,50 +172,50 @@ def _svd_extract_linear_lowrank(
172172
) -> tuple:
173173
"""
174174
Low-rank SVD decomposition using svd_lowrank for fixed rank extraction.
175-
175+
176176
Much faster than full SVD when rank << min(m, n) because it only computes
177177
the top-k singular values using randomized algorithms.
178-
178+
179179
Args:
180180
niter: Power iterations for SVD accuracy (higher = more accurate but slower)
181-
181+
182182
Returns:
183183
(lora_down, lora_up, diff), "low rank" or (weight, "full")
184184
"""
185185
weight = weight.to(device=device, dtype=torch.float32)
186186
out_dim, in_dim = weight.shape
187-
187+
188188
# Clamp rank to valid range
189189
max_possible_rank = min(out_dim, in_dim)
190-
rank = max(1, min(rank, max_possible_rank - 1))
191-
190+
rank = max(1, min(rank, max_possible_rank))
191+
192192
# Check if decomposition is worthwhile
193-
if rank >= max_possible_rank // 2:
193+
if rank >= max_possible_rank:
194194
return weight, "full"
195-
195+
196196
# Use svd_lowrank for faster low-rank approximation
197197
try:
198198
U, S, V = torch.svd_lowrank(weight, q=rank, niter=niter)
199199
# U: [out_dim, rank], S: [rank], V: [in_dim, rank]
200200
Vh = V.T # [rank, in_dim]
201201
except Exception as e:
202202
raise RuntimeError(f"svd_lowrank failed: {e}")
203-
203+
204204
# Clamp outliers
205205
if clamp_quantile < 1.0 and len(S) > 0:
206206
try:
207207
max_val = torch.quantile(S, clamp_quantile)
208208
S = S.clamp(max=max_val)
209209
except RuntimeError:
210210
pass
211-
211+
212212
# Construct LoRA matrices
213213
lora_up = U @ torch.diag(S) # [out_dim, rank]
214214
lora_down = Vh # [rank, in_dim]
215-
215+
216216
# Compute reconstruction diff
217217
diff = weight - (lora_up @ lora_down)
218-
218+
219219
return (lora_down, lora_up, diff), "low rank"
220220

221221

@@ -239,7 +239,7 @@ def _svd_extract_linear(
239239
"""
240240
weight = weight.to(device=device, dtype=torch.float32)
241241
out_dim, in_dim = weight.shape
242-
242+
243243
# For fixed rank, use the faster svd_lowrank
244244
if mode == "fixed" and mode_param > 0:
245245
target_rank = int(mode_param)
@@ -257,7 +257,7 @@ def _svd_extract_linear(
257257
rank = _compute_rank(S, mode, mode_param, max_rank)
258258

259259
# Check if decomposition is worthwhile
260-
if rank >= min(out_dim, in_dim) // 2:
260+
if rank >= min(out_dim, in_dim):
261261
return weight, "full"
262262

263263
# Truncate to rank
@@ -303,7 +303,7 @@ def _svd_extract_conv(
303303

304304
# Flatten for SVD
305305
if is_1x1:
306-
mat = weight.squeeze() # [out_ch, in_ch]
306+
mat = weight.view(out_ch, in_ch) # [out_ch, in_ch]
307307
else:
308308
mat = weight.reshape(out_ch, -1) # [out_ch, in_ch*k*k]
309309

@@ -382,8 +382,11 @@ def _extract_chunked_layer(
382382
return None, None, 0
383383

384384
# Combine chunks
385+
# Note: Using the first chunk's down matrix as the basis for all chunks
386+
# is a better approximation than a simple mean if the chunks share the same input space (like QKV).
387+
# Ideally, we would use a more sophisticated joint SVD, but this is a reasonable fallback.
385388
combined_lora_up = torch.cat(all_lora_up, dim=0)
386-
combined_lora_down = torch.stack(all_lora_down, dim=0).mean(dim=0)
389+
combined_lora_down = all_lora_down[0]
387390
rank = combined_lora_down.shape[0]
388391

389392
return combined_lora_up, combined_lora_down, rank
@@ -476,39 +479,40 @@ def extract_lora_from_files(
476479
try:
477480
keys_a = set(handler_a.keys())
478481
keys_b = set(handler_b.keys())
479-
480482
weight_keys = [k for k in keys_a if k.endswith(".weight")]
481-
482483
pbar = comfy.utils.ProgressBar(len(weight_keys))
483-
484484
stats = {"extracted": 0, "full": 0, "skipped": 0, "chunked": 0}
485485

486-
for key in tqdm(weight_keys, desc="Extracting LoRA", unit="layers"):
486+
def _process_layer(key):
487487
lora_name = _format_lora_key(key)
488488

489489
if _matches_any_pattern(key, skip_patterns):
490-
stats["skipped"] += 1
491-
pbar.update(1)
492-
continue
490+
return "skipped", None
493491

494492
# Load tensors with pinned memory for CUDA
495493
use_pinned = device == 'cuda'
496-
494+
497495
if key not in keys_b:
498496
if mismatch_mode == "skip":
499-
stats["skipped"] += 1
500-
pbar.update(1)
501-
continue
497+
return "skipped", None
498+
if mismatch_mode == "error":
499+
raise ValueError(f"Key {key} not found in model B")
500+
502501
cpu_a = handler_a.get_tensor(key)
503502
if use_pinned:
504503
weight_diff = transfer_to_gpu_pinned(cpu_a, device, torch.float32)
505504
else:
506505
weight_diff = cpu_a.to(device=device, dtype=torch.float32)
507506
del cpu_a
507+
508+
if mismatch_mode == "zeros":
509+
# For zeros mode, we treat the missing base as a zeroed tensor of same shape
510+
# which is already captured by weight_diff = tensor_a
511+
pass
508512
else:
509513
cpu_a = handler_a.get_tensor(key)
510514
cpu_b = handler_b.get_tensor(key)
511-
515+
512516
if use_pinned:
513517
tensor_a = transfer_to_gpu_pinned(cpu_a, device, torch.float32)
514518
tensor_b = transfer_to_gpu_pinned(cpu_b, device, torch.float32)
@@ -519,10 +523,12 @@ def extract_lora_from_files(
519523

520524
if tensor_a.shape != tensor_b.shape:
521525
if mismatch_mode == "skip":
522-
stats["skipped"] += 1
523-
pbar.update(1)
524526
del tensor_a, tensor_b
525-
continue
527+
return "skipped", None
528+
if mismatch_mode == "error":
529+
raise ValueError(f"Shape mismatch for {key}: {tensor_a.shape} vs {tensor_b.shape}")
530+
531+
# For zeros/fallback, use tensor_a as the difference
526532
weight_diff = tensor_a
527533
del tensor_b
528534
else:
@@ -531,19 +537,16 @@ def extract_lora_from_files(
531537

532538
# Skip small differences
533539
if min_diff > 0 and weight_diff.abs().max() < min_diff:
534-
stats["skipped"] += 1
535540
del weight_diff
536-
pbar.update(1)
537-
continue
541+
return "skipped", None
538542

539543
# Skip 1D tensors
540544
if weight_diff.ndim < 2:
541-
stats["skipped"] += 1
542545
del weight_diff
543-
pbar.update(1)
544-
continue
546+
return "skipped", None
545547

546548
is_conv = weight_diff.ndim == 4
549+
layer_results = {}
547550

548551
try:
549552
if is_conv:
@@ -564,32 +567,36 @@ def extract_lora_from_files(
564567
weight_diff, num_chunks, mode, linear_param, device, linear_max_rank
565568
)
566569
if lora_up is not None:
567-
output_sd[f"{lora_name}.lora_up.weight"] = lora_up.to(save_torch_dtype).cpu().contiguous()
568-
output_sd[f"{lora_name}.lora_down.weight"] = lora_down.to(save_torch_dtype).cpu().contiguous()
569-
output_sd[f"{lora_name}.alpha"] = torch.tensor(rank).to(save_torch_dtype)
570-
stats["chunked"] += 1
570+
layer_results[f"{lora_name}.lora_up.weight"] = lora_up.to(save_torch_dtype).cpu().contiguous()
571+
layer_results[f"{lora_name}.lora_down.weight"] = lora_down.to(save_torch_dtype).cpu().contiguous()
572+
layer_results[f"{lora_name}.alpha"] = torch.tensor(rank).to(save_torch_dtype)
571573
del weight_diff
572-
pbar.update(1)
573-
continue
574+
return "chunked", layer_results
574575

575576
print(f"[LoRA Extract] Failed: {key}: {e}")
576-
stats["skipped"] += 1
577577
del weight_diff
578-
pbar.update(1)
579-
continue
578+
return "skipped", None
580579

581580
# Store result
582581
if mode_str == "full":
583-
output_sd[f"{lora_name}.diff"] = weight_diff.to(save_torch_dtype).cpu().contiguous()
584-
stats["full"] += 1
582+
layer_results[f"{lora_name}.diff"] = weight_diff.to(save_torch_dtype).cpu().contiguous()
583+
status = "full"
585584
else:
586585
lora_down, lora_up, _ = result
587-
output_sd[f"{lora_name}.lora_down.weight"] = lora_down.to(save_torch_dtype).cpu().contiguous()
588-
output_sd[f"{lora_name}.lora_up.weight"] = lora_up.to(save_torch_dtype).cpu().contiguous()
589-
output_sd[f"{lora_name}.alpha"] = torch.tensor(lora_down.shape[0]).to(save_torch_dtype)
590-
stats["extracted"] += 1
586+
layer_results[f"{lora_name}.lora_down.weight"] = lora_down.to(save_torch_dtype).cpu().contiguous()
587+
layer_results[f"{lora_name}.lora_up.weight"] = lora_up.to(save_torch_dtype).cpu().contiguous()
588+
layer_results[f"{lora_name}.alpha"] = torch.tensor(lora_down.shape[0]).to(save_torch_dtype)
589+
status = "extracted"
591590

592591
del weight_diff
592+
return status, layer_results
593+
594+
for key in tqdm(weight_keys, desc="Extracting LoRA", unit="layers"):
595+
status, layer_sd = _process_layer(key)
596+
stats[status] += 1
597+
if layer_sd:
598+
output_sd.update(layer_sd)
599+
593600
if force_clear_cache:
594601
import gc
595602
gc.collect()
@@ -706,7 +713,7 @@ def _format_lora_key(key: str) -> str:
706713
# Handle already correct prefixes
707714
if key.startswith("diffusion_model.") or key.startswith("transformer."):
708715
return key
709-
716+
710717
# Handle known Unet blocks (Diffusers format UNet)
711718
if any(key.startswith(p) for p in ["down_blocks", "up_blocks", "mid_block", "conv_in", "conv_out", "time_embedding", "class_embedding"]):
712719
return f"diffusion_model.{key}"
@@ -746,9 +753,9 @@ def define_schema(cls):
746753
description="Extract LoRA with specified fixed rank for each layer type.",
747754
inputs=[
748755
*_get_model_inputs(),
749-
io.Int.Input("linear_dim", default=64, min=1, max=4096,
756+
io.Int.Input("linear_dim", default=64, min=1, max=16384,
750757
tooltip="Rank for linear/attention layers"),
751-
io.Int.Input("conv_dim", default=32, min=1, max=4096,
758+
io.Int.Input("conv_dim", default=32, min=1, max=16384,
752759
tooltip="Rank for conv layers"),
753760
io.Int.Input("svd_niter", default=2, min=0, max=10,
754761
tooltip="SVD power iterations (higher = more accurate but slower)"),
@@ -768,7 +775,7 @@ def execute(cls, model_a, model_b, linear_dim, conv_dim, svd_niter, chunk_large_
768775

769776
output_sd = extract_lora_from_files(
770777
model_a_path, model_b_path, "fixed", linear_dim, conv_dim,
771-
"lora_unet_", device, save_dtype, linear_dim, conv_dim,
778+
device, save_dtype, linear_dim, conv_dim,
772779
clamp_quantile, min_diff, skip_patterns, mismatch_mode, chunk_large_layers, svd_niter,
773780
lazy_load, force_clear_cache
774781
)
@@ -792,9 +799,9 @@ def define_schema(cls):
792799
tooltip="Ratio threshold for linear layers (higher = more SVs kept)"),
793800
io.Float.Input("conv_ratio", default=2.0, min=1.0, max=100.0, step=0.1,
794801
tooltip="Ratio threshold for conv layers (higher = more SVs kept)"),
795-
io.Int.Input("linear_max_rank", default=128, min=1, max=3072,
802+
io.Int.Input("linear_max_rank", default=128, min=1, max=16384,
796803
tooltip="Maximum rank for linear layers"),
797-
io.Int.Input("conv_max_rank", default=128, min=1, max=3072,
804+
io.Int.Input("conv_max_rank", default=128, min=1, max=16384,
798805
tooltip="Maximum rank for conv layers"),
799806
*_get_common_inputs(),
800807
],
@@ -812,7 +819,7 @@ def execute(cls, model_a, model_b, linear_ratio, conv_ratio, linear_max_rank, co
812819

813820
output_sd = extract_lora_from_files(
814821
model_a_path, model_b_path, "ratio", linear_ratio, conv_ratio,
815-
"lora_unet_", device, save_dtype, linear_max_rank, conv_max_rank,
822+
device, save_dtype, linear_max_rank, conv_max_rank,
816823
clamp_quantile, min_diff, skip_patterns, mismatch_mode, chunk_large_layers,
817824
lazy_load=lazy_load, force_clear_cache=force_clear_cache
818825
)
@@ -836,9 +843,9 @@ def define_schema(cls):
836843
tooltip="Target cumulative % for linear layers"),
837844
io.Float.Input("conv_quantile", default=0.9, min=0.0, max=1.0, step=0.01,
838845
tooltip="Target cumulative % for conv layers"),
839-
io.Int.Input("linear_max_rank", default=128, min=1, max=3072,
846+
io.Int.Input("linear_max_rank", default=128, min=1, max=16384,
840847
tooltip="Maximum rank for linear layers"),
841-
io.Int.Input("conv_max_rank", default=128, min=1, max=3072,
848+
io.Int.Input("conv_max_rank", default=128, min=1, max=16384,
842849
tooltip="Maximum rank for conv layers"),
843850
*_get_common_inputs(),
844851
],
@@ -856,7 +863,7 @@ def execute(cls, model_a, model_b, linear_quantile, conv_quantile, linear_max_ra
856863

857864
output_sd = extract_lora_from_files(
858865
model_a_path, model_b_path, "quantile", linear_quantile, conv_quantile,
859-
"lora_unet_", device, save_dtype, linear_max_rank, conv_max_rank,
866+
device, save_dtype, linear_max_rank, conv_max_rank,
860867
clamp_quantile, min_diff, skip_patterns, mismatch_mode, chunk_large_layers,
861868
lazy_load=lazy_load, force_clear_cache=force_clear_cache
862869
)
@@ -878,9 +885,9 @@ def define_schema(cls):
878885
*_get_model_inputs(),
879886
io.Combo.Input("knee_method", options=["sv_knee", "sv_cumulative_knee"],
880887
default="sv_knee", tooltip="Knee detection method"),
881-
io.Int.Input("linear_max_rank", default=128, min=1, max=3072,
888+
io.Int.Input("linear_max_rank", default=128, min=1, max=16384,
882889
tooltip="Maximum rank for linear layers"),
883-
io.Int.Input("conv_max_rank", default=128, min=1, max=3072,
890+
io.Int.Input("conv_max_rank", default=128, min=1, max=16384,
884891
tooltip="Maximum rank for conv layers"),
885892
*_get_common_inputs(),
886893
],
@@ -898,7 +905,7 @@ def execute(cls, model_a, model_b, knee_method, linear_max_rank, conv_max_rank,
898905

899906
output_sd = extract_lora_from_files(
900907
model_a_path, model_b_path, knee_method, 0, 0,
901-
"lora_unet_", device, save_dtype, linear_max_rank, conv_max_rank,
908+
device, save_dtype, linear_max_rank, conv_max_rank,
902909
clamp_quantile, min_diff, skip_patterns, mismatch_mode, chunk_large_layers,
903910
lazy_load=lazy_load, force_clear_cache=force_clear_cache
904911
)
@@ -922,9 +929,9 @@ def define_schema(cls):
922929
tooltip="Target Frobenius norm fraction for linear"),
923930
io.Float.Input("conv_target", default=0.9, min=0.0, max=1.0, step=0.01,
924931
tooltip="Target Frobenius norm fraction for conv"),
925-
io.Int.Input("linear_max_rank", default=128, min=1, max=3072,
932+
io.Int.Input("linear_max_rank", default=128, min=1, max=16384,
926933
tooltip="Maximum rank for linear layers"),
927-
io.Int.Input("conv_max_rank", default=128, min=1, max=3072,
934+
io.Int.Input("conv_max_rank", default=128, min=1, max=16384,
928935
tooltip="Maximum rank for conv layers"),
929936
*_get_common_inputs(),
930937
],
@@ -942,7 +949,7 @@ def execute(cls, model_a, model_b, linear_target, conv_target, linear_max_rank,
942949

943950
output_sd = extract_lora_from_files(
944951
model_a_path, model_b_path, "sv_fro", linear_target, conv_target,
945-
"lora_unet_", device, save_dtype, linear_max_rank, conv_max_rank,
952+
device, save_dtype, linear_max_rank, conv_max_rank,
946953
clamp_quantile, min_diff, skip_patterns, mismatch_mode, chunk_large_layers,
947954
lazy_load=lazy_load, force_clear_cache=force_clear_cache
948955
)

0 commit comments

Comments
 (0)