@@ -71,52 +71,66 @@ def test_classifier_output_types(self):
7171 khc = KhiopsClassifier (n_trees = 0 )
7272 khc .fit (X , y )
7373 y_pred = khc .predict (X )
74+ khc .fit (X_mt , y )
75+ y_mt_pred = khc .predict (X_mt )
76+
7477 y_bin = y .replace ({0 : 0 , 1 : 0 , 2 : 1 })
7578 khc .fit (X , y_bin )
7679 y_bin_pred = khc .predict (X )
77- khc .fit (X_mt , y )
78- khc .export_report_file ("report.khj" )
79- y_mt_pred = khc .predict (X_mt )
8080 khc .fit (X_mt , y_bin )
8181 y_mt_bin_pred = khc .predict (X_mt )
8282
8383 # Create the fixtures
8484 fixtures = {
8585 "ys" : {
86- "int" : y ,
87- "int binary" : y_bin ,
88- "string" : self ._replace (y , {0 : "se" , 1 : "vi" , 2 : "ve" }),
89- "string binary" : self ._replace (y_bin , {0 : "vi_or_se" , 1 : "ve" }),
90- "int as string" : self ._replace (y , {0 : "8" , 1 : "9" , 2 : "10" }),
91- "int as string binary" : self ._replace (y_bin , {0 : "89" , 1 : "10" }),
92- "cat int" : pd .Series (y ).astype ("category" ),
93- "cat string" : pd .Series (
94- self ._replace (y , {0 : "se" , 1 : "vi" , 2 : "ve" })
95- ).astype ("category" ),
86+ #"int": y,
87+ #"int binary": y_bin,
88+ #"float": y.astype(float),
89+ #"bool": y.replace({0: True, 1: True, 2: False}),
90+ #"string": self._replace(y, {0: "se", 1: "vi", 2: "ve"}),
91+ #"string binary": self._replace(y_bin, {0: "vi_or_se", 1: "ve"}),
92+ #"int as string": self._replace(y, {0: "8", 1: "9", 2: "10"}),
93+ #"int as string binary": self._replace(y_bin, {0: "89", 1: "10"}),
94+ #"cat int": pd.Series(y).astype("category"),
95+ #"cat string": pd.Series(
96+ #self._replace(y, {0: "se", 1: "vi", 2: "ve"})
97+ #).astype("category"),
98+ #"cat float": y.astype(float).astype("category"),
99+ "cat bool" : y .replace ({0 : True , 1 : True , 2 : False }).astype ("category" ),
96100 },
97101 "y_type_check" : {
98102 "int" : pd .api .types .is_integer_dtype ,
99103 "int binary" : pd .api .types .is_integer_dtype ,
104+ "float" : pd .api .types .is_float_dtype ,
105+ "bool" : pd .api .types .is_bool_dtype ,
100106 "string" : pd .api .types .is_string_dtype ,
101107 "string binary" : pd .api .types .is_string_dtype ,
102108 "int as string" : pd .api .types .is_string_dtype ,
103109 "int as string binary" : pd .api .types .is_string_dtype ,
104110 "cat int" : pd .api .types .is_integer_dtype ,
105111 "cat string" : pd .api .types .is_string_dtype ,
112+ "cat float" : pd .api .types .is_float_dtype ,
113+ "cat bool" : pd .api .types .is_bool_dtype ,
106114 },
107115 "expected_classes" : {
108116 "int" : column_or_1d ([0 , 1 , 2 ]),
109117 "int binary" : column_or_1d ([0 , 1 ]),
118+ "float" : column_or_1d ([0.0 , 1.0 , 2.0 ]),
119+ "bool" : column_or_1d ([False , True ]),
110120 "string" : column_or_1d (["se" , "ve" , "vi" ]),
111121 "string binary" : column_or_1d (["ve" , "vi_or_se" ]),
112122 "int as string" : column_or_1d (["10" , "8" , "9" ]),
113123 "int as string binary" : column_or_1d (["10" , "89" ]),
114124 "cat int" : column_or_1d ([0 , 1 , 2 ]),
115125 "cat string" : column_or_1d (["se" , "ve" , "vi" ]),
126+ "cat float" : column_or_1d ([0.0 , 1.0 , 2.0 ]),
127+ "cat bool" : column_or_1d ([False , True ]),
116128 },
117129 "expected_y_preds" : {
118130 "mono" : {
119131 "int" : y_pred ,
132+ "float" : y_pred .astype (float ),
133+ "bool" : self ._replace (y_bin_pred , {0 : True , 1 : False }),
120134 "int binary" : y_bin_pred ,
121135 "string" : self ._replace (y_pred , {0 : "se" , 1 : "vi" , 2 : "ve" }),
122136 "string binary" : self ._replace (
@@ -128,9 +142,14 @@ def test_classifier_output_types(self):
128142 ),
129143 "cat int" : y_pred ,
130144 "cat string" : self ._replace (y_pred , {0 : "se" , 1 : "vi" , 2 : "ve" }),
145+ "cat float" : self ._replace (y_pred , {0 : 0.0 , 1 : 1.0 , 2 : 2.0 }),
146+ "cat bool" : self ._replace (y_bin_pred , {0 : True , 1 : False }),
147+
131148 },
132149 "multi" : {
133150 "int" : y_mt_pred ,
151+ "float" : y_mt_pred .astype (float ),
152+ "bool" : self ._replace (y_mt_bin_pred , {0 : True , 1 : False }),
134153 "int binary" : y_mt_bin_pred ,
135154 "string" : self ._replace (y_mt_pred , {0 : "se" , 1 : "vi" , 2 : "ve" }),
136155 "string binary" : self ._replace (
@@ -144,6 +163,9 @@ def test_classifier_output_types(self):
144163 ),
145164 "cat int" : y_mt_pred ,
146165 "cat string" : self ._replace (y_mt_pred , {0 : "se" , 1 : "vi" , 2 : "ve" }),
166+ "cat float" : self ._replace (y_mt_pred , {0 : 0.0 , 1 : 1.0 , 2 : 2.0 }),
167+ "cat bool" : self ._replace (y_mt_bin_pred , {0 : True , 1 : False }),
168+
147169 },
148170 },
149171 "Xs" : {
@@ -154,6 +176,8 @@ def test_classifier_output_types(self):
154176
155177 # Test for each fixture configuration
156178 for y_type , y in fixtures ["ys" ].items ():
179+ print ()
180+ print (y_type )
157181 y_type_check = fixtures ["y_type_check" ][y_type ]
158182 expected_classes = fixtures ["expected_classes" ][y_type ]
159183 for dataset_type , X in fixtures ["Xs" ].items ():
@@ -165,20 +189,32 @@ def test_classifier_output_types(self):
165189 # Train the classifier
166190 khc = KhiopsClassifier (n_trees = 0 )
167191 khc .fit (X , y )
192+ print ('unique y values:' , np .unique (y ))
168193
169194 # Check the expected classes
195+ print ()
196+ print ('model classes:' , khc .classes_ )
197+ print ('expected_classes:' , expected_classes )
170198 assert_array_equal (khc .classes_ , expected_classes )
171199
172200 # Check the return type of predict
173201 y_pred = khc .predict (X )
202+
174203 self .assertTrue (
175204 y_type_check (y_pred ),
176205 f"'{ y_type_check .__name__ } ' was False for "
177206 f"dtype '{ y_pred .dtype } '." ,
178207 )
179208
180209 # Check the predictions match
210+ print (dataset_type , y_type )
181211 expected_y_pred = fixtures ["expected_y_preds" ][dataset_type ][y_type ]
212+ diff_indices = np .where (y_pred != expected_y_pred )
213+ print ("Indices où les valeurs diffèrent :" , diff_indices )
214+
215+ for idx in zip (* diff_indices ):
216+ print (f"À l'indice { idx } , y_pred={ y_pred [idx ]} , expected_y_pred={ expected_y_pred [idx ]} " )
217+
182218 assert_array_equal (y_pred , expected_y_pred )
183219
184220 # Check the dimensions of predict_proba
0 commit comments