Skip to content
Open
1 change: 1 addition & 0 deletions dashboard.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ ENTRYPOINT ["/entrypoint.sh"]
COPY dashboard /app/dashboard
COPY experiments /app/experiments
COPY ml/training_pm.sbatch /app/ml/training_pm.sbatch
COPY mlflow_utils.py /app/mlflow_utils.py

# Make port 8080 available to the world outside this container
EXPOSE 8080
Expand Down
77 changes: 26 additions & 51 deletions dashboard/model_manager.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
import asyncio
from datetime import datetime
from pathlib import Path
import tempfile
import os
import yaml
import re
import sys
import tempfile
from datetime import datetime
from pathlib import Path

import mlflow
import urllib3
import yaml
from sfapi_client import AsyncClient
from sfapi_client.compute import Machine
from trame.widgets import vuetify3 as vuetify
from utils import timer, load_config_dict, create_date_filter
from urllib3.exceptions import InsecureRequestWarning

# Add parent directory to path so we can import mlflow_utils from root
sys.path.insert(0, str(Path(__file__).parent.parent))

from calibration_manager import build_inferred_calibration
from error_manager import add_error
from mlflow_utils import enable_amsc_x_api_key
from sfapi_manager import monitor_sfapi_job
from state_manager import state
from utils import timer, load_config_dict, create_date_filter
Comment on lines 1 to +25

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reordered following the standard Python convention (enforced by tools like isort and outlined in PEP 8), organizing imports into three groups, separated by blank lines, with each group sorted alphabetically:

  • Standard library imports
  • Third-party imports
  • Local/project imports

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied in dfd2189.


model_type_dict = {
"Gaussian Process": "GP",
Expand All @@ -22,49 +31,6 @@
}


def enable_amsc_x_api_key(config_dict):
"""
MLflow authentication helper for the AmSC MLflow server.
Standard MLflow does not automatically inject custom headers like 'X-Api-Key'.
This patches the http_request function to ensure every request to the server
includes the AmSC API key.

See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch for more details.
"""
import mlflow.utils.rest_utils as rest_utils

mlflow_cfg = config_dict.get("mlflow") or {}
api_key_env = mlflow_cfg.get("api_key_env")
if not api_key_env:
title = "Unable to enable AmSC X-Api-Key authentication"
msg = "MLFlow configuration is missing 'mlflow.api_key_env'"
add_error(title, msg)
print(msg)
return

api_key = os.environ.get(api_key_env)
if not api_key:
title = "Unable to enable AmSC X-Api-Key authentication"
msg = f"Environment variable '{api_key_env}' in 'mlflow.api_key_env' is not set"
add_error(title, msg)
print(msg)
return
_orig = rest_utils.http_request

def patched(host_creds, endpoint, method, *args, **kwargs):
if "headers" in kwargs and kwargs["headers"] is not None:
h = dict(kwargs["headers"])
h["X-Api-Key"] = api_key
kwargs["headers"] = h
else:
h = dict(kwargs.get("extra_headers") or {})
h["X-Api-Key"] = api_key
kwargs["extra_headers"] = h
return _orig(host_creds, endpoint, method, *args, **kwargs)

rest_utils.http_request = patched


class ModelManager:
def __init__(self, config_dict, model_type):
print("Initializing model manager...")
Expand All @@ -79,13 +45,22 @@ def __init__(self, config_dict, model_type):

mlflow.set_tracking_uri(config_dict["mlflow"]["tracking_uri"])
# When using the AmSC MLflow: inject the X-Api-Key into the requests to authenticate with the MLflow server
# (See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch)
# (see https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch)
if (
config_dict["mlflow"]["tracking_uri"]
== "https://mlflow.american-science-cloud.org"
):
enable_amsc_x_api_key(config_dict)

# Tell MLflow to ignore SSL certificate errors (common with self-signed internal servers)
os.environ["MLFLOW_TRACKING_INSECURE_TLS"] = "true"
urllib3.disable_warnings(InsecureRequestWarning)
Comment on lines +53 to +55

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this ignores the warning we used to see:

