@@ -153,6 +153,7 @@ def _check_categorical_target_type(ds):
153153 or pd .api .types .is_string_dtype (ds .target_column .dtype )
154154 or pd .api .types .is_integer_dtype (ds .target_column .dtype )
155155 or pd .api .types .is_float_dtype (ds .target_column .dtype )
156+ or pd .api .types .is_bool_dtype (ds .target_column .dtype )
156157 ):
157158 raise ValueError (
158159 f"'y' has invalid type '{ ds .target_column_type } '. "
@@ -2093,6 +2094,24 @@ def _is_real_target_dtype_integer(self):
20932094 )
20942095 )
20952096
2097+ def _is_real_target_dtype_float (self ):
2098+ return self ._original_target_dtype is not None and (
2099+ pd .api .types .is_float_dtype (self ._original_target_dtype )
2100+ or (
2101+ isinstance (self ._original_target_dtype , pd .CategoricalDtype )
2102+ and pd .api .types .is_float_dtype (self ._original_target_dtype .categories )
2103+ )
2104+ )
2105+
2106+ def _is_real_target_dtype_bool (self ):
2107+ return self ._original_target_dtype is not None and (
2108+ pd .api .types .is_bool_dtype (self ._original_target_dtype )
2109+ or (
2110+ isinstance (self ._original_target_dtype , pd .CategoricalDtype )
2111+ and pd .api .types .is_bool_dtype (self ._original_target_dtype .categories )
2112+ )
2113+ )
2114+
20962115 def _sorted_prob_variable_names (self ):
20972116 """Returns the model probability variable names in the order of self.classes_"""
20982117 assert self .is_fitted_ , "Model not fit yet"
@@ -2195,11 +2214,15 @@ def _fit_training_post_process(self, ds):
21952214 for key in variable .meta_data .keys :
21962215 if key .startswith ("TargetProb" ):
21972216 self .classes_ .append (variable .meta_data .get_value (key ))
2198- if ds .is_in_memory and self ._is_real_target_dtype_integer ():
2199- self .classes_ = [int (class_value ) for class_value in self .classes_ ]
2217+ if ds .is_in_memory :
2218+ if self ._is_real_target_dtype_integer ():
2219+ self .classes_ = [int (class_value ) for class_value in self .classes_ ]
2220+ elif self ._is_real_target_dtype_float ():
2221+ self .classes_ = [float (class_value ) for class_value in self .classes_ ]
2222+ elif self ._is_real_target_dtype_bool ():
2223+ self .classes_ = [class_value == "True" for class_value in self .classes_ ]
22002224 self .classes_ .sort ()
22012225 self .classes_ = column_or_1d (self .classes_ )
2202-
22032226 # Count number of classes
22042227 self .n_classes_ = len (self .classes_ )
22052228
@@ -2259,13 +2282,11 @@ def predict(self, X):
22592282 """
22602283 # Call the parent's method
22612284 y_pred = super ().predict (X )
2262-
22632285 # Adjust the data type according to the original target type
22642286 # Note: String is coerced explictly because astype does not work as expected
22652287 if isinstance (y_pred , pd .DataFrame ):
22662288 # Transform to numpy.ndarray
22672289 y_pred = y_pred .to_numpy (copy = False ).ravel ()
2268-
22692290 # If integer and string just transform
22702291 if pd .api .types .is_integer_dtype (self ._original_target_dtype ):
22712292 y_pred = y_pred .astype (self ._original_target_dtype )
@@ -2275,6 +2296,10 @@ def predict(self, X):
22752296 self ._original_target_dtype
22762297 ):
22772298 y_pred = y_pred .astype (str , copy = False )
2299+ elif pd .api .types .is_float_dtype (self ._original_target_dtype ):
2300+ y_pred = y_pred .astype (float , copy = False )
2301+ elif pd .api .types .is_bool_dtype (self ._original_target_dtype ):
2302+ y_pred = y_pred .astype (bool , copy = False )
22782303 # If category first coerce the type to the categories' type
22792304 else :
22802305 assert isinstance (self ._original_target_dtype , pd .CategoricalDtype ), (
0 commit comments