Skip to content

Commit d3ed5c1

Browse files
author
Matthew Shipton
committed
Working draft of adding AzureML support
1 parent e1e898a commit d3ed5c1

4 files changed

Lines changed: 30 additions & 10 deletions

File tree

mlflow_export_import/common/mlflow_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def set_experiment(mlflow_client, dbx_client, exp_name, tags=None):
4343
:return: Experiment ID
4444
"""
4545
from mlflow_export_import.common import utils
46-
if utils.importing_into_databricks():
46+
if utils.get_import_target_implementation() == utils.MLFlowImplementation.DATABRICKS:
4747
create_workspace_dir(dbx_client, os.path.dirname(exp_name))
4848
try:
4949
if not tags: tags = {}

mlflow_export_import/common/utils.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import pandas as pd
22
from tabulate import tabulate
33
import mlflow
4+
from enum import Enum, auto
45

6+
class MLFlowImplementation(Enum):
7+
DATABRICKS = auto()
8+
AZURE_ML = auto()
9+
OSS = auto()
510

611
# Databricks tags that cannot or should not be set
712
_DATABRICKS_SKIP_TAGS = set([
@@ -11,15 +16,26 @@
1116
"mlflow.experiment.sourceType", "mlflow.experiment.sourceId"
1217
])
1318

19+
_AZURE_ML_SKIP_TAGS = set([
20+
"mlflow.user",
21+
"mlflow.source.git.commit"
22+
])
23+
1424

1525
def create_mlflow_tags_for_databricks_import(tags):
16-
if importing_into_databricks():
17-
tags = { k:v for k,v in tags.items() if not k in _DATABRICKS_SKIP_TAGS }
18-
return tags
26+
environment = get_import_target_implementation()
27+
if environment == MLFlowImplementation.DATABRICKS:
28+
return { k:v for k,v in tags.items() if not k in _DATABRICKS_SKIP_TAGS }
29+
if environment == MLFlowImplementation.AZURE_ML:
30+
return { k:v for k,v in tags.items() if not k in _AZURE_ML_SKIP_TAGS }
31+
if environment == MLFlowImplementation.OSS:
32+
return tags
33+
raise Exception("Unsupported environment")
1934

2035

2136
def set_dst_user_id(tags, user_id, use_src_user_id):
22-
if importing_into_databricks():
37+
if get_import_target_implementation() in (MLFlowImplementation.DATABRICKS,
38+
MLFlowImplementation.AZURE_ML):
2339
return
2440
from mlflow.entities import RunTag
2541
from mlflow.utils.mlflow_tags import MLFLOW_USER
@@ -59,8 +75,12 @@ def nested_tags(dst_client, run_ids_mapping):
5975
dst_client.set_tag(dst_run_id, "mlflow.parentRunId", dst_parent_run_id)
6076

6177

62-
def importing_into_databricks():
63-
return mlflow.tracking.get_tracking_uri().startswith("databricks")
78+
def get_import_target_implementation() -> MLFlowImplementation:
79+
if mlflow.tracking.get_tracking_uri().startswith("databricks"):
80+
return MLFlowImplementation.DATABRICKS
81+
if mlflow.tracking.get_tracking_uri().startswith("azureml"):
82+
return MLFlowImplementation.AZURE_ML
83+
return MLFlowImplementation.OSS
6484

6585

6686
def show_table(title, lst, columns):

mlflow_export_import/run/import_run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self,
5454
self.dbx_client = DatabricksHttpClient()
5555
self.import_source_tags = import_source_tags
5656
print(f"in_databricks: {self.in_databricks}")
57-
print(f"importing_into_databricks: {utils.importing_into_databricks()}")
57+
print(f"importing_into_environment: {utils.get_import_target_implementation().name}")
5858

5959

6060
def import_run(self, exp_name, input_dir, dst_notebook_dir=None):
@@ -93,7 +93,7 @@ def _import_run(self, dst_exp_name, input_dir, dst_notebook_dir):
9393
import traceback
9494
traceback.print_exc()
9595
raise MlflowExportImportException(e, f"Importing run {run_id} of experiment '{exp.name}' failed")
96-
if utils.importing_into_databricks() and dst_notebook_dir:
96+
if utils.get_import_target_implementation() == utils.MLFlowImplementation.DATABRICKS and dst_notebook_dir:
9797
ndir = os.path.join(dst_notebook_dir, run_id) if self.dst_notebook_dir_add_run_id else dst_notebook_dir
9898
self._upload_databricks_notebook(input_dir, src_run_dct, ndir)
9999

tests/compare_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def compare_versions(mlflow_client_src, mlflow_client_dst, vr_src, vr_dst, outpu
9393
assert vr_src.status_message == vr_dst.status_message
9494
if mlflow_client_src != mlflow_client_src:
9595
assert vr_src.name == vr_dst.name
96-
if not utils.importing_into_databricks():
96+
if utils.get_import_target_implementation() != utils.MLFlowImplementation.DATABRICKS:
9797
assert vr_src.user_id == vr_dst.user_id
9898

9999
tags_dst = { k:v for k,v in vr_dst.tags.items() if not k.startswith(ExportTags.PREFIX_ROOT) }

0 commit comments

Comments
 (0)