-
Notifications
You must be signed in to change notification settings - Fork 7k
Expand file tree
/
Copy pathinference.py
More file actions
120 lines (94 loc) · 3.64 KB
/
inference.py
File metadata and controls
120 lines (94 loc) · 3.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from io import StringIO
import os
import json
import flask
import joblib
import numpy as np
import pandas as pd
import xgboost as xgb
from flask import Flask, Response, Request
import csv
app = Flask(__name__)
model = None
MODEL_PATH = "/opt/ml/model"
def load_model():
"""
Load the XGBoost model from the specified MODEL_PATH.
Returns:
xgb.Booster: The loaded XGBoost model.
"""
xgb_model_path = os.path.join(MODEL_PATH, "xgboost-model")
# Load the model from the file
with open(xgb_model_path, "rb") as f:
model = joblib.load(f)
return model
def preprocess(input_data, content_type):
"""
Preprocess the input data and convert it into an XGBoost DMatrix.
Args:
input_data (str): The input data as a string (CSV format).
content_type (str): The content type of the input data (expected: "text/csv; charset=utf-8").
Returns:
xgb.DMatrix: The preprocessed data in XGBoost DMatrix format.
"""
if content_type == "text/csv; charset=utf-8":
df = pd.read_csv(StringIO(input_data), header=None)
data = xgb.DMatrix(data=df)
return data
def predict(input_data):
"""
Make predictions using the preprocessed input data.
Args:
input_data (xgb.DMatrix): The preprocessed data in XGBoost DMatrix format.
Returns:
list: A list of predictions or an empty list if there's an error.
"""
try:
# Load the model
model = load_model()
# Make predictions using the input data
predictions = model.predict(input_data)
# Convert predictions (numpy array) to a list and return
return predictions.tolist()
except Exception as e:
# Log the exception and return an empty list in case of an error
print(f"Error while making predictions: {e}", flush=True)
return []
@app.route("/ping", methods=["GET"])
def ping():
"""
Check the health of the model server by verifying if the model is loaded.
Returns a 200 status code if the model is loaded successfully, or a 500
status code if there is an error.
Returns:
flask.Response: A response object containing the status code and mimetype.
"""
model = load_model()
status = 200 if model is not None else 500
return flask.Response(response="\n", status=status, mimetype="application/json")
@app.route("/invocations", methods=["POST"])
def invocations():
"""
Handle prediction requests by preprocessing the input data, making predictions,
and returning the predictions as a JSON object.
This function checks if the request content type is supported (text/csv; charset=utf-8),
and if so, decodes the input data, preprocesses it, makes predictions, and returns
the predictions as a JSON object. If the content type is not supported, a 415 status
code is returned.
Returns:
flask.Response: A response object containing the predictions, status code, and mimetype.
"""
print(f"Predictor: received content type: {flask.request.content_type}")
if flask.request.content_type == "text/csv; charset=utf-8":
input = flask.request.data.decode("utf-8")
transformed_data = preprocess(input, flask.request.content_type)
predictions = predict(transformed_data)
# Return the predictions as a JSON object
return json.dumps({"result": predictions})
else:
print(f"Received: {flask.request.content_type}", flush=True)
return flask.Response(
response=f"XGBPredictor: This predictor only supports CSV data; Received: {flask.request.content_type}",
status=415,
mimetype="text/plain",
)