This repository was archived by the owner on Apr 8, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathscore.py
More file actions
152 lines (122 loc) · 5.04 KB
/
score.py
File metadata and controls
152 lines (122 loc) · 5.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
TreeLite/Python inferencing script
"""
import os
import sys
import argparse
import logging
import time
import numpy as np
from distutils.util import strtobool
import treelite, treelite_runtime
# Add the right path to PYTHONPATH
# so that you can import from common.*
COMMON_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
if COMMON_ROOT not in sys.path:
logging.info(f"Adding {COMMON_ROOT} to PYTHONPATH")
sys.path.append(str(COMMON_ROOT))
# useful imports from common
from common.components import RunnableScript
from common.io import input_file_path, CustomLightGBMDataBatchIterator
class TreeLightInferencingScript(RunnableScript):
def __init__(self):
super().__init__(
task = "score",
framework = 'treelite_python',
framework_version = "PYTHON_API."+str(treelite.__version__)
)
@classmethod
def get_arg_parser(cls, parser=None):
"""Adds component/module arguments to a given argument parser.
Args:
parser (argparse.ArgumentParser): an argument parser instance
Returns:
ArgumentParser: the argument parser instance
Notes:
if parser is None, creates a new parser instance
"""
# add generic arguments
parser = RunnableScript.get_arg_parser(parser)
group_i = parser.add_argument_group("Input Data")
group_i.add_argument("--data",
required=True, type=input_file_path, help="Inferencing data location (file path)")
group_i.add_argument("--so_path",
required=False, default = "./mymodel.so" , help="full path to model so")
group_i.add_argument("--output",
required=False, default=None, type=str, help="Inferencing output location (file path)")
group_params = parser.add_argument_group("Scoring parameters")
group_params.add_argument("--num_threads",
required=False, default=1, type=int, help="number of threads")
group_params.add_argument("--batch_size",
required=False, default=0, type=int, help="size of batches for predict call")
return parser
def run(self, args, logger, metrics_logger, unknown_args):
"""Run script with arguments (the core of the component)
Args:
args (argparse.namespace): command line arguments provided to script
logger (logging.getLogger() for this script)
metrics_logger (common.metrics.MetricLogger)
unknown_args (list[str]): list of arguments not recognized during argparse
"""
# record relevant parameters
metrics_logger.log_parameters(
num_threads=args.num_threads,
batch_size=args.batch_size,
)
# make sure the output argument exists
if args.output:
os.makedirs(args.output, exist_ok=True)
# and create your own file inside the output
args.output = os.path.join(args.output, "predictions.txt")
logger.info(f"Loading model from {args.so_path}")
predictor = treelite_runtime.Predictor(
args.so_path,
verbose=True,
nthread=args.num_threads
)
# accumulate predictions and latencies
predictions = []
time_inferencing_per_batch = []
batch_lengths = []
# loop through batches
for batch in CustomLightGBMDataBatchIterator(args.data, batch_size=args.batch_size, file_format="csv").iter():
if len(batch) == 0:
break
batch_lengths.append(len(batch))
# transform into dense matrix for treelite
batch_data = np.array(batch)
batch_dmat = treelite_runtime.DMatrix(batch_data)
# run prediction on batch
batch_start_time = time.monotonic()
predictions.extend(predictor.predict(batch_dmat))
time_inferencing_per_batch.append((time.monotonic() - batch_start_time)) # usecs
# log overall time
metrics_logger.log_metric("time_inferencing", sum(time_inferencing_per_batch))
# use helper to log latency with the right metric names
metrics_logger.log_inferencing_latencies(
time_inferencing_per_batch,
batch_length=batch_lengths,
factor_to_usecs=1000000.0 # values are in seconds
)
if args.output:
np.savetxt(
args.output,
predictions,
fmt='%f',
delimiter=',',
newline='\n',
header='',
footer='',
comments='# ',
encoding=None
)
def get_arg_parser(parser=None):
""" To ensure compatibility with shrike unit tests """
return TreeLightInferencingScript.get_arg_parser(parser)
def main(cli_args=None):
""" To ensure compatibility with shrike unit tests """
TreeLightInferencingScript.main(cli_args)
if __name__ == "__main__":
main()