-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathtest_onnx.py
More file actions
216 lines (190 loc) · 11.7 KB
/
Copy pathtest_onnx.py
File metadata and controls
216 lines (190 loc) · 11.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#################################################################################
# Copyright (c) 2023-2026, Texas Instruments
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#################################################################################
import datetime
import os
import platform
from argparse import ArgumentParser
from logging import getLogger
import onnxruntime as ort
import pandas as pd
import torch
import torcheval
from tabulate import tabulate
from tinyml_tinyverse.common.datasets import GenericImageDataset
# Tiny ML TinyVerse Modules
from tinyml_tinyverse.common.utils import misc_utils, utils, mdcl_utils
from tinyml_tinyverse.common.utils.mdcl_utils import Logger
from tinyml_tinyverse.common.utils.utils import get_confusion_matrix
from ..common.train_base import shutdown_data_loaders
dataset_loader_dict = {'GenericImageDataset': GenericImageDataset}
def get_args_parser():
DESCRIPTION = "This script loads time series dataset and tests it against a onnx model using ONNX RT"
parser = ArgumentParser(description=DESCRIPTION)
parser.add_argument('--dataset', default='folder', help='dataset')
parser.add_argument('--dataset-loader', default='SimpleTSDataset', help='dataset loader')
parser.add_argument("--loader-type", default="regression", type=str,
help="Dataset Loader Type: classification/regression")
parser.add_argument('--annotation-prefix', default='instances', help='annotation-prefix')
parser.add_argument('--data-path', default=os.path.join('.', 'data', 'datasets'), help='dataset')
parser.add_argument('--output-dir', default=None, help='path where to save')
parser.add_argument('--model-path', default=None, help='ONNX model Path')
parser.add_argument('--gpus', default=1, type=int, help='number of gpus')
parser.add_argument('-j', '--workers', default=0 if platform.system() in ['Windows'] else 16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--date', default=datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), help='current date')
parser.add_argument('--seed', default=42, help="Seed for all randomness", type=int)
parser.add_argument('--lis', help='Log File', type=str,)# default=ops(opb(__file__))[0] + ".lis")
parser.add_argument('--DEBUG', action='store_true', help='Log mode set to DEBUG')
# Training parameters
parser.add_argument("--distributed", default=None, type=misc_utils.str2bool_or_none,
help="use dstributed training even if this script is not launched using torch.disctibuted.launch or run")
parser.add_argument('--device', default='cuda', help='device')
parser.add_argument('-b', '--batch-size', default=1024, type=int)
# Feature Extraction Params
parser.add_argument('--data-proc-transforms', help="Data Preprocessing transforms ", default=[]) # default=['DownSample', 'SimpleWindow'])
parser.add_argument('--feat-ext-transform', help="Feature Extraction transforms ", default=[])
# Vision Related Params
parser.add_argument('--variables', help="1- if Univariate, 2/3/.. if multivariate")
parser.add_argument('--image-height', help="Image dimension(Height)")
parser.add_argument('--image-width', help="Image dimension(Width)")
parser.add_argument('--image-mean', help="Average pixel intensity of dataset computed per channel")
parser.add_argument('--image-scale', help="Standard deviation of pixel intensities per channel")
parser.add_argument('--image-num-channel', help="Number of channels( RGB=3, Greyscale=1) present in the image")
parser.add_argument('--generic-model', help="Open Source models", type=misc_utils.str_or_bool, default=False)
parser.add_argument("--nn-for-feature-extraction", default=False, type=misc_utils.str2bool, help="Use an AI model for preprocessing")
parser.add_argument("--output-int", default=None, type=misc_utils.str_or_bool, help="Get quantized int8 output from model (False for dequantized float output). If not specified, determined automatically based on task type and quantization level.")
return parser
def main(gpu, args):
transform = None
if not args.output_dir:
output_folder = os.path.basename(os.path.split(args.data_path)[0])
args.output_dir = os.path.join('.', 'data', 'checkpoints', 'classification', output_folder, args.model, args.date)
utils.mkdir(args.output_dir)
log_file = os.path.join(args.output_dir, 'run.log')
logger = Logger(log_file=args.lis or log_file, DEBUG=args.DEBUG, name="root", append_log=True, console_log=True)
# logger = command_display(args.lis or log_file, args.DEBUG)
utils.seed_everything(args.seed)
from tinyml_tinyverse.version import get_version_str
logger.info(f"TinyVerse Toolchain Version: {get_version_str()}")
logger.info("Script: {}".format(os.path.relpath(__file__)))
utils.init_distributed_mode(args)
logger.debug("Args: {}".format(args))
device = torch.device(args.device)
# torch.backends.cudnn.benchmark = True
if isinstance(args.data_proc_transforms, list):
if len(args.data_proc_transforms) and isinstance(args.data_proc_transforms[0], list):
args.transforms = args.data_proc_transforms[0] + args.feat_ext_transform # args.data_proc_transforms is a list of lists
else:
args.transforms = args.data_proc_transforms + args.feat_ext_transform
dataset, dataset_test, train_sampler, test_sampler = utils.load_data(args.data_path, args, dataset_loader_dict, test_only=True) # (126073, 1, 152), 126073
num_classes = len(dataset.classes)
logger.info("Loading data:")
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True,
collate_fn=utils.collate_fn)
# data_loader_test = torch.utils.data.DataLoader(
# dataset_test, batch_size=args.batch_size,
# sampler=test_sampler, num_workers=args.workers, pin_memory=True,
# collate_fn=utils.collate_fn, )
logger.info(f"Loading ONNX model: {args.model_path}")
if not args.generic_model:
utils.decrypt(args.model_path, utils.get_crypt_key())
ort_sess = ort.InferenceSession(args.model_path)
if not args.generic_model:
utils.encrypt(args.model_path, utils.get_crypt_key())
input_name = ort_sess.get_inputs()[0].name
output_name = ort_sess.get_outputs()[0].name
predicted = torch.tensor([]).to(device, non_blocking=True)
ground_truth = torch.tensor([]).to(device, non_blocking=True)
for batched_raw_data, batched_data, batched_target in data_loader:
batched_raw_data = batched_raw_data.to(device, non_blocking=True).long()
batched_data = batched_data.to(device, non_blocking=True).float()
batched_target = batched_target.to(device, non_blocking=True).long()
if transform:
batched_data = transform(batched_data)
if args.nn_for_feature_extraction:
for data in batched_raw_data:
predicted = torch.cat((predicted, torch.tensor(ort_sess.run([output_name], {input_name: data.unsqueeze(0).cpu().numpy().astype(np.float32)})[0]).to(device)))
else:
for data in batched_data:
predicted = torch.cat((predicted, torch.tensor(ort_sess.run([output_name], {input_name: data.unsqueeze(0).cpu().numpy()})[0]).to(device)))
ground_truth = torch.cat((ground_truth, batched_target))
try:
mdcl_utils.create_dir(os.path.join(args.output_dir, 'post_training_analysis'))
logger.info("Plotting OvR Multiclass ROC score")
utils.plot_multiclass_roc(ground_truth, predicted, os.path.join(args.output_dir, 'post_training_analysis'),
label_map=dataset.inverse_label_map, phase='test')
logger.info("Plotting Class difference scores")
utils.plot_pairwise_differenced_class_scores(ground_truth, predicted, os.path.join(args.output_dir, 'post_training_analysis'),
label_map=dataset.inverse_label_map, phase='test')
except Exception as e:
logger.warning(f"Post Training Analysis plots will not be generated because: {e}")
metric = torcheval.metrics.MulticlassAccuracy()
# predicted = torch.argmax(predicted, dim=1)
metric.update(predicted, ground_truth)
logger = getLogger("root.main.test_data")
logger.info(f"Test Data Evaluation Accuracy: {metric.compute() * 100:.2f}%")
try:
logger.info(
f"Test Data Evaluation AUC ROC Score: {utils.get_au_roc(predicted, ground_truth, num_classes):.3f}")
except ValueError as e:
logger.warning("Not able to compute AUC ROC. Error: " + str(e))
if len(torch.unique(ground_truth)) == 1:
logger.warning("Confusion Matrix can not be printed because only items of 1 class was present in test data")
else:
try:
confusion_matrix = get_confusion_matrix(predicted, ground_truth.type(torch.int64),
num_classes).cpu().numpy()
logger.info('Confusion Matrix:\n {}'.format(tabulate(pd.DataFrame(
confusion_matrix, columns=[f"Predicted as: {x}" for x in dataset.inverse_label_map.values()],
index=[f"Ground Truth: {x}" for x in dataset.inverse_label_map.values()]), headers="keys", tablefmt='grid')))
except ValueError as e:
logger.warning("Not able to compute Confusion Matrix. Error: " + str(e))
shutdown_data_loaders(data_loader)
return
def run(args):
if args.device != 'cpu' and args.distributed is True:
# for explanation of what is happening here, please see this:
# https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html
# this assignment of RANK assumes a single machine, but with multiple gpus
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = str(args.gpus)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args,))
else:
main(0, args)
if __name__ == "__main__":
arguments = get_args_parser().parse_args()
# run the training.
# if args.distributed is True is set, then this will launch distributed training
# depending on args.gpus
run(arguments)