|
| 1 | +from pathlib import Path |
| 2 | +from typing import Any |
| 3 | + |
| 4 | +import requests |
| 5 | + |
| 6 | +MODEL_URLS = { |
| 7 | + "recon3d": ("http://bigg.ucsd.edu/api/v2/models/Recon3D/download", "json"), |
| 8 | + "ratgem": ( |
| 9 | + "https://github.com/SysBioChalmers/Rat-GEM/raw/refs/heads/main/model/Rat-GEM.mat", |
| 10 | + "mat", |
| 11 | + ), |
| 12 | +} |
| 13 | +SUPPORTED_MODELS = MODEL_URLS.keys() |
| 14 | + |
| 15 | + |
| 16 | +def ratgem_processing(response: requests.Response) -> Any: |
| 17 | + """ |
| 18 | + Parse request containing Rat-GEM .mat file. |
| 19 | + :param response: Response containing Rat-GEM model .mat binary |
| 20 | + :type response: requests.Response |
| 21 | + :return: Binary (.mat) representation of Rat-GEM |
| 22 | + :rtype: Any | Bytes |
| 23 | + """ |
| 24 | + return response.content |
| 25 | + |
| 26 | + |
| 27 | +def recon_processing(response: requests.Response) -> str: |
| 28 | + """ |
| 29 | + Parse request containing Recon3D JSON |
| 30 | + :param response: Response containing Recon3D JSON |
| 31 | + :type response: requests.Response |
| 32 | + :return: String representation of Recon3D JSON |
| 33 | + :rtype: str |
| 34 | + """ |
| 35 | + return response.text.replace("_AT", ".") |
| 36 | + |
| 37 | + |
| 38 | +processing = { |
| 39 | + "recon3d": recon_processing, |
| 40 | + "ratgem": ratgem_processing, |
| 41 | +} |
| 42 | + |
| 43 | + |
| 44 | +# TODO: make singleton |
| 45 | +class ModelManager: |
| 46 | + """ |
| 47 | + ModelManager class in charge of management of auto-downloading, storing and retrieving model files |
| 48 | + """ |
| 49 | + |
| 50 | + def __init__(self): |
| 51 | + """ |
| 52 | + Initialize ModelManager |
| 53 | + """ |
| 54 | + self.model_files = { |
| 55 | + name: f"{name}.{MODEL_URLS[name][1]}" for name in MODEL_URLS.keys() |
| 56 | + } |
| 57 | + self.model_path = Path("./models") |
| 58 | + self.model_file_paths = { |
| 59 | + name: self.model_path / model_file |
| 60 | + for name, model_file in self.model_files.items() |
| 61 | + } |
| 62 | + self.model_path.mkdir(exist_ok=True) |
| 63 | + self.managed_models_str = ", ".join(SUPPORTED_MODELS) |
| 64 | + |
| 65 | + def get_model(self, model: str) -> Path: |
| 66 | + """ |
| 67 | + Retrieve a given model by its name |
| 68 | + :param model: Name given to the model (see allowed models above) |
| 69 | + :type model: str |
| 70 | + :return: Path to model file |
| 71 | + :rtype: Path |
| 72 | + """ |
| 73 | + if not self.model_file_paths[model].exists(): |
| 74 | + self.download_model(model) |
| 75 | + return self.model_file_paths[model] |
| 76 | + |
| 77 | + def download_model(self, model: str): |
| 78 | + """ |
| 79 | + Download a given model by its name from a hardcoded URL |
| 80 | + :param model: Name given to the model (see allowed models above) |
| 81 | + :type model: str |
| 82 | + :raises ValueError: Raised if the model name is unknown |
| 83 | + :raises requests.HTTPError: Raised if the download fails |
| 84 | + """ |
| 85 | + if not model in MODEL_URLS.keys(): |
| 86 | + raise ValueError( |
| 87 | + f"Illegal model: {model} . Supported models are: {self.managed_models_str}" |
| 88 | + ) |
| 89 | + try: |
| 90 | + response = requests.get(MODEL_URLS[model][0]) |
| 91 | + response.raise_for_status() |
| 92 | + except requests.HTTPError as error: |
| 93 | + raise requests.HTTPError(f"Failed to download {model}") from error |
| 94 | + |
| 95 | + content = processing[model](response) |
| 96 | + save_file = self.model_file_paths[model] |
| 97 | + mode = "wb" if save_file.suffix == ".mat" else "w" |
| 98 | + with open(save_file, mode) as f: |
| 99 | + f.write(content) |
| 100 | + |
| 101 | + def get_managed_models(self) -> list[str]: |
| 102 | + """ |
| 103 | + Return a list of the names of supported models |
| 104 | + :return: List of supported models' names |
| 105 | + :rtype: list[str] |
| 106 | + """ |
| 107 | + return SUPPORTED_MODELS |
| 108 | + |
| 109 | + def get_managed_models_str(self) -> str: |
| 110 | + """ |
| 111 | + Return all supported models in a string representation |
| 112 | + :return: String of supported models |
| 113 | + :rtype: str |
| 114 | + """ |
| 115 | + return self.managed_models_str |
| 116 | + |
| 117 | + def wipe(self): |
| 118 | + """ |
| 119 | + Wipes all previously downloaded model files and deletes the models folder |
| 120 | + :raises OSError: Raised if any file or the model folder cannot be deleted |
| 121 | + """ |
| 122 | + try: |
| 123 | + for model in self.model_file_paths.values(): |
| 124 | + if not model.exists(): |
| 125 | + continue |
| 126 | + model.unlink() |
| 127 | + self.model_path.rmdir() |
| 128 | + except Exception as error: |
| 129 | + raise OSError( |
| 130 | + "Could not delete model files, please delete manually in 'models' folder." |
| 131 | + ) from error |
0 commit comments