|
10 | 10 | import random |
11 | 11 | import sys |
12 | 12 | import time |
| 13 | +from concurrent.futures import as_completed |
13 | 14 | from functools import partial |
14 | 15 | from typing import Callable, List, Optional, Union |
15 | 16 |
|
@@ -187,7 +188,8 @@ def custom_metric( |
187 | 188 | mem_thres: A float of the memory size constraint in bytes. |
188 | 189 | pred_time_limit: A float of the prediction latency constraint in seconds. |
189 | 190 | It refers to the average prediction time per row in validation data. |
190 | | - train_time_limit: A float of the training time constraint in seconds. |
| 191 | + train_time_limit: None or a float of the training time constraint in seconds for each trial. |
| 192 | + Only valid for sequential search. |
191 | 193 | verbose: int, default=3 | Controls the verbosity, higher means more |
192 | 194 | messages. |
193 | 195 | retrain_full: bool or str, default=True | whether to retrain the |
@@ -1334,7 +1336,8 @@ def custom_metric( |
1334 | 1336 | mem_thres: A float of the memory size constraint in bytes. |
1335 | 1337 | pred_time_limit: A float of the prediction latency constraint in seconds. |
1336 | 1338 | It refers to the average prediction time per row in validation data. |
1337 | | - train_time_limit: None or a float of the training time constraint in seconds. |
| 1339 | + train_time_limit: None or a float of the training time constraint in seconds for each trial. |
| 1340 | + Only valid for sequential search. |
1338 | 1341 | X_val: None or a numpy array or a pandas dataframe of validation data. |
1339 | 1342 | y_val: None or a numpy array or a pandas series of validation labels. |
1340 | 1343 | sample_weight_val: None or a numpy array of the sample weight of |
@@ -1625,6 +1628,13 @@ def cv_score_agg_func(val_loss_folds, log_metrics_folds): |
1625 | 1628 | _ch.setFormatter(logger_formatter) |
1626 | 1629 | logger.addHandler(_ch) |
1627 | 1630 |
|
| 1631 | + if model_history: |
| 1632 | + logger.warning( |
| 1633 | + "With `model_history` set to `True` by default, all intermediate models are retained in memory, " |
| 1634 | + "which may significantly increase memory usage and slow down training. " |
| 1635 | + "Consider setting `model_history=False` to optimize memory and accelerate the training process." |
| 1636 | + ) |
| 1637 | + |
1628 | 1638 | if not use_ray and not use_spark and n_concurrent_trials > 1: |
1629 | 1639 | if ray_available: |
1630 | 1640 | logger.warning( |
@@ -2717,16 +2727,42 @@ def _search(self): |
2717 | 2727 | ): |
2718 | 2728 | if mlflow.active_run() is None: |
2719 | 2729 | mlflow.start_run(run_id=self.mlflow_integration.parent_run_id) |
2720 | | - self.mlflow_integration.log_model( |
2721 | | - self._trained_estimator.model, |
2722 | | - self.best_estimator, |
2723 | | - signature=self.estimator_signature, |
2724 | | - ) |
2725 | | - self.mlflow_integration.pickle_and_log_automl_artifacts( |
2726 | | - self, self.model, self.best_estimator, signature=self.pipeline_signature |
2727 | | - ) |
| 2730 | + if self.best_estimator.endswith("_spark"): |
| 2731 | + self.mlflow_integration.log_model( |
| 2732 | + self._trained_estimator.model, |
| 2733 | + self.best_estimator, |
| 2734 | + signature=self.estimator_signature, |
| 2735 | + run_id=self.mlflow_integration.parent_run_id, |
| 2736 | + ) |
| 2737 | + else: |
| 2738 | + self.mlflow_integration.pickle_and_log_automl_artifacts( |
| 2739 | + self, |
| 2740 | + self.model, |
| 2741 | + self.best_estimator, |
| 2742 | + signature=self.pipeline_signature, |
| 2743 | + run_id=self.mlflow_integration.parent_run_id, |
| 2744 | + ) |
2728 | 2745 | else: |
2729 | | - logger.info("not retraining because the time budget is too small.") |
| 2746 | + logger.warning("not retraining because the time budget is too small.") |
| 2747 | + if self.mlflow_integration is not None: |
| 2748 | + logger.debug("Collecting results from submitted record_state tasks") |
| 2749 | + t1 = time.perf_counter() |
| 2750 | + for future in as_completed(self.mlflow_integration.futures): |
| 2751 | + _task = self.mlflow_integration.futures[future] |
| 2752 | + try: |
| 2753 | + result = future.result() |
| 2754 | + logger.debug(f"Result for record_state task {_task}: {result}") |
| 2755 | + except Exception as e: |
| 2756 | + logger.warning(f"Exception for record_state task {_task}: {e}") |
| 2757 | + for future in as_completed(self.mlflow_integration.futures_log_model): |
| 2758 | + _task = self.mlflow_integration.futures_log_model[future] |
| 2759 | + try: |
| 2760 | + result = future.result() |
| 2761 | + logger.debug(f"Result for log_model task {_task}: {result}") |
| 2762 | + except Exception as e: |
| 2763 | + logger.warning(f"Exception for log_model task {_task}: {e}") |
| 2764 | + t2 = time.perf_counter() |
| 2765 | + logger.debug(f"Collecting results from tasks submitted to executors costs {t2-t1} seconds.") |
2730 | 2766 |
|
2731 | 2767 | def __del__(self): |
2732 | 2768 | if ( |
|
0 commit comments