Skip to content

Commit e1a3f0b

Browse files
authored
Change base class from Binary to Multiclass Classification for Grad boost algos
1 parent 2a6f56d commit e1a3f0b

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

streamline/models/multiclass_classification/gradient_boosting.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from abc import ABC
2-
from streamline.modeling.submodels import BinaryClassificationModel
2+
from streamline.modeling.submodels import MulticlassClassificationModel
33
from sklearn.ensemble import GradientBoostingClassifier as GB
44
from xgboost import XGBClassifier as XGB
55
from lightgbm import LGBMClassifier as LGB
66
from catboost import CatBoostClassifier as CGB
77

88

9-
class GBClassifier(BinaryClassificationModel, ABC):
9+
class GBClassifier(MulticlassClassificationModel, ABC):
1010
model_name = "Gradient Boosting"
1111
small_name = "GB"
1212
color = "cornflowerblue"
@@ -40,7 +40,7 @@ def objective(self, trial, params=None):
4040
return mean_cv_score
4141

4242

43-
class XGBClassifier(BinaryClassificationModel, ABC):
43+
class XGBClassifier(MulticlassClassificationModel, ABC):
4444
model_name = "Extreme Gradient Boosting"
4545
small_name = "XGB"
4646
color = "cyan"
@@ -93,7 +93,7 @@ def objective(self, trial, params=None):
9393
return mean_cv_score
9494

9595

96-
class LGBClassifier(BinaryClassificationModel, ABC):
96+
class LGBClassifier(MulticlassClassificationModel, ABC):
9797
model_name = "Light Gradient Boosting"
9898
small_name = "LGB"
9999
color = "pink"
@@ -143,7 +143,7 @@ def objective(self, trial, params=None):
143143
return mean_cv_score
144144

145145

146-
class CGBClassifier(BinaryClassificationModel, ABC):
146+
class CGBClassifier(MulticlassClassificationModel, ABC):
147147
model_name = "Category Gradient Boosting"
148148
small_name = "CGB"
149149
color = "magenta"

0 commit comments

Comments
 (0)