-
Notifications
You must be signed in to change notification settings - Fork 159
Expand file tree
/
Copy pathapi_core.py
More file actions
81 lines (63 loc) · 2.57 KB
/
api_core.py
File metadata and controls
81 lines (63 loc) · 2.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
from PIL import PngImagePlugin, Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from src.core import core_generation_funnel
from src import backbone
from src.api.api_constants import api_defaults, api_forced, models_to_index
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
# TODO check that internally we always use png.
def encode_pil_to_base64(image, image_type='png'):
with BytesIO() as output_bytes:
if image_type == 'png':
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None))
else:
raise HTTPException(status_code=500, detail="Invalid image format")
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data)
def encode_to_base64(image):
if type(image) is str:
return image
elif type(image) is Image.Image:
return encode_pil_to_base64(image)
elif type(image) is np.ndarray:
return encode_np_to_base64(image)
else:
return ""
def encode_np_to_base64(image):
pil = Image.fromarray(image)
return encode_pil_to_base64(pil)
def to_base64_PIL(encoding: str):
return Image.fromarray(np.array(decode_base64_to_image(encoding)).astype('uint8'))
def api_gen(input_images, client_options):
default_options = api_defaults.copy()
#TODO try-catch type errors here
for key, value in client_options.items():
if key == "model_type":
default_options[key.upper()] = models_to_index[value]
continue
default_options[key.upper()] = value
for key, value in api_forced.items():
default_options[key] = value
print(f"Processing {str(len(input_images))} images through the API")
print(default_options)
pil_images = []
for input_image in input_images:
pil_images.append(to_base64_PIL(input_image))
outpath = backbone.get_outpath()
gen_obj = core_generation_funnel(outpath, pil_images, None, None, default_options)
return gen_obj