Skip to content

Commit c6b65d6

Browse files
committed
fix: frobenius/cumulative norm calculation now relative to max_rank
1 parent 32281ad commit c6b65d6

3 files changed

Lines changed: 34 additions & 15 deletions

File tree

nodes/lora_extract_svd.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,43 @@ def _index_sv_ratio(S: torch.Tensor, ratio: float) -> int:
4444
return max(1, min(rank, len(S) - 1))
4545

4646

47-
def _index_sv_cumulative(S: torch.Tensor, target: float) -> int:
48-
"""Cumulative mode - keep enough SVs to reach target % of total."""
49-
total = torch.sum(S)
47+
def _index_sv_cumulative(S: torch.Tensor, target: float, max_rank: int = None) -> int:
48+
"""Cumulative mode - keep enough SVs to reach target % of total.
49+
50+
Calculates relative to max_rank if provided, otherwise relative to full.
51+
"""
52+
if max_rank is not None and max_rank < len(S):
53+
total = torch.sum(S[:max_rank])
54+
else:
55+
total = torch.sum(S)
56+
5057
if total < 1e-8:
5158
return 1
5259
cumsum = torch.cumsum(S, dim=0) / total
5360
rank = int(torch.searchsorted(cumsum, target).item()) + 1
5461
return max(1, min(rank, len(S) - 1))
5562

5663

57-
def _index_sv_fro(S: torch.Tensor, target: float) -> int:
58-
"""Frobenius norm mode - preserve target fraction of Frobenius norm."""
59-
S_sq = S.pow(2)
60-
total_sq = torch.sum(S_sq)
64+
def _index_sv_fro(S: torch.Tensor, target: float, max_rank: int = None) -> int:
65+
"""Frobenius norm mode - preserve target fraction of Frobenius norm.
66+
67+
Calculates relative to max_rank if provided, otherwise relative to full.
68+
This means "retain target% of what's achievable within max_rank".
69+
"""
70+
if max_rank is not None and max_rank < len(S):
71+
# Calculate relative to what's achievable within max_rank
72+
S_capped = S[:max_rank]
73+
S_sq = S_capped.pow(2)
74+
total_sq = torch.sum(S_sq)
75+
else:
76+
S_sq = S.pow(2)
77+
total_sq = torch.sum(S_sq)
78+
6179
if total_sq < 1e-8:
6280
return 1
63-
cumsum = torch.cumsum(S_sq, dim=0) / total_sq
81+
82+
# Cumsum of all S (not capped) to find where we reach target
83+
cumsum = torch.cumsum(S.pow(2), dim=0) / total_sq
6484
rank = int(torch.searchsorted(cumsum, target ** 2).item()) + 1
6585
return max(1, min(rank, len(S) - 1))
6686

@@ -122,9 +142,9 @@ def _compute_rank(S: torch.Tensor, mode: str, mode_param: float,
122142
elif mode == "ratio":
123143
rank = _index_sv_ratio(S, mode_param)
124144
elif mode == "quantile" or mode == "sv_cumulative":
125-
rank = _index_sv_cumulative(S, mode_param)
145+
rank = _index_sv_cumulative(S, mode_param, max_rank)
126146
elif mode == "sv_fro":
127-
rank = _index_sv_fro(S, mode_param)
147+
rank = _index_sv_fro(S, mode_param, max_rank)
128148
elif mode == "sv_knee":
129149
rank = _index_sv_knee(S)
130150
elif mode == "sv_cumulative_knee":

nodes/lora_resize.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,14 +284,13 @@ def _compute_resize(
284284
"""Compute new rank and alpha based on resize method."""
285285

286286
if dynamic_method == "sv_ratio" and dynamic_param is not None:
287-
# Note: _index_sv_ratio uses S[0]*ratio, kohya uses S[0]/ratio
288-
# Use kohya convention here for consistency with their tool
287+
# Use kohya convention: S[0]/ratio
289288
min_sv = S[0] / dynamic_param
290289
new_rank = max(1, int(torch.sum(S > min_sv).item()))
291290
elif dynamic_method == "sv_cumulative" and dynamic_param is not None:
292-
new_rank = _index_sv_cumulative(S, dynamic_param)
291+
new_rank = _index_sv_cumulative(S, dynamic_param, max_rank)
293292
elif dynamic_method == "sv_fro" and dynamic_param is not None:
294-
new_rank = _index_sv_fro(S, dynamic_param)
293+
new_rank = _index_sv_fro(S, dynamic_param, max_rank)
295294
else:
296295
new_rank = max_rank
297296

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "ComfyUI-ModelUtils"
33
description = "[WIP]Custom nodes for handling, inspecting, modifying and creating various model files."
4-
version = "0.2.1"
4+
version = "0.2.2"
55
license = { file = "LICENSE" }
66

77
[project.urls]

0 commit comments

Comments
 (0)