-
Notifications
You must be signed in to change notification settings - Fork 4
Move MLflow authentication helper into shared module #393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
25e1b34
7758397
f3272f0
ab1dfce
51d3082
722869c
dfd2189
688ef05
2b3afc8
a2e30f9
f9af282
bd3dc1b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,19 +1,28 @@ | ||
| import asyncio | ||
| from datetime import datetime | ||
| from pathlib import Path | ||
| import tempfile | ||
| import os | ||
| import yaml | ||
| import re | ||
| import sys | ||
| import tempfile | ||
| from datetime import datetime | ||
| from pathlib import Path | ||
|
|
||
| import mlflow | ||
| import urllib3 | ||
| import yaml | ||
| from sfapi_client import AsyncClient | ||
| from sfapi_client.compute import Machine | ||
| from trame.widgets import vuetify3 as vuetify | ||
| from utils import timer, load_config_dict, create_date_filter | ||
| from urllib3.exceptions import InsecureRequestWarning | ||
|
|
||
| # Add parent directory to path so we can import mlflow_utils from root | ||
| sys.path.insert(0, str(Path(__file__).parent.parent)) | ||
|
|
||
| from calibration_manager import build_inferred_calibration | ||
| from error_manager import add_error | ||
| from mlflow_utils import enable_amsc_x_api_key | ||
| from sfapi_manager import monitor_sfapi_job | ||
| from state_manager import state | ||
| from utils import timer, load_config_dict, create_date_filter | ||
|
|
||
| model_type_dict = { | ||
| "Gaussian Process": "GP", | ||
|
|
@@ -22,49 +31,6 @@ | |
| } | ||
|
|
||
|
|
||
| def enable_amsc_x_api_key(config_dict): | ||
| """ | ||
| MLflow authentication helper for the AmSC MLflow server. | ||
| Standard MLflow does not automatically inject custom headers like 'X-Api-Key'. | ||
| This patches the http_request function to ensure every request to the server | ||
| includes the AmSC API key. | ||
|
|
||
| See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch for more details. | ||
| """ | ||
| import mlflow.utils.rest_utils as rest_utils | ||
|
|
||
| mlflow_cfg = config_dict.get("mlflow") or {} | ||
| api_key_env = mlflow_cfg.get("api_key_env") | ||
| if not api_key_env: | ||
| title = "Unable to enable AmSC X-Api-Key authentication" | ||
| msg = "MLFlow configuration is missing 'mlflow.api_key_env'" | ||
| add_error(title, msg) | ||
| print(msg) | ||
| return | ||
|
|
||
| api_key = os.environ.get(api_key_env) | ||
| if not api_key: | ||
| title = "Unable to enable AmSC X-Api-Key authentication" | ||
| msg = f"Environment variable '{api_key_env}' in 'mlflow.api_key_env' is not set" | ||
| add_error(title, msg) | ||
| print(msg) | ||
| return | ||
| _orig = rest_utils.http_request | ||
|
|
||
| def patched(host_creds, endpoint, method, *args, **kwargs): | ||
| if "headers" in kwargs and kwargs["headers"] is not None: | ||
| h = dict(kwargs["headers"]) | ||
| h["X-Api-Key"] = api_key | ||
| kwargs["headers"] = h | ||
| else: | ||
| h = dict(kwargs.get("extra_headers") or {}) | ||
| h["X-Api-Key"] = api_key | ||
| kwargs["extra_headers"] = h | ||
| return _orig(host_creds, endpoint, method, *args, **kwargs) | ||
|
|
||
| rest_utils.http_request = patched | ||
|
|
||
|
|
||
| class ModelManager: | ||
| def __init__(self, config_dict, model_type): | ||
| print("Initializing model manager...") | ||
|
|
@@ -79,13 +45,22 @@ def __init__(self, config_dict, model_type): | |
|
|
||
| mlflow.set_tracking_uri(config_dict["mlflow"]["tracking_uri"]) | ||
| # When using the AmSC MLflow: inject the X-Api-Key into the requests to authenticate with the MLflow server | ||
| # (See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch) | ||
| # (see https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch) | ||
| if ( | ||
| config_dict["mlflow"]["tracking_uri"] | ||
| == "https://mlflow.american-science-cloud.org" | ||
| ): | ||
| enable_amsc_x_api_key(config_dict) | ||
|
|
||
| # Tell MLflow to ignore SSL certificate errors (common with self-signed internal servers) | ||
| os.environ["MLFLOW_TRACKING_INSECURE_TLS"] = "true" | ||
| urllib3.disable_warnings(InsecureRequestWarning) | ||
|
Comment on lines
+53
to
+55
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this ignores the warning we used to see: |
||
| # Inject the X-Api-Key into the requests. | ||
| try: | ||
| enable_amsc_x_api_key(config_dict) | ||
| except ValueError as e: | ||
| title = "AmSC MLflow authentication setup failed" | ||
| msg = f"Error occurred when setting up AmSC MLflow authentication: {e}" | ||
| add_error(title, msg) | ||
| print(msg) | ||
|
Comment on lines
+53
to
+63
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need the same |
||
| experiment = config_dict["experiment"] | ||
| model_name = f"synapse-{experiment}_{model_type}" | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,52 @@ | ||||||||||||||||||||||||||||||||
| """MLflow utility functions for AmSC authentication and configuration.""" | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def enable_amsc_x_api_key(config_dict): | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| MLflow authentication helper for the AmSC MLflow server. | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Standard MLflow does not automatically inject custom headers like 'X-Api-Key'. | ||||||||||||||||||||||||||||||||
| This patches the http_request function to ensure every request to the server | ||||||||||||||||||||||||||||||||
| includes the AmSC API key. | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||
| config_dict: Configuration dictionary containing mlflow settings | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Raises: | ||||||||||||||||||||||||||||||||
| ValueError: If required mlflow configuration is missing or invalid | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch for more details. | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| import mlflow.utils.rest_utils as rest_utils | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| mlflow_cfg = config_dict.get("mlflow") if config_dict is not None else None | ||||||||||||||||||||||||||||||||
| if not isinstance(mlflow_cfg, dict): | ||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||
| "Missing 'mlflow' configuration section required for AmSC MLFlow authentication." | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| api_key_env = mlflow_cfg.get("api_key_env") | ||||||||||||||||||||||||||||||||
| if not api_key_env: | ||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||
| "Missing 'api_key_env' in 'mlflow' configuration. " | ||||||||||||||||||||||||||||||||
| "Please specify the name of the environment variable containing the AmSC API key." | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| api_key = os.environ.get(api_key_env) | ||||||||||||||||||||||||||||||||
| if api_key is None: | ||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||
| f"The environment variable '{api_key_env}' specified in 'mlflow.api_key_env' " | ||||||||||||||||||||||||||||||||
| "is not set. Please export it with the AmSC MLFlow API key." | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| _orig = rest_utils.http_request | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def patched(host_creds, endpoint, method, *args, **kwargs): | ||||||||||||||||||||||||||||||||
| h = dict(kwargs.get("headers") or kwargs.get("extra_headers") or {}) | ||||||||||||||||||||||||||||||||
| h["X-Api-Key"] = api_key | ||||||||||||||||||||||||||||||||
| kwargs["headers" if "headers" in kwargs else "extra_headers"] = h | ||||||||||||||||||||||||||||||||
| return _orig(host_creds, endpoint, method, *args, **kwargs) | ||||||||||||||||||||||||||||||||
|
Comment on lines
+46
to
+50
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copying from https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch:
Suggested change
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Applied in 688ef05. |
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| rest_utils.http_request = patched | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reordered following the standard Python convention (enforced by tools like
isortand outlined in PEP 8), organizing imports into three groups, separated by blank lines, with each group sorted alphabetically:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Applied in dfd2189.