|
5 | 5 | import os |
6 | 6 | import yaml |
7 | 7 | import re |
| 8 | +import mlflow |
8 | 9 | from sfapi_client import AsyncClient |
9 | 10 | from sfapi_client.compute import Machine |
10 | | -from lume_model.models.torch_model import TorchModel |
11 | | -from lume_model.models.ensemble import NNEnsemble |
12 | | -from lume_model.models.gp_model import GPModel |
13 | 11 | from trame.widgets import vuetify3 as vuetify |
14 | | -from utils import verify_input_variables, timer, load_config_dict, create_date_filter |
| 12 | +from utils import timer, load_config_dict, create_date_filter |
15 | 13 | from error_manager import add_error |
16 | 14 | from sfapi_manager import monitor_sfapi_job |
17 | 15 | from state_manager import state |
|
23 | 21 | } |
24 | 22 |
|
25 | 23 |
|
| 24 | +def enable_amsc_x_api_key(config_dict): |
| 25 | + """ |
| 26 | + MLflow authentication helper for the AmSC MLflow server. |
| 27 | + Standard MLflow does not automatically inject custom headers like 'X-Api-Key'. |
| 28 | + This patches the http_request function to ensure every request to the server |
| 29 | + includes the AmSC API key. |
| 30 | +
|
| 31 | + See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch for more details. |
| 32 | + """ |
| 33 | + import mlflow.utils.rest_utils as rest_utils |
| 34 | + |
| 35 | + mlflow_cfg = config_dict.get("mlflow") or {} |
| 36 | + api_key_env = mlflow_cfg.get("api_key_env") |
| 37 | + if not api_key_env: |
| 38 | + title = "Unable to enable AmSC X-Api-Key authentication" |
| 39 | + msg = "MLFlow configuration is missing 'mlflow.api_key_env'" |
| 40 | + add_error(title, msg) |
| 41 | + print(msg) |
| 42 | + return |
| 43 | + |
| 44 | + api_key = os.environ.get(api_key_env) |
| 45 | + if not api_key: |
| 46 | + title = "Unable to enable AmSC X-Api-Key authentication" |
| 47 | + msg = f"Environment variable '{api_key_env}' in 'mlflow.api_key_env' is not set" |
| 48 | + add_error(title, msg) |
| 49 | + print(msg) |
| 50 | + return |
| 51 | + _orig = rest_utils.http_request |
| 52 | + |
| 53 | + def patched(host_creds, endpoint, method, *args, **kwargs): |
| 54 | + if "headers" in kwargs and kwargs["headers"] is not None: |
| 55 | + h = dict(kwargs["headers"]) |
| 56 | + h["X-Api-Key"] = api_key |
| 57 | + kwargs["headers"] = h |
| 58 | + else: |
| 59 | + h = dict(kwargs.get("extra_headers") or {}) |
| 60 | + h["X-Api-Key"] = api_key |
| 61 | + kwargs["extra_headers"] = h |
| 62 | + return _orig(host_creds, endpoint, method, *args, **kwargs) |
| 63 | + |
| 64 | + rest_utils.http_request = patched |
| 65 | + |
| 66 | + |
26 | 67 | class ModelManager: |
27 | | - def __init__(self, db): |
| 68 | + def __init__(self, config_dict, model_type_tag): |
28 | 69 | print("Initializing model manager...") |
29 | | - # Set initial default values |
30 | 70 | self.__model = None |
31 | 71 | self.__is_neural_network = False |
32 | 72 | self.__is_gaussian_process = False |
33 | 73 | self.__is_neural_network_ensemble = False |
| 74 | + self.__model_type_tag = model_type_tag |
34 | 75 |
|
35 | | - # Download model information from the database |
36 | | - collection = db["models"] |
37 | | - model_type_tag = model_type_tag_dict[state.model_type] |
38 | | - query = {"experiment": state.experiment, "model_type": model_type_tag} |
39 | | - count = collection.count_documents(query) |
40 | | - |
41 | | - if count == 0: |
42 | | - print( |
43 | | - f"No model found for experiment: {state.experiment} and model type: {model_type_tag}" |
44 | | - ) |
45 | | - return |
46 | | - elif count > 1: |
| 76 | + if "mlflow" not in config_dict or not config_dict["mlflow"].get("tracking_uri"): |
47 | 77 | print( |
48 | | - f"Multiple models found ({count}) for experiment: {state.experiment} and model type: {model_type_tag}!" |
| 78 | + f"No mlflow.tracking_uri in configuration file for {config_dict['experiment']}; cannot load model from MLflow." |
49 | 79 | ) |
50 | 80 | return |
51 | 81 |
|
52 | | - # Load model information from the database |
53 | | - document = collection.find_one(query) |
54 | | - # Save model files in a temporary directory, |
55 | | - # so that it can then be loaded with lume_model |
56 | | - with tempfile.TemporaryDirectory() as temp_dir: |
57 | | - # Open content of the top-level YAML file |
58 | | - yaml_file_content = document["yaml_file_content"] |
59 | | - model_filename = f"{state.experiment}.yml" |
60 | | - with open(os.path.join(temp_dir, model_filename), "w") as f: |
61 | | - f.write(yaml_file_content) |
| 82 | + mlflow.set_tracking_uri(config_dict["mlflow"]["tracking_uri"]) |
| 83 | + # When using the AmSC MLflow: inject the X-Api-Key into the requests to authenticate with the MLflow server |
| 84 | + # (See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch) |
| 85 | + if ( |
| 86 | + config_dict["mlflow"]["tracking_uri"] |
| 87 | + == "https://mlflow.american-science-cloud.org" |
| 88 | + ): |
| 89 | + enable_amsc_x_api_key(config_dict) |
62 | 90 |
|
63 | | - # Extract list of files to download |
64 | | - files_to_download = [] |
65 | | - if state.model_type == "Neural Network (ensemble)": |
66 | | - models_info = yaml.safe_load(yaml_file_content) |
67 | | - # Download yaml file for each model within the ensemble |
68 | | - for model in models_info["models"]: |
69 | | - yaml_file_name = model.replace("_model.jit", ".yml") |
70 | | - with open(os.path.join(temp_dir, yaml_file_name), "wb") as f: |
71 | | - f.write(document[yaml_file_name]) |
72 | | - model_info = yaml.safe_load(document[yaml_file_name]) |
73 | | - # Extract files to download |
74 | | - files_to_download += ( |
75 | | - [model_info["model"]] |
76 | | - + model_info["input_transformers"] |
77 | | - + model_info["output_transformers"] |
78 | | - ) |
79 | | - else: |
80 | | - # Extract files to download |
81 | | - model_info = yaml.safe_load(yaml_file_content) |
82 | | - files_to_download = ( |
83 | | - [model_info["model"]] |
84 | | - + model_info["input_transformers"] |
85 | | - + model_info["output_transformers"] |
86 | | - ) |
87 | | - |
88 | | - # Download all the files that define the model(s) |
89 | | - for filename in files_to_download: |
90 | | - with open(os.path.join(temp_dir, filename), "wb") as f: |
91 | | - f.write(document[filename]) |
| 91 | + experiment = config_dict["experiment"] |
| 92 | + model_name = f"{experiment}_{model_type_tag}" |
92 | 93 |
|
93 | | - # Check consistency of the model file |
94 | | - print("Reading model file...") |
95 | | - model_file = os.path.join(temp_dir, f"{state.experiment}.yml") |
96 | | - if not os.path.isfile(model_file): |
97 | | - title = f"Model file {model_file} not found" |
98 | | - msg = f"Unable to find the model file for {state.experiment}" |
99 | | - add_error(title, msg) |
100 | | - print(msg) |
101 | | - return |
102 | | - elif not verify_input_variables(model_file, state.experiment): |
103 | | - title = "Model file input variable mismatch" |
104 | | - msg = f"Model file {model_file} has different input variables than the configuration file for {state.experiment}" |
105 | | - add_error(title, msg) |
106 | | - print(msg) |
107 | | - return |
108 | | - |
109 | | - # Load model with lume_model |
110 | | - try: |
111 | | - if state.model_type == "Neural Network (single)": |
112 | | - self.__is_neural_network = True |
113 | | - self.__model = TorchModel(model_file) |
114 | | - elif state.model_type == "Neural Network (ensemble)": |
115 | | - self.__is_neural_network_ensemble = True |
116 | | - self.__model = NNEnsemble(model_file) |
117 | | - elif state.model_type == "Gaussian Process": |
118 | | - self.__is_gaussian_process = True |
119 | | - self.__model = GPModel.from_yaml(model_file) |
120 | | - else: |
121 | | - raise ValueError(f"Unsupported model type: {state.model_type}") |
122 | | - except Exception as e: |
123 | | - title = f"Unable to load model {state.model_type}" |
124 | | - msg = f"Error occurred when loading model: {e}" |
125 | | - add_error(title, msg) |
126 | | - print(msg) |
| 94 | + try: |
| 95 | + # Download model from MLflow server |
| 96 | + self.__model = ( |
| 97 | + mlflow.pyfunc.load_model(f"models:/{model_name}/latest") |
| 98 | + .unwrap_python_model() |
| 99 | + .model |
| 100 | + ) |
| 101 | + if model_type_tag == "NN": |
| 102 | + self.__is_neural_network = True |
| 103 | + elif model_type_tag == "ensemble_NN": |
| 104 | + self.__is_neural_network_ensemble = True |
| 105 | + elif model_type_tag == "GP": |
| 106 | + self.__is_gaussian_process = True |
| 107 | + else: |
| 108 | + raise ValueError(f"Unsupported model type: {model_type_tag}") |
| 109 | + except Exception as e: |
| 110 | + title = f"Unable to load model {model_type_tag}" |
| 111 | + msg = f"Error occurred when loading model from MLflow: {e}" |
| 112 | + add_error(title, msg) |
| 113 | + print(msg) |
127 | 114 |
|
128 | 115 | def avail(self): |
129 | 116 | print("Checking model availability...") |
@@ -153,22 +140,13 @@ def evaluate(self, parameters, output): |
153 | 140 | mean = output_dict[output] |
154 | 141 | mean_error = 0.0 # trick to collapse error range when lower/upper bounds are not predicted |
155 | 142 | elif self.__is_gaussian_process or self.__is_neural_network_ensemble: |
156 | | - if self.__is_gaussian_process: |
157 | | - # TODO use "exp" only once experimental data is available for all experiments |
158 | | - task_tag = "exp" if state.experiment == "bella-ip2" else "sim" |
159 | | - output_key = [key for key in output_dict.keys() if task_tag in key][ |
160 | | - 0 |
161 | | - ] |
162 | | - elif self.__is_neural_network_ensemble: |
163 | | - output_key = list(output_dict.keys())[0] |
164 | | - |
165 | 143 | # compute mean, standard deviation and mean error |
166 | 144 | # (call detach method to detach gradients from tensors) |
167 | | - mean = output_dict[output_key].mean.detach() |
168 | | - std_dev = output_dict[output_key].variance.sqrt().detach() |
| 145 | + mean = output_dict[output].mean.detach() |
| 146 | + std_dev = output_dict[output].variance.sqrt().detach() |
169 | 147 | mean_error = 2.0 * std_dev |
170 | 148 | else: |
171 | | - raise ValueError(f"Unsupported model type: {state.model_type}") |
| 149 | + raise ValueError(f"Unsupported model type: {self.__model_type_tag}") |
172 | 150 | # compute lower/upper bounds for error range |
173 | 151 | lower = mean - mean_error |
174 | 152 | upper = mean + mean_error |
|
0 commit comments