/home/edoardo/miniconda3/envs/synapse-gui/lib/python3.11/site-packages/urllib3/connectionpool.py:1097: InsecureRequestWarning: Unverified HTTPS request is being made to host 'mlflow.american-science-cloud.org'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#tls-warnings
  warnings.warn(

# Inject the X-Api-Key into the requests.
try:
enable_amsc_x_api_key(config_dict)
except ValueError as e:
title = "AmSC MLflow authentication setup failed"
msg = f"Error occurred when setting up AmSC MLflow authentication: {e}"
add_error(title, msg)
print(msg)
Comment on lines +53 to +63

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the same try/except logic when we call enable_amsc_x_api_key in train_model.py?

experiment = config_dict["experiment"]
model_name = f"synapse-{experiment}_{model_type}"

Expand Down
1 change: 1 addition & 0 deletions ml.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ ENTRYPOINT ["/app/ml/entrypoint.sh"]
COPY ml/train_model.py /app/ml/train_model.py
COPY ml/Neural_Net_Classes.py /app/ml/Neural_Net_Classes.py
COPY experiments /app/experiments
COPY mlflow_utils.py /app/mlflow_utils.py

# Run train_model.py when the container launches
CMD ["python", "-u", "train_model.py"]
54 changes: 6 additions & 48 deletions ml/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@
import sys
import pandas as pd
from gpytorch.mlls import ExactMarginalLogLikelihood
from pathlib import Path

sys.path.append(".")
# Add parent directories to path for module imports
sys.path.insert(0, str(Path(__file__).parent.parent)) # For mlflow_utils in root
sys.path.insert(0, str(Path(__file__).parent)) # For Neural_Net_Classes in ml/
from Neural_Net_Classes import CombinedNN, train_calibration
from mlflow_utils import enable_amsc_x_api_key

# measure the time it took to import everything
import_end_time = time.time()
Expand Down Expand Up @@ -454,52 +458,6 @@ def train_gp(norm_df_train, input_names, output_names, device):
return combined_gp


def enable_amsc_x_api_key(config_dict):
"""
MLflow authentication helper for the AmSC MLflow server.
Standard MLflow does not automatically inject custom headers like 'X-Api-Key'.
This patches the http_request function to ensure every request to the server
includes the AmSC API key.

See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch for more details.
"""
import mlflow.utils.rest_utils as rest_utils

mlflow_cfg = config_dict.get("mlflow") if config_dict is not None else None
if not isinstance(mlflow_cfg, dict):
raise KeyError(
"Missing 'mlflow' configuration section required for AmSC MLFlow authentication."
)

api_key_env = mlflow_cfg.get("api_key_env")
if not api_key_env:
raise KeyError(
"Missing 'api_key_env' in 'mlflow' configuration. "
"Please specify the name of the environment variable containing the AmSC API key."
)

api_key = os.getenv(api_key_env)
if api_key is None:
raise KeyError(
f"The environment variable '{api_key_env}' specified in 'mlflow.api_key_env' "
"is not set. Please export it with the AmSC MLFlow API key."
)
_orig = rest_utils.http_request

def patched(host_creds, endpoint, method, *args, **kwargs):
if "headers" in kwargs and kwargs["headers"] is not None:
h = dict(kwargs["headers"])
h["X-Api-Key"] = api_key
kwargs["headers"] = h
else:
h = dict(kwargs.get("extra_headers") or {})
h["X-Api-Key"] = api_key
kwargs["extra_headers"] = h
return _orig(host_creds, endpoint, method, *args, **kwargs)

rest_utils.http_request = patched


def register_model_to_mlflow(model, model_type, experiment, config_dict):
"""Register the trained model to MLflow (tracking URI from config)."""
tracking_uri = config_dict["mlflow"]["tracking_uri"]
Expand Down Expand Up @@ -550,7 +508,7 @@ def register_model_to_mlflow(model, model_type, experiment, config_dict):
df_sim = pd.DataFrame(db[experiment].find({"experiment_flag": 0}))

# When using the AmSC MLflow: inject the X-Api-Key into the requests to authenticate with the MLflow server
# (See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch)
# (see https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch)
if (
"mlflow" in config_dict
and config_dict["mlflow"].get("tracking_uri")
Expand Down
52 changes: 52 additions & 0 deletions mlflow_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""MLflow utility functions for AmSC authentication and configuration."""

import os


def enable_amsc_x_api_key(config_dict):
"""
MLflow authentication helper for the AmSC MLflow server.

Standard MLflow does not automatically inject custom headers like 'X-Api-Key'.
This patches the http_request function to ensure every request to the server
includes the AmSC API key.

Args:
config_dict: Configuration dictionary containing mlflow settings

Raises:
ValueError: If required mlflow configuration is missing or invalid

See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch for more details.
"""
import mlflow.utils.rest_utils as rest_utils

mlflow_cfg = config_dict.get("mlflow") if config_dict is not None else None
if not isinstance(mlflow_cfg, dict):
raise ValueError(
"Missing 'mlflow' configuration section required for AmSC MLFlow authentication."
)

api_key_env = mlflow_cfg.get("api_key_env")
if not api_key_env:
raise ValueError(
"Missing 'api_key_env' in 'mlflow' configuration. "
"Please specify the name of the environment variable containing the AmSC API key."
)

api_key = os.environ.get(api_key_env)
if api_key is None:
raise ValueError(
f"The environment variable '{api_key_env}' specified in 'mlflow.api_key_env' "
"is not set. Please export it with the AmSC MLFlow API key."
)

_orig = rest_utils.http_request

def patched(host_creds, endpoint, method, *args, **kwargs):
h = dict(kwargs.get("headers") or kwargs.get("extra_headers") or {})
h["X-Api-Key"] = api_key
kwargs["headers" if "headers" in kwargs else "extra_headers"] = h
return _orig(host_creds, endpoint, method, *args, **kwargs)
Comment on lines +46 to +50

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copying from https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch:

Suggested change
def patched(host_creds, endpoint, method, *args, **kwargs):
if "headers" in kwargs and kwargs["headers"] is not None:
h = dict(kwargs["headers"])
h["X-Api-Key"] = api_key
kwargs["headers"] = h
else:
h = dict(kwargs.get("extra_headers") or {})
h["X-Api-Key"] = api_key
kwargs["extra_headers"] = h
return _orig(host_creds, endpoint, method, *args, **kwargs)
def patched(host_creds, endpoint, method, *args, **kwargs):
h = dict(kwargs.get("headers") or kwargs.get("extra_headers") or {})
h["X-Api-Key"] = api_key
kwargs["headers" if "headers" in kwargs else "extra_headers"] = h
return _orig(host_creds, endpoint, method, *args, **kwargs)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied in 688ef05.


rest_utils.http_request = patched
Loading