33import torch
44from torch import Tensor
55
6- from ._dual_cone_utils import _project_weights
6+ from ._dual_cone_utils import project_weights
77from ._gramian_utils import _compute_regularized_normalized_gramian
8- from ._pref_vector_utils import _pref_vector_to_str_suffix , _pref_vector_to_weighting
8+ from ._pref_vector_utils import pref_vector_to_str_suffix , pref_vector_to_weighting
99from .bases import _WeightedAggregator , _Weighting
1010from .mean import _MeanWeighting
1111
@@ -47,7 +47,7 @@ def __init__(
4747 reg_eps : float = 0.0001 ,
4848 solver : Literal ["quadprog" ] = "quadprog" ,
4949 ):
50- weighting = _pref_vector_to_weighting (pref_vector , default = _MeanWeighting ())
50+ weighting = pref_vector_to_weighting (pref_vector , default = _MeanWeighting ())
5151 self ._pref_vector = pref_vector
5252
5353 super ().__init__ (
@@ -64,7 +64,7 @@ def __repr__(self) -> str:
6464 )
6565
6666 def __str__ (self ) -> str :
67- return f"UPGrad{ _pref_vector_to_str_suffix (self ._pref_vector )} "
67+ return f"UPGrad{ pref_vector_to_str_suffix (self ._pref_vector )} "
6868
6969
7070class _UPGradWrapper (_Weighting ):
@@ -97,5 +97,5 @@ def __init__(
9797 def forward (self , matrix : Tensor ) -> Tensor :
9898 U = torch .diag (self .weighting (matrix ))
9999 G = _compute_regularized_normalized_gramian (matrix , self .norm_eps , self .reg_eps )
100- W = _project_weights (U , G , self .solver )
100+ W = project_weights (U , G , self .solver )
101101 return torch .sum (W , dim = 0 )
0 commit comments