This repository was archived by the owner on May 12, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
41 lines (37 loc) · 1.98 KB
/
predict.py
File metadata and controls
41 lines (37 loc) · 1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import cv2
import argparse
import numpy as np
from gan_ce.network import Network
# Define arguments with there default values
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--image", required=True, help="Path to the image.")
ap.add_argument("-m", "--mask", required=True, help="Path to the mask.")
ap.add_argument("-o", "--output", required=True, help="Path to save prediction.")
ap.add_argument("-t", "--tiles", required=False, default=(2, 2), help="How many tiles should the picture be divided into (default=(2,2))).")
ap.add_argument("-sh", "--shape", required=False, default=(256, 256, 3), help="Define the shape of a tile (default=(256,256,3))).")
ap.add_argument("-w", "--weights", required=False, default='./weights/weights.ckpt', help="Path to the weights (default='./weights/weights.ckpt').")
args = vars(ap.parse_args())
# Verify the passed parameters
if not isinstance(args["tiles"], tuple) or len(args["tiles"]) != 2:
raise Exception("Tiles parameter is invalid. Should be something like '(2,2)'.")
if not isinstance(args["shape"], tuple) or len(args["shape"]) != 3:
raise Exception("Shape parameter is invalid. Should be something like '(256,256,3)'.")
if not os.path.isdir(os.path.dirname(args["weights"])):
raise Exception("Path to weights is invalid.")
if not os.path.isfile(args["image"]):
raise Exception("Path to image is invalid.")
if not os.path.isfile(args["mask"]):
raise Exception("Path to mask is invalid.")
# Load the image to inpaint
image = cv2.imread(args["image"], 3)
# Load the mask to inpaint and norm it to [0,0,0] -> [1,1,1]
mask = cv2.imread(args["mask"], 3)
mask[np.where((mask != [0, 0, 0]).all(axis=2))] = [1, 1, 1]
# Initalize the GAN (Context Encoder(Generator) and Discriminator)
network = Network(tiles=args["tiles"], shape=args["shape"])
# Load the weights
network.load_weights_generator(weights_path=args["weights"])
# Start prediction and save the results
prediction = network.predict(image, mask)
cv2.imwrite(args["output"], prediction)