@@ -71,12 +71,12 @@ def test_classifier_output_types(self):
7171 khc = KhiopsClassifier (n_trees = 0 )
7272 khc .fit (X , y )
7373 y_pred = khc .predict (X )
74+ khc .fit (X_mt , y )
75+ y_mt_pred = khc .predict (X_mt )
76+
7477 y_bin = y .replace ({0 : 0 , 1 : 0 , 2 : 1 })
7578 khc .fit (X , y_bin )
7679 y_bin_pred = khc .predict (X )
77- khc .fit (X_mt , y )
78- khc .export_report_file ("report.khj" )
79- y_mt_pred = khc .predict (X_mt )
8080 khc .fit (X_mt , y_bin )
8181 y_mt_bin_pred = khc .predict (X_mt )
8282
@@ -85,6 +85,8 @@ def test_classifier_output_types(self):
8585 "ys" : {
8686 "int" : y ,
8787 "int binary" : y_bin ,
88+ "float" : y .astype (float ),
89+ "bool" : y .replace ({0 : True , 1 : True , 2 : False }),
8890 "string" : self ._replace (y , {0 : "se" , 1 : "vi" , 2 : "ve" }),
8991 "string binary" : self ._replace (y_bin , {0 : "vi_or_se" , 1 : "ve" }),
9092 "int as string" : self ._replace (y , {0 : "8" , 1 : "9" , 2 : "10" }),
@@ -93,30 +95,42 @@ def test_classifier_output_types(self):
9395 "cat string" : pd .Series (
9496 self ._replace (y , {0 : "se" , 1 : "vi" , 2 : "ve" })
9597 ).astype ("category" ),
98+ "cat float" : y .astype (float ).astype ("category" ),
99+ "cat bool" : y .replace ({0 : True , 1 : True , 2 : False }).astype ("category" ),
96100 },
97101 "y_type_check" : {
98102 "int" : pd .api .types .is_integer_dtype ,
99103 "int binary" : pd .api .types .is_integer_dtype ,
104+ "float" : pd .api .types .is_float_dtype ,
105+ "bool" : pd .api .types .is_bool_dtype ,
100106 "string" : pd .api .types .is_string_dtype ,
101107 "string binary" : pd .api .types .is_string_dtype ,
102108 "int as string" : pd .api .types .is_string_dtype ,
103109 "int as string binary" : pd .api .types .is_string_dtype ,
104110 "cat int" : pd .api .types .is_integer_dtype ,
105111 "cat string" : pd .api .types .is_string_dtype ,
112+ "cat float" : pd .api .types .is_float_dtype ,
113+ "cat bool" : pd .api .types .is_bool_dtype ,
106114 },
107115 "expected_classes" : {
108116 "int" : column_or_1d ([0 , 1 , 2 ]),
109117 "int binary" : column_or_1d ([0 , 1 ]),
118+ "float" : column_or_1d ([0.0 , 1.0 , 2.0 ]),
119+ "bool" : column_or_1d ([False , True ]),
110120 "string" : column_or_1d (["se" , "ve" , "vi" ]),
111121 "string binary" : column_or_1d (["ve" , "vi_or_se" ]),
112122 "int as string" : column_or_1d (["10" , "8" , "9" ]),
113123 "int as string binary" : column_or_1d (["10" , "89" ]),
114124 "cat int" : column_or_1d ([0 , 1 , 2 ]),
115125 "cat string" : column_or_1d (["se" , "ve" , "vi" ]),
126+ "cat float" : column_or_1d ([0.0 , 1.0 , 2.0 ]),
127+ "cat bool" : column_or_1d ([False , True ]),
116128 },
117129 "expected_y_preds" : {
118130 "mono" : {
119131 "int" : y_pred ,
132+ "float" : y_pred .astype (float ),
133+ "bool" : self ._replace (y_bin_pred , {0 : True , 1 : False }),
120134 "int binary" : y_bin_pred ,
121135 "string" : self ._replace (y_pred , {0 : "se" , 1 : "vi" , 2 : "ve" }),
122136 "string binary" : self ._replace (
@@ -128,9 +142,13 @@ def test_classifier_output_types(self):
128142 ),
129143 "cat int" : y_pred ,
130144 "cat string" : self ._replace (y_pred , {0 : "se" , 1 : "vi" , 2 : "ve" }),
145+ "cat float" : self ._replace (y_pred , {0 : 0.0 , 1 : 1.0 , 2 : 2.0 }),
146+ "cat bool" : self ._replace (y_bin_pred , {0 : True , 1 : False }),
131147 },
132148 "multi" : {
133149 "int" : y_mt_pred ,
150+ "float" : y_mt_pred .astype (float ),
151+ "bool" : self ._replace (y_mt_bin_pred , {0 : True , 1 : False }),
134152 "int binary" : y_mt_bin_pred ,
135153 "string" : self ._replace (y_mt_pred , {0 : "se" , 1 : "vi" , 2 : "ve" }),
136154 "string binary" : self ._replace (
@@ -144,6 +162,8 @@ def test_classifier_output_types(self):
144162 ),
145163 "cat int" : y_mt_pred ,
146164 "cat string" : self ._replace (y_mt_pred , {0 : "se" , 1 : "vi" , 2 : "ve" }),
165+ "cat float" : self ._replace (y_mt_pred , {0 : 0.0 , 1 : 1.0 , 2 : 2.0 }),
166+ "cat bool" : self ._replace (y_mt_bin_pred , {0 : True , 1 : False }),
147167 },
148168 },
149169 "Xs" : {
@@ -171,6 +191,7 @@ def test_classifier_output_types(self):
171191
172192 # Check the return type of predict
173193 y_pred = khc .predict (X )
194+
174195 self .assertTrue (
175196 y_type_check (y_pred ),
176197 f"'{ y_type_check .__name__ } ' was False for "
0 commit comments