Skip to content

Commit d6cbf32

Browse files
committed
xception infer can resume from pred file
1 parent 76863ca commit d6cbf32

1 file changed

Lines changed: 15 additions & 3 deletions

File tree

examples/xception/infer.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
parser.add_argument("--batch_size", type=int, default=128)
1515
parser.add_argument("--subset", type=str, choices=["train", "val", "test", "testA", "testB"])
1616
parser.add_argument("--gpus", type=int, default=1)
17+
parser.add_argument("--resume", type=str, default=None)
1718
parser.add_argument("--take_num", type=int, default=None)
1819

1920
if __name__ == '__main__':
@@ -32,9 +33,21 @@
3233

3334
save_path = f"output/{args.model}_{args.subset}.txt"
3435
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
35-
with open(save_path, "w") as f:
36+
37+
processed_files = set()
38+
if args.resume is not None:
39+
with open(args.resume, "r") as f:
40+
for line in f:
41+
processed_files.add(line.split(";")[0])
42+
43+
with open(save_path, "a") as f:
3644
with torch.inference_mode():
37-
for i, (video, _, _) in enumerate(tqdm(test_dataset)):
45+
for i in tqdm(range(len(test_dataset))):
46+
file_name = test_dataset.metadata[i].file
47+
if file_name in processed_files:
48+
continue
49+
50+
video, _, _ = test_dataset[i]
3851
# batch video as frames use batch_size
3952
preds_video = []
4053
for j in range(0, len(video), args.batch_size):
@@ -45,6 +58,5 @@
4558
# choose the max prediction
4659
pred = preds_video.max().item()
4760

48-
file_name = test_dataset.metadata[i].file
4961
f.write(f"{file_name};{pred}\n")
5062
f.flush()

0 commit comments

Comments
 (0)