|
2 | 2 |
|
3 | 3 | """ |
4 | 4 | Script to handle the API code here. |
| 5 | +Be able to view the interactive API documentation, powered by Swagger UI, |
| 6 | +at http://localhost:8000/docs |
| 7 | +
|
| 8 | +Read-in of Person class instance is optional regarding original categorical features. |
| 9 | +
|
| 10 | +For production code, set debug on False. |
| 11 | +For general FastAPI information, see: |
| 12 | +https://fastapi.tiangolo.com/tutorial/ |
| 13 | +For application setup, see: |
| 14 | +https://fastapi.tiangolo.com/advanced/events/ |
| 15 | +For FastAPI beginner tutorial, start with: |
| 16 | +https://fastapi.tiangolo.com/tutorial/first-steps/ |
| 17 | +For advanced FastAPI example, see: |
| 18 | +https://github.com/microsoft/cookiecutter-spacy-fastapi/blob/master/%7B%7Bcookiecutter.project_slug%7D%7D/app/api.py |
| 19 | +For testing see: |
| 20 | +https://fastapi.tiangolo.com/tutorial/testing/ |
| 21 | +
|
| 22 | +future toDo: |
| 23 | +add a custom exception handler with @app.exception_handler() |
| 24 | +see: https://fastapi.tiangolo.com/tutorial/handling-errors/ |
| 25 | +
|
| 26 | +
|
5 | 27 | author: Ilona Brinkmeier |
6 | 28 | date: 2023-09 |
7 | 29 | """ |
|
12 | 34 |
|
13 | 35 | import logging |
14 | 36 | import uvicorn |
| 37 | +import signal |
| 38 | +import os |
| 39 | +import sys |
| 40 | +import yaml |
| 41 | +import numpy as np |
| 42 | +import pandas as pd |
| 43 | + |
| 44 | +# needed to run this script alone |
| 45 | +MAIN_DIR = os.path.join(os.getcwd(), 'src/') |
| 46 | +APP_DIR = os.path.join(MAIN_DIR, 'app/') |
| 47 | +sys.path.append(MAIN_DIR) |
| 48 | +sys.path.append(os.getcwd()) |
| 49 | +print(f'sys.path : {sys.path}') |
15 | 50 |
|
| 51 | +from typing import Optional, Any |
| 52 | +from contextlib import asynccontextmanager |
| 53 | +from fastapi import FastAPI, Body, HTTPException, Response, status |
| 54 | +from app.schemas import FeatureLabels, Person |
| 55 | +from training.ml.data import clean_data |
| 56 | +from training.ml.model import inference |
| 57 | +from config import get_config |
| 58 | +from slice_performance import load_transformer_artifact, load_final_model_artifact |
16 | 59 |
|
17 | 60 | ################### |
18 | 61 | # Coding |
|
22 | 65 | # info see: https://realpython.com/python-logging-source-code/ |
23 | 66 | logger = logging.getLogger(__name__) |
24 | 67 |
|
| 68 | +# variable to store artifacts names |
| 69 | +ml_components = {} |
| 70 | + |
| 71 | +# read in examples |
| 72 | +examples_file = os.path.join(APP_DIR, 'examples_request.yml') |
| 73 | +with open(examples_file) as f: |
| 74 | + examples_request = yaml.safe_load(f) |
| 75 | + |
| 76 | + |
| 77 | +# customised exception |
| 78 | +class InferenceNotPossible(HTTPException): |
| 79 | + ''' Raised if inference workflow went wrong ''' |
| 80 | + def __init__(self) -> None: |
| 81 | + super().__init__(status_code=404, detail="Client error: Inference not possible") |
| 82 | + |
| 83 | + |
| 84 | +# Define the signal handler function |
| 85 | +def graceful_shutdown(signum, frame) -> None: |
| 86 | + # Perform cleanup tasks here (closing db connections, saving state, ...); |
| 87 | + # e.g. has to be filled, if Person items are stored in a database |
| 88 | + |
| 89 | + # Finally, exit the application |
| 90 | + logger.warning("Shutting down the FastAPI US Census app") |
| 91 | + sys.exit(0) |
| 92 | + |
| 93 | + |
| 94 | +# Register the signal handler for SIGTERM |
| 95 | +signal.signal(signal.SIGTERM, graceful_shutdown) |
| 96 | + |
| 97 | + |
| 98 | +@asynccontextmanager |
| 99 | +async def lifespan(app: FastAPI) -> None: |
| 100 | + ''' Handles transformer and model artifacts for startup and shutdown. |
| 101 | + |
| 102 | + The coding before the yield will be executed before the application starts taking |
| 103 | + requests, during the startup. |
| 104 | + The coding after the yield will be executed after the application finishes handling requests, |
| 105 | + right before the shutdown. |
| 106 | + ''' |
| 107 | + try: |
| 108 | + logging.debug('Read in post-market transformer and model artifacts') |
| 109 | + # load ml components: feature transformer and classifier artifacts |
| 110 | + transformer_artifact = load_transformer_artifact() |
| 111 | + ml_components['transformer_artifact'] = transformer_artifact |
| 112 | + model_artifact = load_final_model_artifact() |
| 113 | + ml_components['model_artifact'] = model_artifact |
| 114 | + |
| 115 | + yield |
| 116 | + |
| 117 | + # clean up the ML components and release the resources |
| 118 | + logging.debug('Resource cleaning of transformer and model artifacts') |
| 119 | + ml_components.clear() |
| 120 | + except Exception as e: |
| 121 | + logger.exception("Exit: exception of type %s occurred. Details: %s", type(e).__name__, str(e)) |
| 122 | + else: |
| 123 | + txt = 'Handling of transformer and model artifacts was successful during lifespan of FastAPI app.' |
| 124 | + logger.debug(txt) |
| 125 | + |
| 126 | + |
| 127 | +app = FastAPI( |
| 128 | + title = "Udacity MLOps, Project 3 - Prediction Model for Public US Census Bureau Data", |
| 129 | + description = "Deploying a Binary Classification ML Model on Render with FastAPI; \ |
| 130 | + its inference is about having a salary <=50K or >50K", |
| 131 | + version = "0.1", |
| 132 | + lifespan=lifespan, |
| 133 | + debug = True |
| 134 | +) |
| 135 | + |
| 136 | + |
| 137 | +@app.get("/") |
| 138 | +async def root() -> Response: |
| 139 | + ''' Returns welcome message at root level ''' |
| 140 | + response = Response( |
| 141 | + status_code=status.HTTP_200_OK, |
| 142 | + content="Welcome to the Udacity MLOps project 3 and its salary prediction application!" |
| 143 | + ) |
| 144 | + return response |
| 145 | + |
| 146 | + |
| 147 | +@app.get("/feature_labels/{feature_name}") |
| 148 | +async def feature_labels(feature_name: FeatureLabels) -> Any: |
| 149 | + ''' Read-in feature values with original label from US census dataset ''' |
| 150 | + logging.info("Read-in of feature values from examples_request file started") |
| 151 | + feat_value = examples_request['features_labels'][feature_name] |
| 152 | + return feat_value |
| 153 | + |
| 154 | + |
| 155 | +@app.post("/predict/") |
| 156 | +async def predict(person: Person = Body(..., examples=examples_request['test_examples'])): |
| 157 | + ''' |
| 158 | + Returns prediction of test examples about income class, being <=50k or >50k, |
| 159 | + so having a proper response status number 200 in such cases. |
| 160 | + |
| 161 | + If only a few features are having a wrong value type, the model shall be able to handle |
| 162 | + this properly having an inference result of being an <=50k or >50k item as well. |
| 163 | + |
| 164 | + If most of the features are missing, a value error shall be thrown with response status number 422. |
| 165 | + ''' |
| 166 | + logging.info("Model classification inference started") |
| 167 | + person = person.dict() |
| 168 | + features = np.array( |
| 169 | + [person[f] for f in examples_request['features_labels'].keys()] |
| 170 | + ).reshape(1, -1) |
| 171 | + |
| 172 | + df = pd.DataFrame(features, columns=examples_request['features_labels'].keys()) |
| 173 | + df_cleaned = clean_data(df, get_config()) |
| 174 | + logger.info('Census cleaned new adult person data with %s features', |
| 175 | + df_cleaned.shape[1]) |
| 176 | + logger.info('Its columns are: %s', df_cleaned.columns) |
25 | 177 |
|
| 178 | + # cleaning inference case for person dataframe (X = df_cleaned), not training |
| 179 | + X_processed = ml_components['transformer_artifact'].transform(df_cleaned) |
| 180 | + # predict income class |
| 181 | + model = ml_components['model_artifact'] |
| 182 | + y_pred = inference(model, X_processed) |
| 183 | + logger.info('Predict post y_pred: %s', y_pred) |
| 184 | + if y_pred not in [0, 1]: |
| 185 | + raise InferenceNotPossible(HTTPException('US census prediction workflow error')) |
26 | 186 |
|
| 187 | + pred_class = '>50k' if y_pred == 1 else '<=50k' |
| 188 | + logger.info('income prediction label: %s, salary class: %s', y_pred[0], pred_class) |
27 | 189 |
|
| 190 | + content_txt = ''.join( |
| 191 | + ['income prediction label: ', str(y_pred[0]), |
| 192 | + ', ', |
| 193 | + 'salary class: ', pred_class] |
| 194 | + ) |
| 195 | + response = Response( |
| 196 | + status_code = status.HTTP_200_OK, |
| 197 | + content = content_txt, |
| 198 | + ) |
28 | 199 |
|
| 200 | + return response |
29 | 201 |
|
30 | 202 |
|
31 | 203 | if __name__ == "__main__": |
|
0 commit comments