-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
121 lines (98 loc) · 3.61 KB
/
app.py
File metadata and controls
121 lines (98 loc) · 3.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from base64 import b64decode
import base64
import io
import asyncio
import os
# from image_check.has_human import has_human
from fastapi.exceptions import RequestValidationError
from concurrent.futures import ThreadPoolExecutor
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import torch, gc
from typing import Annotated,Union
from PIL import Image
from fastapi import FastAPI, File, HTTPException, Form
from fastapi.responses import Response, JSONResponse
from starlette.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import requests
import cv2
import numpy as np
from iharm.inference.evaluation import evaluate
from iharm.inference.predictor import Predictor
from iharm.inference.utils import load_model, find_checkpoint
import torch
cors_options = {
"allow_methods": ["*"],
"allow_headers": ["*"],
"allow_credentials": True,
"allow_origins": [
"http://localhost:3000",
"http://localhost",
"https://dev-app.photio.io",
"https://app.photio.io",
"https://cafe24.photio.io"
],
}
app = FastAPI()
app.add_middleware(CORSMiddleware, **cors_options)
@app.get("/")
def read_root():
return {"Hello": "World!"}
@app.get("/ping")
def read_root():
return "pong"
def decode_base64_to_image(encoding: str) -> Image.Image:
if encoding.startswith("http://") or encoding.startswith("https://"):
try:
response = requests.get(encoding, timeout=30, verify=False)
return Image.open(io.BytesIO(response.content))
except requests.exceptions.Timeout as e:
raise HTTPException(status_code=408) from e
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=422) from e
except Exception as e:
raise HTTPException(status_code=422) from e
else:
if encoding.startswith("data:"):
encoding = encoding.split(";")[1].split(",")[1]
try:
im_bytes = base64.b64decode(encoding)
im_arr = np.frombuffer(im_bytes, dtype=np.uint8) # im_arr is one-dim Numpy array
img = cv2.imdecode(im_arr, flags=cv2.IMREAD_COLOR)
return img
except Exception as e:
raise HTTPException(status_code=422) from e
class HarmonizeRequest(BaseModel):
image:str = Field(
None,
title="Image",
description="base64 or url",
)
mask:str = Field(
None,
title="Image",
description="base64 or url",
)
device = torch.device(0)
checkpoint_path = find_checkpoint('', './pretrained_models/PCTNet_ViT.pth')
net = load_model('ViT_pct', checkpoint_path, verbose=False)
use_attn = False
normalization ={'mean': [0,0,0], 'std':[1,1,1]}
predictor = Predictor(net, device, with_flip=False, hsv=False, use_attn=use_attn,
mean=normalization['mean'], std=normalization['std'])
@app.post("/harmonize")
async def harmonize(
request:HarmonizeRequest
):
try:
image = decode_base64_to_image(request.image)
mask = decode_base64_to_image(request.mask)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pred_fullres = evaluate(image,mask,predictor)
# cv2.imwrite(os.path.join('/home/ubuntu/PCT-Net-Image-Harmonization/output', 'harmonized.png'), pred_fullres[:,:,::-1])
_, im_arr = cv2.imencode('.png', pred_fullres) # im_arr: image in Numpy one-dim array format.
im_bytes = im_arr.tobytes()
im_b64 = base64.b64encode(im_bytes)
return {"result": im_b64}
except RuntimeError as e:
raise HTTPException(status_code=422, detail=f"error occur: {e}") from e