|
69 | 69 | _RMSLE_METRIC_ID = "rmsle" |
70 | 70 | _MSE_METRIC_ID = "mse" |
71 | 71 |
|
72 | | -try: # Only used by local tuning loop |
73 | | - import sklearn.metrics |
74 | | - from sklearn.model_selection import train_test_split |
75 | | - |
76 | | - _SUPPORTED_METRIC_FUNCTIONS = { |
77 | | - _ROC_AUC_METRIC_ID: sklearn.metrics.roc_auc_score, |
78 | | - _F1_METRIC_ID: sklearn.metrics.f1_score, |
79 | | - _PRECISION_METRIC_ID: sklearn.metrics.precision_score, |
80 | | - _RECALL_METRIC_ID: sklearn.metrics.recall_score, |
81 | | - _ACCURACY_METRIC_ID: sklearn.metrics.accuracy_score, |
82 | | - _MAE_METRIC_ID: sklearn.metrics.mean_absolute_error, |
83 | | - _MAPE_METRIC_ID: sklearn.metrics.mean_absolute_percentage_error, |
84 | | - _R2_METRIC_ID: sklearn.metrics.r2_score, |
85 | | - _RMSE_METRIC_ID: functools.partial( |
86 | | - sklearn.metrics.mean_squared_error, squared=False |
87 | | - ), |
88 | | - _RMSLE_METRIC_ID: functools.partial( |
89 | | - sklearn.metrics.mean_squared_log_error, squared=False |
90 | | - ), |
91 | | - _MSE_METRIC_ID: sklearn.metrics.mean_squared_error, |
92 | | - } |
93 | | - _SUPPORTED_METRIC_IDS = frozenset(_SUPPORTED_METRIC_FUNCTIONS.keys()).union( |
94 | | - frozenset([_CUSTOM_METRIC_ID]) |
95 | | - ) |
96 | | - _SUPPORTED_CLASSIFICATION_METRIC_IDS = frozenset( |
97 | | - [ |
98 | | - _ROC_AUC_METRIC_ID, |
99 | | - _F1_METRIC_ID, |
100 | | - _PRECISION_METRIC_ID, |
101 | | - _RECALL_METRIC_ID, |
102 | | - _ACCURACY_METRIC_ID, |
103 | | - ] |
104 | | - ) |
| 72 | +_SUPPORTED_METRIC_IDS = frozenset( |
| 73 | + [ |
| 74 | + _CUSTOM_METRIC_ID, |
| 75 | + _ROC_AUC_METRIC_ID, |
| 76 | + _F1_METRIC_ID, |
| 77 | + _PRECISION_METRIC_ID, |
| 78 | + _RECALL_METRIC_ID, |
| 79 | + _ACCURACY_METRIC_ID, |
| 80 | + _MAE_METRIC_ID, |
| 81 | + _MAPE_METRIC_ID, |
| 82 | + _R2_METRIC_ID, |
| 83 | + _RMSE_METRIC_ID, |
| 84 | + _RMSLE_METRIC_ID, |
| 85 | + _MSE_METRIC_ID, |
| 86 | + ] |
| 87 | +) |
| 88 | +_SUPPORTED_CLASSIFICATION_METRIC_IDS = frozenset( |
| 89 | + [ |
| 90 | + _ROC_AUC_METRIC_ID, |
| 91 | + _F1_METRIC_ID, |
| 92 | + _PRECISION_METRIC_ID, |
| 93 | + _RECALL_METRIC_ID, |
| 94 | + _ACCURACY_METRIC_ID, |
| 95 | + ] |
| 96 | +) |
105 | 97 |
|
106 | | -except ImportError: |
107 | | - pass |
108 | 98 |
|
109 | 99 | # Vizier client constnats |
110 | 100 | _STUDY_NAME_PREFIX = "vizier_hyperparameter_tuner_study" |
@@ -366,6 +356,13 @@ def _create_train_and_test_splits( |
366 | 356 | "test_fraction must be greater than 0 and less than 1 but was " |
367 | 357 | f"{test_fraction}." |
368 | 358 | ) |
| 359 | + try: |
| 360 | + from sklearn.model_selection import train_test_split |
| 361 | + except ImportError: |
| 362 | + raise ImportError( |
| 363 | + "scikit-learn must be installed to create train and test splits. " |
| 364 | + "Please call `pip install scikit-learn>=0.24`" |
| 365 | + ) from None |
369 | 366 |
|
370 | 367 | if isinstance(y, str): |
371 | 368 | try: |
@@ -414,6 +411,32 @@ def score(x_test, y_test): |
414 | 411 | Returns: |
415 | 412 | A tuple containing the model and the corresponding metric value. |
416 | 413 | """ |
| 414 | + try: # Only used by local tuning loop |
| 415 | + import sklearn.metrics |
| 416 | + |
| 417 | + _SUPPORTED_METRIC_FUNCTIONS = { |
| 418 | + _ROC_AUC_METRIC_ID: sklearn.metrics.roc_auc_score, |
| 419 | + _F1_METRIC_ID: sklearn.metrics.f1_score, |
| 420 | + _PRECISION_METRIC_ID: sklearn.metrics.precision_score, |
| 421 | + _RECALL_METRIC_ID: sklearn.metrics.recall_score, |
| 422 | + _ACCURACY_METRIC_ID: sklearn.metrics.accuracy_score, |
| 423 | + _MAE_METRIC_ID: sklearn.metrics.mean_absolute_error, |
| 424 | + _MAPE_METRIC_ID: sklearn.metrics.mean_absolute_percentage_error, |
| 425 | + _R2_METRIC_ID: sklearn.metrics.r2_score, |
| 426 | + _RMSE_METRIC_ID: functools.partial( |
| 427 | + sklearn.metrics.mean_squared_error, squared=False |
| 428 | + ), |
| 429 | + _RMSLE_METRIC_ID: functools.partial( |
| 430 | + sklearn.metrics.mean_squared_log_error, squared=False |
| 431 | + ), |
| 432 | + _MSE_METRIC_ID: sklearn.metrics.mean_squared_error, |
| 433 | + } |
| 434 | + except Exception as e: |
| 435 | + raise ImportError( |
| 436 | + "scikit-learn must be installed to evaluate models. " |
| 437 | + "Please call `pip install scikit-learn>=0.24`" |
| 438 | + ) from e |
| 439 | + |
417 | 440 | if self.metric_id == _CUSTOM_METRIC_ID: |
418 | 441 | metric_value = model.score(x_test, y_test) |
419 | 442 | else: |
|
0 commit comments