Skip to content

Commit 46ed3de

Browse files
committed
Add test_pref_vector_to_weighting_check
1 parent 7c944dd commit 46ed3de

1 file changed

Lines changed: 26 additions & 0 deletions

File tree

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from contextlib import nullcontext as does_not_raise
2+
3+
import torch
4+
from pytest import mark, raises
5+
from torch import Tensor
6+
from unit._utils import ExceptionContext
7+
8+
from torchjd.aggregation._pref_vector_utils import _pref_vector_to_weighting
9+
from torchjd.aggregation.mean import _MeanWeighting
10+
11+
12+
@mark.parametrize(
13+
["pref_vector", "expectation"],
14+
[
15+
(None, does_not_raise()),
16+
(torch.ones([]), raises(ValueError)),
17+
(torch.ones([0]), does_not_raise()),
18+
(torch.ones([1]), does_not_raise()),
19+
(torch.ones([5]), does_not_raise()),
20+
(torch.ones([1, 1]), raises(ValueError)),
21+
(torch.ones([1, 1, 1]), raises(ValueError)),
22+
],
23+
)
24+
def test_pref_vector_to_weighting_check(pref_vector: Tensor | None, expectation: ExceptionContext):
25+
with expectation:
26+
_ = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())

0 commit comments

Comments
 (0)