@@ -44,23 +44,43 @@ def _index_sv_ratio(S: torch.Tensor, ratio: float) -> int:
4444 return max (1 , min (rank , len (S ) - 1 ))
4545
4646
47- def _index_sv_cumulative (S : torch .Tensor , target : float ) -> int :
48- """Cumulative mode - keep enough SVs to reach target % of total."""
49- total = torch .sum (S )
47+ def _index_sv_cumulative (S : torch .Tensor , target : float , max_rank : int = None ) -> int :
48+ """Cumulative mode - keep enough SVs to reach target % of total.
49+
50+ Calculates relative to max_rank if provided, otherwise relative to full.
51+ """
52+ if max_rank is not None and max_rank < len (S ):
53+ total = torch .sum (S [:max_rank ])
54+ else :
55+ total = torch .sum (S )
56+
5057 if total < 1e-8 :
5158 return 1
5259 cumsum = torch .cumsum (S , dim = 0 ) / total
5360 rank = int (torch .searchsorted (cumsum , target ).item ()) + 1
5461 return max (1 , min (rank , len (S ) - 1 ))
5562
5663
57- def _index_sv_fro (S : torch .Tensor , target : float ) -> int :
58- """Frobenius norm mode - preserve target fraction of Frobenius norm."""
59- S_sq = S .pow (2 )
60- total_sq = torch .sum (S_sq )
64+ def _index_sv_fro (S : torch .Tensor , target : float , max_rank : int = None ) -> int :
65+ """Frobenius norm mode - preserve target fraction of Frobenius norm.
66+
67+ Calculates relative to max_rank if provided, otherwise relative to full.
68+ This means "retain target% of what's achievable within max_rank".
69+ """
70+ if max_rank is not None and max_rank < len (S ):
71+ # Calculate relative to what's achievable within max_rank
72+ S_capped = S [:max_rank ]
73+ S_sq = S_capped .pow (2 )
74+ total_sq = torch .sum (S_sq )
75+ else :
76+ S_sq = S .pow (2 )
77+ total_sq = torch .sum (S_sq )
78+
6179 if total_sq < 1e-8 :
6280 return 1
63- cumsum = torch .cumsum (S_sq , dim = 0 ) / total_sq
81+
82+ # Cumsum of all S (not capped) to find where we reach target
83+ cumsum = torch .cumsum (S .pow (2 ), dim = 0 ) / total_sq
6484 rank = int (torch .searchsorted (cumsum , target ** 2 ).item ()) + 1
6585 return max (1 , min (rank , len (S ) - 1 ))
6686
@@ -122,9 +142,9 @@ def _compute_rank(S: torch.Tensor, mode: str, mode_param: float,
122142 elif mode == "ratio" :
123143 rank = _index_sv_ratio (S , mode_param )
124144 elif mode == "quantile" or mode == "sv_cumulative" :
125- rank = _index_sv_cumulative (S , mode_param )
145+ rank = _index_sv_cumulative (S , mode_param , max_rank )
126146 elif mode == "sv_fro" :
127- rank = _index_sv_fro (S , mode_param )
147+ rank = _index_sv_fro (S , mode_param , max_rank )
128148 elif mode == "sv_knee" :
129149 rank = _index_sv_knee (S )
130150 elif mode == "sv_cumulative_knee" :
0 commit comments