From 25e1b347d3408ee4f806dacf51dda570117165aa Mon Sep 17 00:00:00 2001 From: Edoardo Zoni Date: Mon, 2 Mar 2026 14:53:00 -0800 Subject: [PATCH 1/7] Move enable_amsc_x_api_key to shared utils file --- dashboard/model_manager.py | 57 ++++++++++---------------------------- ml/train_model.py | 54 ++++-------------------------------- mlflow_utils.py | 57 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 91 deletions(-) create mode 100644 mlflow_utils.py diff --git a/dashboard/model_manager.py b/dashboard/model_manager.py index 9843a4b4..8887e02e 100644 --- a/dashboard/model_manager.py +++ b/dashboard/model_manager.py @@ -3,6 +3,7 @@ from pathlib import Path import tempfile import os +import sys import yaml import re import urllib3 @@ -11,10 +12,15 @@ from sfapi_client import AsyncClient from sfapi_client.compute import Machine from trame.widgets import vuetify3 as vuetify + +# Add parent directory to path so we can import mlflow_utils from root +sys.path.insert(0, str(Path(__file__).parent.parent)) + from utils import timer, load_config_dict, create_date_filter from error_manager import add_error from sfapi_manager import monitor_sfapi_job from state_manager import state +from mlflow_utils import enable_amsc_x_api_key model_type_tag_dict = { "Gaussian Process": "GP", @@ -23,47 +29,6 @@ } -def enable_amsc_x_api_key(config_dict): - """ - MLflow authentication helper for the AmSC MLflow server. - Standard MLflow does not automatically inject custom headers like 'X-Api-Key'. - This patches the http_request function to ensure every request to the server - includes the AmSC API key. - - See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch for more details. - """ - import mlflow.utils.rest_utils as rest_utils - - mlflow_cfg = config_dict.get("mlflow") or {} - api_key_env = mlflow_cfg.get("api_key_env") - if not api_key_env: - add_error( - "MLFlow configuration is missing 'mlflow.api_key_env'; cannot enable AmSC X-Api-Key authentication." - ) - return - - api_key = os.environ.get(api_key_env) - if not api_key: - add_error( - f"Environment variable '{api_key_env}' (configured in 'mlflow.api_key_env') is not set; cannot enable AmSC X-Api-Key authentication." - ) - return - _orig = rest_utils.http_request - - def patched(host_creds, endpoint, method, *args, **kwargs): - if "headers" in kwargs and kwargs["headers"] is not None: - h = dict(kwargs["headers"]) - h["X-Api-Key"] = api_key - kwargs["headers"] = h - else: - h = dict(kwargs.get("extra_headers") or {}) - h["X-Api-Key"] = api_key - kwargs["extra_headers"] = h - return _orig(host_creds, endpoint, method, *args, **kwargs) - - rest_utils.http_request = patched - - class ModelManager: def __init__(self): print("Initializing model manager...") @@ -93,7 +58,7 @@ def __init__(self): mlflow.set_tracking_uri(config_dict["mlflow"]["tracking_uri"]) # When using the AmSC MLflow: - # (See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch) + # (see https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch) if ( config_dict["mlflow"]["tracking_uri"] == "https://mlflow.american-science-cloud.org" @@ -102,7 +67,13 @@ def __init__(self): os.environ["MLFLOW_TRACKING_INSECURE_TLS"] = "true" urllib3.disable_warnings(InsecureRequestWarning) # - inject the X-Api-Key into the requests. - enable_amsc_x_api_key(config_dict) + try: + enable_amsc_x_api_key(config_dict) + except KeyError as e: + title = "AmSC MLflow authentication setup failed" + msg = f"Error occurred when setting up AmSC MLflow authentication: {e}" + add_error(title, msg) + print(msg) model_name = f"{state.experiment}_{model_type_tag}" try: diff --git a/ml/train_model.py b/ml/train_model.py index e9a4a343..82e45ea1 100644 --- a/ml/train_model.py +++ b/ml/train_model.py @@ -28,9 +28,13 @@ import sys import pandas as pd from gpytorch.mlls import ExactMarginalLogLikelihood +from pathlib import Path -sys.path.append(".") +# Add parent directories to path for module imports +sys.path.insert(0, str(Path(__file__).parent.parent)) # For mlflow_utils in root +sys.path.insert(0, str(Path(__file__).parent)) # For Neural_Net_Classes in ml/ from Neural_Net_Classes import CombinedNN +from mlflow_utils import enable_amsc_x_api_key # measure the time it took to import everything import_end_time = time.time() @@ -374,52 +378,6 @@ def train_gp( ) -def enable_amsc_x_api_key(config_dict): - """ - MLflow authentication helper for the AmSC MLflow server. - Standard MLflow does not automatically inject custom headers like 'X-Api-Key'. - This patches the http_request function to ensure every request to the server - includes the AmSC API key. - - See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch for more details. - """ - import mlflow.utils.rest_utils as rest_utils - - mlflow_cfg = config_dict.get("mlflow") if config_dict is not None else None - if not isinstance(mlflow_cfg, dict): - raise KeyError( - "Missing 'mlflow' configuration section required for AmSC MLFlow authentication." - ) - - api_key_env = mlflow_cfg.get("api_key_env") - if not api_key_env: - raise KeyError( - "Missing 'api_key_env' in 'mlflow' configuration. " - "Please specify the name of the environment variable containing the AmSC API key." - ) - - api_key = os.getenv(api_key_env) - if api_key is None: - raise KeyError( - f"The environment variable '{api_key_env}' specified in 'mlflow.api_key_env' " - "is not set. Please export it with the AmSC MLFlow API key." - ) - _orig = rest_utils.http_request - - def patched(host_creds, endpoint, method, *args, **kwargs): - if "headers" in kwargs and kwargs["headers"] is not None: - h = dict(kwargs["headers"]) - h["X-Api-Key"] = api_key - kwargs["headers"] = h - else: - h = dict(kwargs.get("extra_headers") or {}) - h["X-Api-Key"] = api_key - kwargs["extra_headers"] = h - return _orig(host_creds, endpoint, method, *args, **kwargs) - - rest_utils.http_request = patched - - def register_model_to_mlflow(model, model_type, experiment, config_dict): """Register the trained model to MLflow (tracking URI from config).""" tracking_uri = config_dict["mlflow"]["tracking_uri"] @@ -471,7 +429,7 @@ def register_model_to_mlflow(model, model_type, experiment, config_dict): df_sim = pd.DataFrame(db[experiment].find({"experiment_flag": 0})) # When using the AmSC MLflow: - # (See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch) + # (see https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch) if ( "mlflow" in config_dict and config_dict["mlflow"].get("tracking_uri") diff --git a/mlflow_utils.py b/mlflow_utils.py new file mode 100644 index 00000000..2c43ef30 --- /dev/null +++ b/mlflow_utils.py @@ -0,0 +1,57 @@ +"""MLflow utility functions for AmSC authentication and configuration.""" + +import os + + +def enable_amsc_x_api_key(config_dict): + """ + MLflow authentication helper for the AmSC MLflow server. + + Standard MLflow does not automatically inject custom headers like 'X-Api-Key'. + This patches the http_request function to ensure every request to the server + includes the AmSC API key. + + Args: + config_dict: Configuration dictionary containing mlflow settings + + Raises: + KeyError: If required mlflow configuration is missing or invalid + + See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch for more details. + """ + import mlflow.utils.rest_utils as rest_utils + + mlflow_cfg = config_dict.get("mlflow") if config_dict is not None else None + if not isinstance(mlflow_cfg, dict): + raise KeyError( + "Missing 'mlflow' configuration section required for AmSC MLFlow authentication." + ) + + api_key_env = mlflow_cfg.get("api_key_env") + if not api_key_env: + raise KeyError( + "Missing 'api_key_env' in 'mlflow' configuration. " + "Please specify the name of the environment variable containing the AmSC API key." + ) + + api_key = os.getenv(api_key_env) + if api_key is None: + raise KeyError( + f"The environment variable '{api_key_env}' specified in 'mlflow.api_key_env' " + "is not set. Please export it with the AmSC MLFlow API key." + ) + + _orig = rest_utils.http_request + + def patched(host_creds, endpoint, method, *args, **kwargs): + if "headers" in kwargs and kwargs["headers"] is not None: + h = dict(kwargs["headers"]) + h["X-Api-Key"] = api_key + kwargs["headers"] = h + else: + h = dict(kwargs.get("extra_headers") or {}) + h["X-Api-Key"] = api_key + kwargs["extra_headers"] = h + return _orig(host_creds, endpoint, method, *args, **kwargs) + + rest_utils.http_request = patched From 77583977f1bbbb30fb9191f4b46e1ecad10063e7 Mon Sep 17 00:00:00 2001 From: Edoardo Zoni Date: Mon, 2 Mar 2026 14:56:48 -0800 Subject: [PATCH 2/7] Update Dockerfiles --- dashboard.Dockerfile | 1 + ml.Dockerfile | 1 + 2 files changed, 2 insertions(+) diff --git a/dashboard.Dockerfile b/dashboard.Dockerfile index 52bcb1b3..28094ae2 100644 --- a/dashboard.Dockerfile +++ b/dashboard.Dockerfile @@ -24,6 +24,7 @@ ENTRYPOINT ["/entrypoint.sh"] COPY dashboard /app/dashboard COPY experiments /app/experiments COPY ml/training_pm.sbatch /app/ml/training_pm.sbatch +COPY mlflow_utils.py /app/mlflow_utils.py # Make port 8080 available to the world outside this container EXPOSE 8080 diff --git a/ml.Dockerfile b/ml.Dockerfile index 09392d7d..7520fe04 100644 --- a/ml.Dockerfile +++ b/ml.Dockerfile @@ -28,6 +28,7 @@ ENTRYPOINT ["/app/ml/entrypoint.sh"] COPY ml/train_model.py /app/ml/train_model.py COPY ml/Neural_Net_Classes.py /app/ml/Neural_Net_Classes.py COPY experiments /app/experiments +COPY mlflow_utils.py /app/mlflow_utils.py # Run train_model.py when the container launches CMD ["python", "-u", "train_model.py"] From f3272f0289afc14a698cc4f65bfad76ddc7897fb Mon Sep 17 00:00:00 2001 From: Edoardo Zoni Date: Mon, 2 Mar 2026 15:13:12 -0800 Subject: [PATCH 3/7] Raise ValueError instead of KeyError where appropriate --- dashboard/model_manager.py | 2 +- mlflow_utils.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dashboard/model_manager.py b/dashboard/model_manager.py index 8887e02e..143a6f18 100644 --- a/dashboard/model_manager.py +++ b/dashboard/model_manager.py @@ -69,7 +69,7 @@ def __init__(self): # - inject the X-Api-Key into the requests. try: enable_amsc_x_api_key(config_dict) - except KeyError as e: + except ValueError as e: title = "AmSC MLflow authentication setup failed" msg = f"Error occurred when setting up AmSC MLflow authentication: {e}" add_error(title, msg) diff --git a/mlflow_utils.py b/mlflow_utils.py index 2c43ef30..cae417dd 100644 --- a/mlflow_utils.py +++ b/mlflow_utils.py @@ -15,7 +15,7 @@ def enable_amsc_x_api_key(config_dict): config_dict: Configuration dictionary containing mlflow settings Raises: - KeyError: If required mlflow configuration is missing or invalid + ValueError: If required mlflow configuration is missing or invalid See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch for more details. """ @@ -23,20 +23,20 @@ def enable_amsc_x_api_key(config_dict): mlflow_cfg = config_dict.get("mlflow") if config_dict is not None else None if not isinstance(mlflow_cfg, dict): - raise KeyError( + raise ValueError( "Missing 'mlflow' configuration section required for AmSC MLFlow authentication." ) api_key_env = mlflow_cfg.get("api_key_env") if not api_key_env: - raise KeyError( + raise ValueError( "Missing 'api_key_env' in 'mlflow' configuration. " "Please specify the name of the environment variable containing the AmSC API key." ) api_key = os.getenv(api_key_env) if api_key is None: - raise KeyError( + raise ValueError( f"The environment variable '{api_key_env}' specified in 'mlflow.api_key_env' " "is not set. Please export it with the AmSC MLFlow API key." ) From 51d3082041614478afe52436d61f748a21ede20c Mon Sep 17 00:00:00 2001 From: Edoardo Zoni Date: Thu, 26 Mar 2026 09:11:22 -0700 Subject: [PATCH 4/7] Reorder import statements --- dashboard/model_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dashboard/model_manager.py b/dashboard/model_manager.py index 98446a71..7b89b574 100644 --- a/dashboard/model_manager.py +++ b/dashboard/model_manager.py @@ -1,12 +1,12 @@ import asyncio from datetime import datetime from pathlib import Path -import tempfile import os -import sys -import yaml import re +import sys +import tempfile import urllib3 +import yaml import mlflow from sfapi_client import AsyncClient from sfapi_client.compute import Machine From 722869c9435f4b25948109bc4f63c08336c6cfcd Mon Sep 17 00:00:00 2001 From: Edoardo Zoni Date: Thu, 26 Mar 2026 09:15:55 -0700 Subject: [PATCH 5/7] Clean up inline comments --- dashboard/model_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dashboard/model_manager.py b/dashboard/model_manager.py index 7b89b574..68cf2939 100644 --- a/dashboard/model_manager.py +++ b/dashboard/model_manager.py @@ -50,10 +50,10 @@ def __init__(self, config_dict, model_type_tag): config_dict["mlflow"]["tracking_uri"] == "https://mlflow.american-science-cloud.org" ): - # - tell MLflow to ignore SSL certificate errors (common with self-signed internal servers) + # Tell MLflow to ignore SSL certificate errors (common with self-signed internal servers) os.environ["MLFLOW_TRACKING_INSECURE_TLS"] = "true" urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - # - inject the X-Api-Key into the requests. + # Inject the X-Api-Key into the requests. try: enable_amsc_x_api_key(config_dict) except ValueError as e: From dfd21898ff5b897e23f38ee396cd7644279c93fe Mon Sep 17 00:00:00 2001 From: Edoardo Zoni Date: Thu, 26 Mar 2026 09:20:09 -0700 Subject: [PATCH 6/7] Reorder import statements --- dashboard/model_manager.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/dashboard/model_manager.py b/dashboard/model_manager.py index 68cf2939..efe91e96 100644 --- a/dashboard/model_manager.py +++ b/dashboard/model_manager.py @@ -1,25 +1,27 @@ import asyncio -from datetime import datetime -from pathlib import Path import os import re import sys import tempfile +from datetime import datetime +from pathlib import Path + +import mlflow import urllib3 import yaml -import mlflow from sfapi_client import AsyncClient from sfapi_client.compute import Machine from trame.widgets import vuetify3 as vuetify +from urllib3.exceptions import InsecureRequestWarning # Add parent directory to path so we can import mlflow_utils from root sys.path.insert(0, str(Path(__file__).parent.parent)) -from utils import timer, load_config_dict, create_date_filter from error_manager import add_error +from mlflow_utils import enable_amsc_x_api_key from sfapi_manager import monitor_sfapi_job from state_manager import state -from mlflow_utils import enable_amsc_x_api_key +from utils import timer, load_config_dict, create_date_filter model_type_tag_dict = { "Gaussian Process": "GP", @@ -52,7 +54,7 @@ def __init__(self, config_dict, model_type_tag): ): # Tell MLflow to ignore SSL certificate errors (common with self-signed internal servers) os.environ["MLFLOW_TRACKING_INSECURE_TLS"] = "true" - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + urllib3.disable_warnings(InsecureRequestWarning) # Inject the X-Api-Key into the requests. try: enable_amsc_x_api_key(config_dict) From 688ef0565341da04a8b0e52f80976d2b319a720d Mon Sep 17 00:00:00 2001 From: Edoardo Zoni <59625522+EZoni@users.noreply.github.com> Date: Thu, 26 Mar 2026 09:28:16 -0700 Subject: [PATCH 7/7] Apply suggestions from code review Co-authored-by: Edoardo Zoni <59625522+EZoni@users.noreply.github.com> --- mlflow_utils.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/mlflow_utils.py b/mlflow_utils.py index cae417dd..8cfb2a45 100644 --- a/mlflow_utils.py +++ b/mlflow_utils.py @@ -34,7 +34,7 @@ def enable_amsc_x_api_key(config_dict): "Please specify the name of the environment variable containing the AmSC API key." ) - api_key = os.getenv(api_key_env) + api_key = os.environ.get(api_key_env) if api_key is None: raise ValueError( f"The environment variable '{api_key_env}' specified in 'mlflow.api_key_env' " @@ -44,14 +44,9 @@ def enable_amsc_x_api_key(config_dict): _orig = rest_utils.http_request def patched(host_creds, endpoint, method, *args, **kwargs): - if "headers" in kwargs and kwargs["headers"] is not None: - h = dict(kwargs["headers"]) - h["X-Api-Key"] = api_key - kwargs["headers"] = h - else: - h = dict(kwargs.get("extra_headers") or {}) - h["X-Api-Key"] = api_key - kwargs["extra_headers"] = h + h = dict(kwargs.get("headers") or kwargs.get("extra_headers") or {}) + h["X-Api-Key"] = api_key + kwargs["headers" if "headers" in kwargs else "extra_headers"] = h return _orig(host_creds, endpoint, method, *args, **kwargs) rest_utils.http_request = patched