-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
110 lines (82 loc) · 2.98 KB
/
inference.py
File metadata and controls
110 lines (82 loc) · 2.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
## IMPORTS
import os, argparse, cv2
from model import MotivNet
from torch.utils.data import DataLoader, Dataset
import lightning as L
import torch
class ImageDataset(Dataset):
def __init__(self, folder_path):
self.image_paths = [os.path.join(folder_path, img) for img in os.listdir(folder_path)]
self.transform_mean=torch.tensor([123.5, 116.5, 103.5])
self.transform_std=torch.tensor([58.5, 57.0, 57.5])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = cv2.imread(self.image_paths[idx])
image = cv2.resize(image, (768,768), interpolation=cv2.INTER_LINEAR)
padding = (1024-768) // 2
image = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=(255, 255, 255))
image = image.transpose(2, 0, 1)
image = torch.from_numpy(image)
image = image[[2, 1, 0], ...].float()
m = self.transform_mean.view(-1, 1, 1)
s = self.transform_std.view(-1, 1, 1)
image = (image - m) / s
return self.image_paths[idx], image
if __name__ == "__main__":
# parse cli arguments
parser = argparse.ArgumentParser(
prog="MotivNet",
description="Enhancing Meta-Sapiens as a emotionally intelligent foundational model",
epilog="Inference script for MotivNet",
)
parser.add_argument(
"--input",
type=str,
default="",
help="Path to folder containing inference images",
)
parser.add_argument(
"--GPUS",
type=int,
default=1,
help="Number of GPUs to use for inference",
)
parser.add_argument(
"--batch_size",
type=int,
default=8,
help="Batch size for inference",
)
parser.add_argument(
"--top_k",
type=int,
default=1,
help="Number of top predictions to return",
)
args = parser.parse_args()
if args.input == "":
raise ValueError("Please provide a path to the inference data")
if not os.path.exists(args.input):
raise FileNotFoundError("The provided path does not exist")
# load data
print("Loading data...")
print("Input path:", args.input)
print(str(len(os.listdir(args.input))) + " file(s) found")
dataset = ImageDataset(args.input)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
# load model
print("Loading model...")
model = MotivNet(7,0,0)
model.load_state_dict(torch.load(os.path.join(os.getcwd(), "checkpoints", "MotivNet.pth"), weights_only=True))
trainer = L.Trainer(devices=args.GPUS)
# inference
print("Running inference...")
preds = trainer.predict(model, dataloader)
print("Predictions:")
file_names, preds = preds[0]
results = []
for file_name, pred in zip(file_names, preds):
top_k_indices = torch.topk(pred, args.top_k).indices.tolist()
results.append((file_name, top_k_indices))
print(results)