-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathtest_pref_vector.py
More file actions
26 lines (22 loc) · 885 Bytes
/
test_pref_vector.py
File metadata and controls
26 lines (22 loc) · 885 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from contextlib import nullcontext as does_not_raise
from pytest import mark, raises
from torch import Tensor
from tests.utils.contexts import ExceptionContext
from tests.utils.tensors import ones_
from torchjd.aggregation._mean import MeanWeighting
from torchjd.aggregation._utils.pref_vector import pref_vector_to_weighting
@mark.parametrize(
["pref_vector", "expectation"],
[
(None, does_not_raise()),
(ones_([]), raises(ValueError)),
(ones_([0]), does_not_raise()),
(ones_([1]), does_not_raise()),
(ones_([5]), does_not_raise()),
(ones_([1, 1]), raises(ValueError)),
(ones_([1, 1, 1]), raises(ValueError)),
],
)
def test_pref_vector_to_weighting_check(pref_vector: Tensor | None, expectation: ExceptionContext):
with expectation:
_ = pref_vector_to_weighting(pref_vector, default=MeanWeighting())