-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy path_pref_vector_utils.py
More file actions
31 lines (24 loc) · 993 Bytes
/
_pref_vector_utils.py
File metadata and controls
31 lines (24 loc) · 993 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from torch import Tensor
from ._str_utils import _vector_to_str
from .bases import _Weighting
from .constant import _ConstantWeighting
def pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -> _Weighting:
"""
Returns the weighting associated to a given preference vector, with a fallback to a default
weighting if the preference vector is None.
"""
if pref_vector is None:
return default
else:
if pref_vector.ndim != 1:
raise ValueError(
"Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = "
f"{pref_vector.ndim}`."
)
return _ConstantWeighting(pref_vector)
def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str:
"""Returns a suffix string containing the representation of the optional preference vector."""
if pref_vector is None:
return ""
else:
return f"([{_vector_to_str(pref_vector)}])"