Skip to content

Commit 51ae05f

Browse files
committed
Merge branch 'main' into copilot/add-copilot-instructions
2 parents d531ef0 + 2ecdda9 commit 51ae05f

14 files changed

Lines changed: 710 additions & 423 deletions

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
33
# Ruff version
4-
rev: v0.15.2
4+
rev: v0.15.7
55
hooks:
66
# Run the linter
77
- id: ruff-check

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ To display ML predictions, the application requires the following:
2525
- **Simulation and experimental data points**: Each data point consists of values for the scalar inputs and outputs defined in the experiment configuration file.
2626
Data points are stored in a [MongoDB](https://www.mongodb.com/) database, where each experiment is represented by a separate collection.
2727
Experimental and simulation data points are stored in the same collection and are distinguished by the `experimental_flag` attribute.
28-
- **ML models**: Machine learning models that interpolate between data points and are stored in a separate MongoDB collection named `models`.
28+
- **ML models**: Machine learning models that interpolate between data points and are stored in [MLflow](https://mlflow.org/).
2929
- **Simulation movies** (optional): For certain experiments, users can click on simulation data points to visualize simulation movies.
3030
The corresponding MP4 files are stored in the Perlmutter shared file system at `/global/cfs/cdirs/m558/superfacility/simulation_data`.
3131
This directory is mounted on the container image running on Spin.

dashboard/README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@ conda-lock install --name synapse-gui environment-lock.yml
5454

5555
2. Move to the [dashboard/](./) directory.
5656

57-
3. Set up the database settings (read-only):
57+
3. Set up the database settings (read-only) and the AmSC MLflow API key:
5858
```bash
5959
export SF_DB_HOST='127.0.0.1'
6060
export SF_DB_READONLY_PASSWORD='your_password_here' # Use SINGLE quotes around the password!
61+
export AM_SC_API_KEY='your_amsc_api_key_here' # Required when MLflow tracking_uri is AmSC
6162
```
6263

6364
4. Activate the conda environment `synapse-gui`:
@@ -85,11 +86,11 @@ conda-lock install --name synapse-gui environment-lock.yml
8586

8687
4. Run the Docker container:
8788
```bash
88-
docker run --network=host -v /etc/localtime:/etc/localtime -v $PWD/ml:/app/ml -e SF_DB_HOST='127.0.0.1' -e SF_DB_READONLY_PASSWORD='your_password_here' synapse-gui
89+
docker run --network=host -v /etc/localtime:/etc/localtime -v $PWD/ml:/app/ml -e SF_DB_HOST='127.0.0.1' -e SF_DB_READONLY_PASSWORD='your_password_here' -e AM_SC_API_KEY='your_amsc_api_key_here' synapse-gui
8990
```
9091
For debugging, you can enter the container without starting the app:
9192
```bash
92-
docker run --network=host -v /etc/localtime:/etc/localtime -v $PWD/ml:/app/ml -e SF_DB_HOST='127.0.0.1' -e SF_DB_READONLY_PASSWORD='your_password_here' -it synapse-gui bash
93+
docker run --network=host -v /etc/localtime:/etc/localtime -v $PWD/ml:/app/ml -e SF_DB_HOST='127.0.0.1' -e SF_DB_READONLY_PASSWORD='your_password_here' -e AM_SC_API_KEY='your_amsc_api_key_here' -it synapse-gui bash
9394
```
9495
Note that `-v /etc/localtime:/etc/localtime` is necessary to synchronize the time zone in the container with the host machine.
9596

dashboard/app.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from trame.ui.vuetify3 import SinglePageWithDrawerLayout
77
from trame.widgets import plotly, router, vuetify3 as vuetify, html
88

9-
from model_manager import ModelManager
9+
from model_manager import ModelManager, model_type_tag_dict
1010
from outputs_manager import OutputManager
1111
from optimization_manager import OptimizationManager
1212
from parameters_manager import ParametersManager
@@ -16,6 +16,7 @@
1616
from error_manager import error_panel, add_error
1717
from utils import (
1818
data_depth_panel,
19+
load_config_dict,
1920
load_experiments,
2021
load_database,
2122
load_data,
@@ -64,14 +65,18 @@ def update(
6465
state.experiment
6566
)
6667
# load data
67-
db = load_database(state.experiment)
68-
exp_data, sim_data = load_data(db)
68+
config_dict = load_config_dict(state.experiment)
69+
db = load_database(config_dict)
70+
exp_data, sim_data = load_data(db, state.experiment, state.experiment_date_range)
6971
# reset output
7072
if reset_output:
7173
out_manager = OutputManager(output_variables)
7274
# reset model
7375
if reset_model:
74-
mod_manager = ModelManager(db)
76+
mod_manager = ModelManager(
77+
config_dict=config_dict,
78+
model_type_tag=model_type_tag_dict[state.model_type],
79+
)
7580
opt_manager = OptimizationManager(mod_manager)
7681
# reset parameters
7782
if reset_parameters:
@@ -257,7 +262,8 @@ def find_simulation(event, db):
257262

258263

259264
def open_simulation_dialog(event):
260-
db = load_database(state.experiment)
265+
config_dict = load_config_dict(state.experiment)
266+
db = load_database(config_dict)
261267
try:
262268
data_directory, file_path = find_simulation(event, db)
263269
state.simulation_video = file_path.endswith(".mp4")

dashboard/error_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
from trame.widgets import vuetify3 as vuetify, html
2-
from state_manager import state
2+
from state_manager import state, server
33

44

55
def add_error(title, msg):
6+
if not server.running:
7+
# Outside of a Trame app (e.g. check_model.py), raise a Python error
8+
# to surface the error to the caller.
9+
raise RuntimeError(f"{title}: {msg}")
10+
# Otherwise: Inside a Trame app, add the error to the state.
611
state.errors.append(
712
{
813
"id": state.error_counter,

dashboard/model_manager.py

Lines changed: 82 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55
import os
66
import yaml
77
import re
8+
import mlflow
89
from sfapi_client import AsyncClient
910
from sfapi_client.compute import Machine
10-
from lume_model.models.torch_model import TorchModel
11-
from lume_model.models.ensemble import NNEnsemble
12-
from lume_model.models.gp_model import GPModel
1311
from trame.widgets import vuetify3 as vuetify
14-
from utils import verify_input_variables, timer, load_config_dict, create_date_filter
12+
from utils import timer, load_config_dict, create_date_filter
1513
from error_manager import add_error
1614
from sfapi_manager import monitor_sfapi_job
1715
from state_manager import state
@@ -23,107 +21,96 @@
2321
}
2422

2523

24+
def enable_amsc_x_api_key(config_dict):
25+
"""
26+
MLflow authentication helper for the AmSC MLflow server.
27+
Standard MLflow does not automatically inject custom headers like 'X-Api-Key'.
28+
This patches the http_request function to ensure every request to the server
29+
includes the AmSC API key.
30+
31+
See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch for more details.
32+
"""
33+
import mlflow.utils.rest_utils as rest_utils
34+
35+
mlflow_cfg = config_dict.get("mlflow") or {}
36+
api_key_env = mlflow_cfg.get("api_key_env")
37+
if not api_key_env:
38+
title = "Unable to enable AmSC X-Api-Key authentication"
39+
msg = "MLFlow configuration is missing 'mlflow.api_key_env'"
40+
add_error(title, msg)
41+
print(msg)
42+
return
43+
44+
api_key = os.environ.get(api_key_env)
45+
if not api_key:
46+
title = "Unable to enable AmSC X-Api-Key authentication"
47+
msg = f"Environment variable '{api_key_env}' in 'mlflow.api_key_env' is not set"
48+
add_error(title, msg)
49+
print(msg)
50+
return
51+
_orig = rest_utils.http_request
52+
53+
def patched(host_creds, endpoint, method, *args, **kwargs):
54+
if "headers" in kwargs and kwargs["headers"] is not None:
55+
h = dict(kwargs["headers"])
56+
h["X-Api-Key"] = api_key
57+
kwargs["headers"] = h
58+
else:
59+
h = dict(kwargs.get("extra_headers") or {})
60+
h["X-Api-Key"] = api_key
61+
kwargs["extra_headers"] = h
62+
return _orig(host_creds, endpoint, method, *args, **kwargs)
63+
64+
rest_utils.http_request = patched
65+
66+
2667
class ModelManager:
27-
def __init__(self, db):
68+
def __init__(self, config_dict, model_type_tag):
2869
print("Initializing model manager...")
29-
# Set initial default values
3070
self.__model = None
3171
self.__is_neural_network = False
3272
self.__is_gaussian_process = False
3373
self.__is_neural_network_ensemble = False
74+
self.__model_type_tag = model_type_tag
3475

35-
# Download model information from the database
36-
collection = db["models"]
37-
model_type_tag = model_type_tag_dict[state.model_type]
38-
query = {"experiment": state.experiment, "model_type": model_type_tag}
39-
count = collection.count_documents(query)
40-
41-
if count == 0:
42-
print(
43-
f"No model found for experiment: {state.experiment} and model type: {model_type_tag}"
44-
)
45-
return
46-
elif count > 1:
76+
if "mlflow" not in config_dict or not config_dict["mlflow"].get("tracking_uri"):
4777
print(
48-
f"Multiple models found ({count}) for experiment: {state.experiment} and model type: {model_type_tag}!"
78+
f"No mlflow.tracking_uri in configuration file for {config_dict['experiment']}; cannot load model from MLflow."
4979
)
5080
return
5181

52-
# Load model information from the database
53-
document = collection.find_one(query)
54-
# Save model files in a temporary directory,
55-
# so that it can then be loaded with lume_model
56-
with tempfile.TemporaryDirectory() as temp_dir:
57-
# Open content of the top-level YAML file
58-
yaml_file_content = document["yaml_file_content"]
59-
model_filename = f"{state.experiment}.yml"
60-
with open(os.path.join(temp_dir, model_filename), "w") as f:
61-
f.write(yaml_file_content)
82+
mlflow.set_tracking_uri(config_dict["mlflow"]["tracking_uri"])
83+
# When using the AmSC MLflow: inject the X-Api-Key into the requests to authenticate with the MLflow server
84+
# (See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch)
85+
if (
86+
config_dict["mlflow"]["tracking_uri"]
87+
== "https://mlflow.american-science-cloud.org"
88+
):
89+
enable_amsc_x_api_key(config_dict)
6290

63-
# Extract list of files to download
64-
files_to_download = []
65-
if state.model_type == "Neural Network (ensemble)":
66-
models_info = yaml.safe_load(yaml_file_content)
67-
# Download yaml file for each model within the ensemble
68-
for model in models_info["models"]:
69-
yaml_file_name = model.replace("_model.jit", ".yml")
70-
with open(os.path.join(temp_dir, yaml_file_name), "wb") as f:
71-
f.write(document[yaml_file_name])
72-
model_info = yaml.safe_load(document[yaml_file_name])
73-
# Extract files to download
74-
files_to_download += (
75-
[model_info["model"]]
76-
+ model_info["input_transformers"]
77-
+ model_info["output_transformers"]
78-
)
79-
else:
80-
# Extract files to download
81-
model_info = yaml.safe_load(yaml_file_content)
82-
files_to_download = (
83-
[model_info["model"]]
84-
+ model_info["input_transformers"]
85-
+ model_info["output_transformers"]
86-
)
87-
88-
# Download all the files that define the model(s)
89-
for filename in files_to_download:
90-
with open(os.path.join(temp_dir, filename), "wb") as f:
91-
f.write(document[filename])
91+
experiment = config_dict["experiment"]
92+
model_name = f"{experiment}_{model_type_tag}"
9293

93-
# Check consistency of the model file
94-
print("Reading model file...")
95-
model_file = os.path.join(temp_dir, f"{state.experiment}.yml")
96-
if not os.path.isfile(model_file):
97-
title = f"Model file {model_file} not found"
98-
msg = f"Unable to find the model file for {state.experiment}"
99-
add_error(title, msg)
100-
print(msg)
101-
return
102-
elif not verify_input_variables(model_file, state.experiment):
103-
title = "Model file input variable mismatch"
104-
msg = f"Model file {model_file} has different input variables than the configuration file for {state.experiment}"
105-
add_error(title, msg)
106-
print(msg)
107-
return
108-
109-
# Load model with lume_model
110-
try:
111-
if state.model_type == "Neural Network (single)":
112-
self.__is_neural_network = True
113-
self.__model = TorchModel(model_file)
114-
elif state.model_type == "Neural Network (ensemble)":
115-
self.__is_neural_network_ensemble = True
116-
self.__model = NNEnsemble(model_file)
117-
elif state.model_type == "Gaussian Process":
118-
self.__is_gaussian_process = True
119-
self.__model = GPModel.from_yaml(model_file)
120-
else:
121-
raise ValueError(f"Unsupported model type: {state.model_type}")
122-
except Exception as e:
123-
title = f"Unable to load model {state.model_type}"
124-
msg = f"Error occurred when loading model: {e}"
125-
add_error(title, msg)
126-
print(msg)
94+
try:
95+
# Download model from MLflow server
96+
self.__model = (
97+
mlflow.pyfunc.load_model(f"models:/{model_name}/latest")
98+
.unwrap_python_model()
99+
.model
100+
)
101+
if model_type_tag == "NN":
102+
self.__is_neural_network = True
103+
elif model_type_tag == "ensemble_NN":
104+
self.__is_neural_network_ensemble = True
105+
elif model_type_tag == "GP":
106+
self.__is_gaussian_process = True
107+
else:
108+
raise ValueError(f"Unsupported model type: {model_type_tag}")
109+
except Exception as e:
110+
title = f"Unable to load model {model_type_tag}"
111+
msg = f"Error occurred when loading model from MLflow: {e}"
112+
add_error(title, msg)
113+
print(msg)
127114

128115
def avail(self):
129116
print("Checking model availability...")
@@ -153,22 +140,13 @@ def evaluate(self, parameters, output):
153140
mean = output_dict[output]
154141
mean_error = 0.0 # trick to collapse error range when lower/upper bounds are not predicted
155142
elif self.__is_gaussian_process or self.__is_neural_network_ensemble:
156-
if self.__is_gaussian_process:
157-
# TODO use "exp" only once experimental data is available for all experiments
158-
task_tag = "exp" if state.experiment == "bella-ip2" else "sim"
159-
output_key = [key for key in output_dict.keys() if task_tag in key][
160-
0
161-
]
162-
elif self.__is_neural_network_ensemble:
163-
output_key = list(output_dict.keys())[0]
164-
165143
# compute mean, standard deviation and mean error
166144
# (call detach method to detach gradients from tensors)
167-
mean = output_dict[output_key].mean.detach()
168-
std_dev = output_dict[output_key].variance.sqrt().detach()
145+
mean = output_dict[output].mean.detach()
146+
std_dev = output_dict[output].variance.sqrt().detach()
169147
mean_error = 2.0 * std_dev
170148
else:
171-
raise ValueError(f"Unsupported model type: {state.model_type}")
149+
raise ValueError(f"Unsupported model type: {self.__model_type_tag}")
172150
# compute lower/upper bounds for error range
173151
lower = mean - mean_error
174152
upper = mean + mean_error

dashboard/utils.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,13 @@ def create_date_filter(experiment_date_range):
103103

104104

105105
@timer
106-
def load_data(db):
106+
def load_data(db, experiment, date_range=None):
107107
print("Loading data from database...")
108108
# create date filter if date range is set
109-
date_filter = create_date_filter(state.experiment_date_range)
109+
date_filter = create_date_filter(date_range)
110110
# load experiment and simulation data points in dataframes
111-
exp_data = pd.DataFrame(
112-
db[state.experiment].find({"experiment_flag": 1, **date_filter})
113-
)
114-
sim_data = pd.DataFrame(db[state.experiment].find({"experiment_flag": 0}))
111+
exp_data = pd.DataFrame(db[experiment].find({"experiment_flag": 1, **date_filter}))
112+
sim_data = pd.DataFrame(db[experiment].find({"experiment_flag": 0}))
115113
# Store '_id', 'date' as string
116114
for key in ["_id", "date"]:
117115
if key in exp_data.columns:
@@ -121,32 +119,9 @@ def load_data(db):
121119
return (exp_data, sim_data)
122120

123121

124-
def verify_input_variables(model_file, experiment):
125-
print("Checking model consistency...")
126-
# read configuration file
127-
input_vars, _, _ = load_variables(experiment)
128-
config_vars = [input_var["name"] for input_var in input_vars.values()]
129-
config_vars.sort()
130-
# read model file
131-
with open(model_file) as f:
132-
model_str = f.read()
133-
# load model dictionary
134-
model_dict = yaml.safe_load(model_str)
135-
# load model input variables list
136-
model_vars = list(model_dict["input_variables"].keys())
137-
model_vars.sort()
138-
# check if configuration list and model list match
139-
match = config_vars == model_vars
140-
if not match:
141-
print("Input variables in configuration file and model file do not match")
142-
return match
143-
144-
145122
@timer
146-
def load_database(experiment):
123+
def load_database(config_dict):
147124
print("Loading database...")
148-
# load configuration dictionary
149-
config_dict = load_config_dict(experiment)
150125
# read database information from configuration dictionary
151126
db_host = config_dict["database"]["host"]
152127
db_port = config_dict["database"]["port"]

0 commit comments

Comments
 (0)