-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathpref_vector.py
More file actions
35 lines (27 loc) · 1.06 KB
/
pref_vector.py
File metadata and controls
35 lines (27 loc) · 1.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from torch import Tensor
from torchjd._linalg import Matrix
from torchjd.aggregation._constant import ConstantWeighting
from torchjd.aggregation._weighting_bases import Weighting
from .str import vector_to_str
def pref_vector_to_weighting(
pref_vector: Tensor | None, default: Weighting[Matrix]
) -> Weighting[Matrix]:
"""
Returns the weighting associated to a given preference vector, with a fallback to a default
weighting if the preference vector is None.
"""
if pref_vector is None:
return default
else:
if pref_vector.ndim != 1:
raise ValueError(
"Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = "
f"{pref_vector.ndim}`."
)
return ConstantWeighting(pref_vector)
def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str:
"""Returns a suffix string containing the representation of the optional preference vector."""
if pref_vector is None:
return ""
else:
return f"([{vector_to_str(pref_vector)}])"