|
1 | 1 | from abc import ABC |
2 | | -from streamline.modeling.submodels import BinaryClassificationModel |
| 2 | +from streamline.modeling.submodels import MulticlassClassificationModel |
3 | 3 | from sklearn.ensemble import GradientBoostingClassifier as GB |
4 | 4 | from xgboost import XGBClassifier as XGB |
5 | 5 | from lightgbm import LGBMClassifier as LGB |
6 | 6 | from catboost import CatBoostClassifier as CGB |
7 | 7 |
|
8 | 8 |
|
9 | | -class GBClassifier(BinaryClassificationModel, ABC): |
| 9 | +class GBClassifier(MulticlassClassificationModel, ABC): |
10 | 10 | model_name = "Gradient Boosting" |
11 | 11 | small_name = "GB" |
12 | 12 | color = "cornflowerblue" |
@@ -40,7 +40,7 @@ def objective(self, trial, params=None): |
40 | 40 | return mean_cv_score |
41 | 41 |
|
42 | 42 |
|
43 | | -class XGBClassifier(BinaryClassificationModel, ABC): |
| 43 | +class XGBClassifier(MulticlassClassificationModel, ABC): |
44 | 44 | model_name = "Extreme Gradient Boosting" |
45 | 45 | small_name = "XGB" |
46 | 46 | color = "cyan" |
@@ -93,7 +93,7 @@ def objective(self, trial, params=None): |
93 | 93 | return mean_cv_score |
94 | 94 |
|
95 | 95 |
|
96 | | -class LGBClassifier(BinaryClassificationModel, ABC): |
| 96 | +class LGBClassifier(MulticlassClassificationModel, ABC): |
97 | 97 | model_name = "Light Gradient Boosting" |
98 | 98 | small_name = "LGB" |
99 | 99 | color = "pink" |
@@ -143,7 +143,7 @@ def objective(self, trial, params=None): |
143 | 143 | return mean_cv_score |
144 | 144 |
|
145 | 145 |
|
146 | | -class CGBClassifier(BinaryClassificationModel, ABC): |
| 146 | +class CGBClassifier(MulticlassClassificationModel, ABC): |
147 | 147 | model_name = "Category Gradient Boosting" |
148 | 148 | small_name = "CGB" |
149 | 149 | color = "magenta" |
|
0 commit comments