diff --git a/flaml/fabric/mlflow.py b/flaml/fabric/mlflow.py index 4cfb0e2cd2..442c3709b8 100644 --- a/flaml/fabric/mlflow.py +++ b/flaml/fabric/mlflow.py @@ -492,7 +492,11 @@ def log_automl(self, automl): mlflow.log_metric("best_validation_loss", automl._state.best_loss) mlflow.log_metric("best_iteration", automl._best_iteration) mlflow.log_metric("num_child_runs", len(self.infos)) - if automl._trained_estimator is not None and not self.has_model: + if ( + automl._trained_estimator is not None + and not self.has_model + and automl._trained_estimator._model is not None + ): self.log_model( automl._trained_estimator._model, automl.best_estimator, signature=automl.estimator_signature ) @@ -521,7 +525,11 @@ def log_automl(self, automl): logger.info(f"logging best model {automl.best_estimator}") self.copy_mlflow_run(best_mlflow_run_id, self.parent_run_id) self.has_summary = True - if automl._trained_estimator is not None and not self.has_model: + if ( + automl._trained_estimator is not None + and not self.has_model + and automl._trained_estimator._model is not None + ): self.log_model( automl._trained_estimator._model, automl.best_estimator, diff --git a/setup.py b/setup.py index 5783e99de0..e30db68a20 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ "psutil==5.8.0", "dataclasses", "transformers[torch]==4.26", - "datasets", + "datasets<=3.5.0", "nltk<=3.8.1", # 3.8.2 doesn't work with mlflow "rouge_score", "hcrystalball==0.1.10",