|
1 | 1 | import pandas as pd |
2 | 2 | from tabulate import tabulate |
3 | 3 | import mlflow |
| 4 | +from enum import Enum, auto |
4 | 5 |
|
| 6 | +class MLFlowImplementation(Enum): |
| 7 | + DATABRICKS = auto() |
| 8 | + AZURE_ML = auto() |
| 9 | + OSS = auto() |
5 | 10 |
|
6 | 11 | # Databricks tags that cannot or should not be set |
7 | 12 | _DATABRICKS_SKIP_TAGS = set([ |
|
11 | 16 | "mlflow.experiment.sourceType", "mlflow.experiment.sourceId" |
12 | 17 | ]) |
13 | 18 |
|
| 19 | +_AZURE_ML_SKIP_TAGS = set([ |
| 20 | + "mlflow.user", |
| 21 | + "mlflow.source.git.commit" |
| 22 | + ]) |
| 23 | + |
14 | 24 |
|
15 | 25 | def create_mlflow_tags_for_databricks_import(tags): |
16 | | - if importing_into_databricks(): |
17 | | - tags = { k:v for k,v in tags.items() if not k in _DATABRICKS_SKIP_TAGS } |
18 | | - return tags |
| 26 | + environment = get_import_target_implementation() |
| 27 | + if environment == MLFlowImplementation.DATABRICKS: |
| 28 | + return { k:v for k,v in tags.items() if not k in _DATABRICKS_SKIP_TAGS } |
| 29 | + if environment == MLFlowImplementation.AZURE_ML: |
| 30 | + return { k:v for k,v in tags.items() if not k in _AZURE_ML_SKIP_TAGS } |
| 31 | + if environment == MLFlowImplementation.OSS: |
| 32 | + return tags |
| 33 | + raise Exception("Unsupported environment") |
19 | 34 |
|
20 | 35 |
|
21 | 36 | def set_dst_user_id(tags, user_id, use_src_user_id): |
22 | | - if importing_into_databricks(): |
| 37 | + if get_import_target_implementation() in (MLFlowImplementation.DATABRICKS, |
| 38 | + MLFlowImplementation.AZURE_ML): |
23 | 39 | return |
24 | 40 | from mlflow.entities import RunTag |
25 | 41 | from mlflow.utils.mlflow_tags import MLFLOW_USER |
@@ -59,8 +75,12 @@ def nested_tags(dst_client, run_ids_mapping): |
59 | 75 | dst_client.set_tag(dst_run_id, "mlflow.parentRunId", dst_parent_run_id) |
60 | 76 |
|
61 | 77 |
|
62 | | -def importing_into_databricks(): |
63 | | - return mlflow.tracking.get_tracking_uri().startswith("databricks") |
| 78 | +def get_import_target_implementation() -> MLFlowImplementation: |
| 79 | + if mlflow.tracking.get_tracking_uri().startswith("databricks"): |
| 80 | + return MLFlowImplementation.DATABRICKS |
| 81 | + if mlflow.tracking.get_tracking_uri().startswith("azureml"): |
| 82 | + return MLFlowImplementation.AZURE_ML |
| 83 | + return MLFlowImplementation.OSS |
64 | 84 |
|
65 | 85 |
|
66 | 86 | def show_table(title, lst, columns): |
|
0 commit comments