2020from tests .test_utils import TEST_NDARRAYS , assert_allclose
2121
2222TESTS = []
23+ TESTS_ET_LABEL_3 = []
24+
25+ # Tests for default et_label = 4
2326for p in TEST_NDARRAYS :
2427 TESTS .extend (
2528 [
4649 ]
4750 )
4851
52+ # Tests for et_label = 3
53+ for p in TEST_NDARRAYS :
54+ TESTS_ET_LABEL_3 .extend (
55+ [
56+ [
57+ p ([[0 , 1 , 2 ], [1 , 2 , 3 ], [0 , 1 , 3 ]]),
58+ p (
59+ [
60+ [[0 , 1 , 0 ], [1 , 0 , 1 ], [0 , 1 , 1 ]],
61+ [[0 , 1 , 1 ], [1 , 1 , 1 ], [0 , 1 , 1 ]],
62+ [[0 , 0 , 0 ], [0 , 0 , 1 ], [0 , 0 , 1 ]],
63+ ]
64+ ),
65+ ]
66+ ]
67+ )
68+
4969
5070class TestConvertToMultiChannel (unittest .TestCase ):
5171 @parameterized .expand (TESTS )
@@ -54,6 +74,18 @@ def test_type_shape(self, data, expected_result):
5474 assert_allclose (result , expected_result )
5575 self .assertTrue (result .dtype in (bool , torch .bool ))
5676
77+ @parameterized .expand (TESTS_ET_LABEL_3 )
78+ def test_type_shape_et_label_3 (self , data , expected_result ):
79+ result = ConvertToMultiChannelBasedOnBratsClasses (et_label = 3 )(data )
80+ assert_allclose (result , expected_result )
81+ self .assertTrue (result .dtype in (bool , torch .bool ))
82+
83+ def test_invalid_et_label (self ):
84+ with self .assertRaises (ValueError ):
85+ ConvertToMultiChannelBasedOnBratsClasses (et_label = 1 )
86+ with self .assertRaises (ValueError ):
87+ ConvertToMultiChannelBasedOnBratsClasses (et_label = 2 )
88+
5789
5890if __name__ == "__main__" :
5991 unittest .main ()
0 commit comments