Skip to content

Commit f35e4d1

Browse files
black-elevenyihuiwen
andauthored
support openai interface (#1037)
Co-authored-by: yihuiwen <yihuiwen@sensetime.com>
1 parent fe8e642 commit f35e4d1

3 files changed

Lines changed: 372 additions & 0 deletions

File tree

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
import asyncio
2+
import base64
3+
import re
4+
import time
5+
import uuid
6+
from pathlib import Path
7+
from typing import Literal, Optional
8+
9+
from fastapi import APIRouter, File, Form, HTTPException, Request, UploadFile
10+
from loguru import logger
11+
from pydantic import BaseModel, Field
12+
13+
from ..schema import ImageTaskRequest
14+
from ..task_manager import TaskStatus, task_manager
15+
from .deps import get_services
16+
17+
router = APIRouter()
18+
19+
_SIZE_PATTERN = re.compile(r"^\s*(\d+)\s*x\s*(\d+)\s*$", re.IGNORECASE)
20+
21+
22+
class OpenAIImageGenerationRequest(BaseModel):
23+
prompt: str = Field(..., description="Text prompt")
24+
model: Optional[str] = Field(default=None, description="Ignored for compatibility")
25+
n: int = Field(default=1, description="Number of images, currently only supports 1")
26+
size: Optional[str] = Field(default=None, description="Image size, e.g. 1024x1024")
27+
response_format: Literal["b64_json"] = Field(default="b64_json")
28+
user: Optional[str] = Field(default=None, description="Ignored for compatibility")
29+
seed: Optional[int] = Field(default=None, description="Optional random seed")
30+
31+
32+
class OpenAIImageResponse(BaseModel):
33+
created: int
34+
data: list[dict[str, str]]
35+
36+
37+
def _write_file_sync(file_path: Path, content: bytes) -> None:
38+
with open(file_path, "wb") as buffer:
39+
buffer.write(content)
40+
41+
42+
def _shape_from_size(size: str) -> tuple[int, int]:
43+
match = _SIZE_PATTERN.match(size)
44+
if not match:
45+
raise ValueError("size must be in WxH format, e.g. 1024x1024")
46+
width = int(match.group(1))
47+
height = int(match.group(2))
48+
if width <= 0 or height <= 0:
49+
raise ValueError("size width and height must be positive")
50+
return width, height
51+
52+
53+
async def _wait_task_result_png(task_id: str, timeout_seconds: int, poll_interval_seconds: float) -> bytes:
54+
start_time = time.monotonic()
55+
while True:
56+
task_status = task_manager.get_task_status(task_id)
57+
if not task_status:
58+
raise HTTPException(status_code=500, detail=f"Task status not found: {task_id}")
59+
60+
status = task_status.get("status")
61+
if status == TaskStatus.COMPLETED.value:
62+
result_png = task_manager.get_task_result_png(task_id)
63+
if result_png:
64+
return result_png
65+
raise HTTPException(status_code=500, detail=f"Task completed but no in-memory image found: {task_id}")
66+
67+
if status == TaskStatus.FAILED.value:
68+
raise HTTPException(status_code=500, detail=task_status.get("error", "Task failed"))
69+
70+
if status == TaskStatus.CANCELLED.value:
71+
raise HTTPException(status_code=409, detail=task_status.get("error", "Task cancelled"))
72+
73+
if (time.monotonic() - start_time) > timeout_seconds:
74+
task_manager.cancel_task(task_id)
75+
raise HTTPException(status_code=504, detail=f"Task {task_id} timed out after {timeout_seconds} seconds")
76+
77+
await asyncio.sleep(poll_interval_seconds)
78+
79+
80+
async def _watch_client_disconnect(request: Request, task_id: str, poll_interval_seconds: float = 0.2) -> bool:
81+
while True:
82+
if await request.is_disconnected():
83+
task_manager.cancel_task(task_id)
84+
logger.info(f"Client disconnected, task {task_id} cancelled")
85+
return True
86+
await asyncio.sleep(poll_interval_seconds)
87+
88+
89+
async def _run_sync_image_task(request: Request, message: ImageTaskRequest) -> bytes:
90+
task_id = None
91+
timeout_seconds = 600
92+
poll_interval_seconds = 0.5
93+
94+
try:
95+
message.prefer_memory_result = True
96+
task_id = task_manager.create_task(message)
97+
message.task_id = task_id
98+
99+
wait_task = asyncio.create_task(_wait_task_result_png(task_id, timeout_seconds, poll_interval_seconds))
100+
disconnect_task = asyncio.create_task(_watch_client_disconnect(request, task_id))
101+
102+
done, pending = await asyncio.wait({wait_task, disconnect_task}, return_when=asyncio.FIRST_COMPLETED)
103+
for pending_task in pending:
104+
pending_task.cancel()
105+
await asyncio.gather(*pending, return_exceptions=True)
106+
107+
if disconnect_task in done and disconnect_task.result():
108+
if not wait_task.done():
109+
wait_task.cancel()
110+
await asyncio.gather(wait_task, return_exceptions=True)
111+
raise HTTPException(status_code=499, detail=f"Client disconnected, task {task_id} cancelled")
112+
113+
return wait_task.result()
114+
except RuntimeError as e:
115+
raise HTTPException(status_code=503, detail=str(e))
116+
except HTTPException:
117+
raise
118+
except asyncio.CancelledError:
119+
if task_id:
120+
task_manager.cancel_task(task_id)
121+
raise
122+
except Exception as e:
123+
logger.error(f"Failed to run OpenAI-compatible image task: {e}")
124+
raise HTTPException(status_code=500, detail=str(e))
125+
126+
127+
def _build_url_response(request: Request, task_id: str, image_bytes: bytes) -> str:
128+
services = get_services()
129+
assert services.file_service is not None, "File service is not initialized"
130+
131+
file_name = f"{task_id}.png"
132+
output_path = services.file_service.output_video_dir / file_name
133+
output_path.parent.mkdir(parents=True, exist_ok=True)
134+
_write_file_sync(output_path, image_bytes)
135+
136+
base = str(request.base_url).rstrip("/")
137+
return f"{base}/v1/files/download/{file_name}"
138+
139+
140+
def _build_openai_response(request: Request, task_id: str, image_bytes: bytes, response_format: Literal["url", "b64_json"]):
141+
if response_format == "b64_json":
142+
return OpenAIImageResponse(created=int(time.time()), data=[{"b64_json": base64.b64encode(image_bytes).decode("utf-8")}])
143+
144+
return OpenAIImageResponse(created=int(time.time()), data=[{"url": _build_url_response(request, task_id, image_bytes)}])
145+
146+
147+
def _build_image_task_request(
148+
prompt: str,
149+
*,
150+
negative_prompt: str = "",
151+
seed: Optional[int] = None,
152+
target_shape: Optional[list[int]] = None,
153+
image_path: str = "",
154+
image_mask_path: str = "",
155+
) -> ImageTaskRequest:
156+
payload = {
157+
"prompt": prompt,
158+
"negative_prompt": negative_prompt,
159+
"image_path": image_path,
160+
"image_mask_path": image_mask_path,
161+
}
162+
if target_shape:
163+
payload["target_shape"] = target_shape
164+
if seed is not None:
165+
payload["seed"] = seed
166+
return ImageTaskRequest(**payload)
167+
168+
169+
@router.post("/generations", response_model=OpenAIImageResponse)
170+
async def create_openai_image_generation(request: Request, body: OpenAIImageGenerationRequest):
171+
if body.n != 1:
172+
raise HTTPException(status_code=400, detail="Only n=1 is currently supported")
173+
if not body.prompt.strip():
174+
raise HTTPException(status_code=400, detail="prompt is required")
175+
176+
target_shape = None
177+
if body.size:
178+
try:
179+
width, height = _shape_from_size(body.size)
180+
target_shape = [height, width]
181+
except ValueError as e:
182+
raise HTTPException(status_code=400, detail=str(e))
183+
184+
message = _build_image_task_request(
185+
prompt=body.prompt,
186+
seed=body.seed,
187+
target_shape=target_shape,
188+
)
189+
190+
result_png = await _run_sync_image_task(request, message)
191+
return _build_openai_response(request, message.task_id, result_png, body.response_format)
192+
193+
194+
async def _save_upload_file(file: UploadFile, target_dir: Path) -> str:
195+
if not file.filename:
196+
raise HTTPException(status_code=400, detail="Uploaded file has no filename")
197+
198+
file_extension = Path(file.filename).suffix or ".png"
199+
unique_filename = f"{uuid.uuid4()}{file_extension}"
200+
file_path = target_dir / unique_filename
201+
202+
content = await file.read()
203+
if not content:
204+
raise HTTPException(status_code=400, detail=f"Uploaded file is empty: {file.filename}")
205+
await asyncio.to_thread(_write_file_sync, file_path, content)
206+
return str(file_path)
207+
208+
209+
@router.post("/edits", response_model=OpenAIImageResponse)
210+
async def create_openai_image_edit(
211+
request: Request,
212+
image: UploadFile = File(...),
213+
prompt: str = Form(...),
214+
mask: UploadFile | None = File(default=None),
215+
model: str | None = Form(default=None),
216+
n: int = Form(default=1),
217+
size: str | None = Form(default=None),
218+
response_format: Literal["url", "b64_json"] = Form(default="url"),
219+
user: str | None = Form(default=None),
220+
negative_prompt: str = Form(default=""),
221+
seed: int | None = Form(default=None),
222+
):
223+
_ = model, user
224+
if n != 1:
225+
raise HTTPException(status_code=400, detail="Only n=1 is currently supported")
226+
if not prompt.strip():
227+
raise HTTPException(status_code=400, detail="prompt is required")
228+
229+
services = get_services()
230+
assert services.file_service is not None, "File service is not initialized"
231+
232+
target_shape = None
233+
if size:
234+
try:
235+
width, height = _shape_from_size(size)
236+
target_shape = [height, width]
237+
except ValueError as e:
238+
raise HTTPException(status_code=400, detail=str(e))
239+
240+
image_path = await _save_upload_file(image, services.file_service.input_image_dir)
241+
image_mask_path = ""
242+
if mask is not None:
243+
image_mask_path = await _save_upload_file(mask, services.file_service.input_image_dir)
244+
245+
message = _build_image_task_request(
246+
prompt=prompt,
247+
negative_prompt=negative_prompt,
248+
seed=seed,
249+
target_shape=target_shape,
250+
image_path=image_path,
251+
image_mask_path=image_mask_path,
252+
)
253+
254+
result_png = await _run_sync_image_task(request, message)
255+
return _build_openai_response(request, message.task_id, result_png, response_format)

lightx2v/server/api/router.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from fastapi import APIRouter
22

33
from .files import router as files_router
4+
from .openai_images import router as openai_images_router
45
from .service_routes import router as service_router
56
from .tasks import common_router, image_router, video_router
67

@@ -19,6 +20,7 @@ def create_api_router() -> APIRouter:
1920
tasks_router.post("/", response_model_exclude_unset=True, deprecated=True)(create_video_task)
2021

2122
api_router.include_router(tasks_router)
23+
api_router.include_router(openai_images_router, prefix="/v1/images", tags=["openai-images"])
2224
api_router.include_router(files_router, prefix="/v1/files", tags=["files"])
2325
api_router.include_router(service_router, prefix="/v1/service", tags=["service"])
2426

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import argparse
2+
import base64
3+
from pathlib import Path
4+
from typing import Any
5+
6+
import requests
7+
8+
try:
9+
from openai import OpenAI # pyright: ignore[reportMissingImports]
10+
except ImportError:
11+
OpenAI = None # type: ignore[assignment]
12+
13+
14+
def _extract_data_item(response: Any) -> dict[str, Any]:
15+
if not hasattr(response, "data") or not response.data:
16+
raise RuntimeError(f"Invalid OpenAI images response: {response}")
17+
item = response.data[0]
18+
if hasattr(item, "model_dump"):
19+
return item.model_dump() # openai pydantic object
20+
if isinstance(item, dict):
21+
return item
22+
raise RuntimeError(f"Unsupported data item type: {type(item)!r}")
23+
24+
25+
def _save_image_from_item(item: dict[str, Any], output_path: Path) -> Path:
26+
output_path.parent.mkdir(parents=True, exist_ok=True)
27+
28+
if "b64_json" in item and item["b64_json"]:
29+
image_bytes = base64.b64decode(item["b64_json"])
30+
output_path.write_bytes(image_bytes)
31+
return output_path
32+
33+
if "url" in item and item["url"]:
34+
resp = requests.get(item["url"], timeout=120)
35+
resp.raise_for_status()
36+
output_path.write_bytes(resp.content)
37+
return output_path
38+
39+
raise RuntimeError(f"Response item has neither b64_json nor url: {item}")
40+
41+
42+
def run_generate(client: Any, args: argparse.Namespace) -> Path:
43+
response = client.images.generate(
44+
model=args.model,
45+
prompt=args.prompt,
46+
size=args.size,
47+
response_format=args.response_format,
48+
)
49+
item = _extract_data_item(response)
50+
print(f"[generate] response item: {item}")
51+
return _save_image_from_item(item, Path(args.output_dir) / "generate.png")
52+
53+
54+
def run_edit(client: Any, args: argparse.Namespace) -> Path:
55+
if not args.image:
56+
raise ValueError("--image is required for edit mode")
57+
58+
image_path = Path(args.image)
59+
if not image_path.exists():
60+
raise FileNotFoundError(f"Image file not found: {image_path}")
61+
62+
with image_path.open("rb") as image_file:
63+
kwargs = {
64+
"model": args.model,
65+
"image": image_file,
66+
"prompt": args.edit_prompt or args.prompt,
67+
"size": args.size,
68+
"response_format": args.response_format,
69+
}
70+
if args.mask:
71+
mask_path = Path(args.mask)
72+
if not mask_path.exists():
73+
raise FileNotFoundError(f"Mask file not found: {mask_path}")
74+
with mask_path.open("rb") as mask_file:
75+
response = client.images.edit(mask=mask_file, **kwargs)
76+
else:
77+
response = client.images.edit(**kwargs)
78+
79+
item = _extract_data_item(response)
80+
print(f"[edit] response item: {item}")
81+
return _save_image_from_item(item, Path(args.output_dir) / "edit.png")
82+
83+
84+
def main() -> None:
85+
parser = argparse.ArgumentParser(description="Test OpenAI-compatible image APIs on LightX2V server.")
86+
parser.add_argument("--base_url", type=str, default="http://127.0.0.1:8000/v1", help="OpenAI-compatible base URL")
87+
parser.add_argument("--api_key", type=str, default="dummy-key", help="OpenAI API key placeholder")
88+
parser.add_argument("--model", type=str, default="gpt-image-1", help="Model name (for compatibility only)")
89+
parser.add_argument("--mode", choices=["generate", "edit", "all"], default="all", help="Test mode")
90+
parser.add_argument("--prompt", type=str, default="a futuristic city at sunset", help="Prompt for generation")
91+
parser.add_argument("--edit_prompt", type=str, default="", help="Prompt for edit (defaults to --prompt)")
92+
parser.add_argument("--size", type=str, default="1024x1024", help="Image size, e.g. 1024x1024")
93+
parser.add_argument("--response_format", choices=["url", "b64_json"], default="url", help="OpenAI response format")
94+
parser.add_argument("--image", type=str, default="", help="Input image path for edit mode")
95+
parser.add_argument("--mask", type=str, default="", help="Optional mask image path for edit mode")
96+
parser.add_argument("--output_dir", type=str, default="outputs/openai_images_test", help="Directory to save outputs")
97+
args = parser.parse_args()
98+
99+
if OpenAI is None:
100+
raise RuntimeError("Missing dependency: openai. Please install it with `pip install openai`.")
101+
102+
client = OpenAI(api_key=args.api_key, base_url=args.base_url)
103+
104+
output_paths: list[Path] = []
105+
if args.mode in ("generate", "all"):
106+
output_paths.append(run_generate(client, args))
107+
if args.mode in ("edit", "all"):
108+
output_paths.append(run_edit(client, args))
109+
110+
for path in output_paths:
111+
print(f"[saved] {path}")
112+
113+
114+
if __name__ == "__main__":
115+
main()

0 commit comments

Comments
 (0)