diff --git a/tutorials/e2e-ds-experience/e2e-ml-workflow.ipynb b/tutorials/e2e-ds-experience/e2e-ml-workflow.ipynb index f82604832a..f83dc6452e 100644 --- a/tutorials/e2e-ds-experience/e2e-ml-workflow.ipynb +++ b/tutorials/e2e-ds-experience/e2e-ml-workflow.ipynb @@ -344,8 +344,8 @@ " description=\"Custom environment for Credit Card Defaults pipeline\",\n", " tags={\"scikit-learn\": \"0.24.2\"},\n", " conda_file=os.path.join(dependencies_dir, \"conda.yaml\"),\n", - " image=\"mcr.microsoft.com/azureml/openmpi4.1.0-ubuntu20.04:latest\",\n", - " version=\"0.1.1\",\n", + " image=\"mcr.microsoft.com/azureml/openmpi4.1.0-ubuntu22.04:latest\",\n", + " version=\"0.1.2\",\n", ")\n", "pipeline_job_env = ml_client.environments.create_or_update(pipeline_job_env)\n", "\n", @@ -610,6 +610,7 @@ "import os\n", "import pandas as pd\n", "import mlflow\n", + "from mlflow.exceptions import MlflowException\n", "\n", "\n", "def select_first_file(path):\n", @@ -627,7 +628,7 @@ "mlflow.start_run()\n", "\n", "# enable autologging\n", - "mlflow.sklearn.autolog()\n", + "mlflow.sklearn.autolog(log_models=False)\n", "\n", "os.makedirs(\"./outputs\", exist_ok=True)\n", "\n", @@ -674,26 +675,31 @@ "\n", " print(classification_report(y_test, y_pred))\n", "\n", - " # Registering the model to the workspace\n", - " print(\"Registering the model via MLFlow\")\n", - " mlflow.sklearn.log_model(\n", - " sk_model=clf,\n", - " registered_model_name=args.registered_model_name,\n", - " artifact_path=args.registered_model_name,\n", - " )\n", - "\n", " # Saving the model to a file\n", + " model_path = os.path.join(args.model, \"trained_model\")\n", " mlflow.sklearn.save_model(\n", " sk_model=clf,\n", - " path=os.path.join(args.model, \"trained_model\"),\n", + " path=model_path,\n", " )\n", "\n", + " # Registering the model to the workspace\n", + " print(\"Registering the model via MLFlow\")\n", + " mlflow.log_artifacts(model_path, artifact_path=args.registered_model_name)\n", + " run_id = mlflow.active_run().info.run_id\n", + " model_uri = f\"runs:/{run_id}/{args.registered_model_name}\"\n", + " client = mlflow.tracking.MlflowClient()\n", + " try:\n", + " client.create_registered_model(args.registered_model_name)\n", + " except mlflow.exceptions.MlflowException:\n", + " pass # model already exists\n", + " client.create_model_version(args.registered_model_name, model_uri, run_id)\n", + "\n", " # Stop Logging\n", " mlflow.end_run()\n", "\n", "\n", "if __name__ == \"__main__\":\n", - " main()" + " main()\n" ] }, {