Skip to content

Commit c9e609c

Browse files
committed
fix
1 parent 04b5a56 commit c9e609c

5 files changed

Lines changed: 2677 additions & 364 deletions

File tree

api/api.py

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ class GenerateRequest(BaseModel):
144144
video_length: float = Field(5.0, description="Length of the video in seconds.", gt=0)
145145
seed: int = Field(-1, description="Seed for generation. -1 for random.")
146146
use_teacache: bool = Field(False, description="Enable TEACache optimization.")
147-
gpu_memory_preservation: float = Field(0.0, description="GPU memory to preserve (GB) in low VRAM mode.", ge=0)
147+
gpu_memory_preservation: float = Field(10.0, description="GPU memory to preserve (GB) in low VRAM mode.", ge=0)
148148
steps: int = Field(20, description="Number of diffusion steps.", gt=0)
149149
cfg: float = Field(7.0, description="Classifier-Free Guidance scale.", ge=1.0)
150150
gs: float = Field(1.0, description="Guidance scale for start latent.", ge=0)
@@ -189,24 +189,28 @@ class ResultResponse(BaseModel):
189189
class ImageGenerationRequest(BaseModel):
190190
prompt: str = Field(..., description="Text prompt for image generation.")
191191
negative_prompt: Optional[str] = Field("", description="Negative prompt.")
192+
input_image: str = Field(..., description="Base64 encoded input image (required).")
192193
seed: Optional[int] = Field(-1, description="Seed for generation. -1 for random.")
193194
steps: int = Field(30, description="Number of diffusion steps.", gt=0)
194195
cfg: float = Field(1.0, description="Classifier-Free Guidance scale.", ge=0.0)
195196
width: int = Field(1216, description="Image width.", gt=0)
196197
height: int = Field(704, description="Image height.", gt=0)
198+
gpu_memory_preservation: float = Field(10.0, description="GPU memory to preserve (GB) in low VRAM mode.", ge=0)
197199
lora_paths: Optional[List[str]] = Field(None, description="List of LoRA file paths.")
198200
lora_scales: Optional[List[float]] = Field(None, description="List of LoRA scales.")
199201

200202

201203
class BatchImageRequest(BaseModel):
202204
prompts: List[str] = Field(..., description="List of text prompts for batch generation.")
205+
input_image: str = Field(..., description="Base64 encoded input image (used for all batch items).")
203206
negative_prompt: Optional[str] = Field("", description="Negative prompt for all images.")
204207
seeds: Optional[List[int]] = Field(None, description="List of seeds (one per prompt).")
205208
batch_size: int = Field(4, description="Number of images to generate in parallel.", gt=0, le=8)
206209
steps: int = Field(30, description="Number of diffusion steps.", gt=0)
207210
cfg: float = Field(1.0, description="Classifier-Free Guidance scale.", ge=0.0)
208211
width: int = Field(1216, description="Image width.", gt=0)
209212
height: int = Field(704, description="Image height.", gt=0)
213+
gpu_memory_preservation: float = Field(10.0, description="GPU memory to preserve (GB) in low VRAM mode.", ge=0)
210214
lora_paths: Optional[List[str]] = Field(None, description="List of LoRA file paths.")
211215
lora_scales: Optional[List[float]] = Field(None, description="List of LoRA scales.")
212216

@@ -220,6 +224,7 @@ class ImageTransferRequest(BaseModel):
220224
seed: Optional[int] = Field(-1, description="Seed for generation. -1 for random.")
221225
steps: int = Field(30, description="Number of diffusion steps.", gt=0)
222226
cfg: float = Field(1.0, description="Classifier-Free Guidance scale.", ge=0.0)
227+
gpu_memory_preservation: float = Field(10.0, description="GPU memory to preserve (GB) in low VRAM mode.", ge=0)
223228

224229

225230
class ImageGenerationResponse(BaseModel):
@@ -233,6 +238,12 @@ class SamplingMode(str, enum.Enum):
233238
forward = "forward"
234239

235240

241+
# --- Enum for Output Type ---
242+
class OutputType(str, enum.Enum):
243+
video = "video"
244+
image = "image"
245+
246+
236247
# --- Background Worker ---
237248
def background_worker_task():
238249
global worker_running, currently_processing_job_id
@@ -757,21 +768,59 @@ async def list_videos():
757768
# --- Image Generation Endpoints ---
758769

759770
@app.post("/api/generate-image", response_model=ImageGenerationResponse)
760-
async def generate_image(request: ImageGenerationRequest):
771+
async def generate_image(
772+
prompt: str = Form(..., description="Text prompt for image generation."),
773+
negative_prompt: str = Form("", description="Negative prompt."),
774+
seed: int = Form(-1, description="Seed for generation. -1 for random."),
775+
steps: int = Form(30, description="Number of diffusion steps."),
776+
cfg: float = Form(1.0, description="Classifier-Free Guidance scale."),
777+
width: int = Form(1216, description="Image width."),
778+
height: int = Form(704, description="Image height."),
779+
lora_path: Optional[str] = Form("", description="LoRA file path."),
780+
lora_scale: float = Form(1.0, description="LoRA scale."),
781+
gpu_memory_preservation: float = Form(10.0, description="GPU memory to preserve (GB) in low VRAM mode."),
782+
sampling_mode: str = Form("dpm-solver++", description="Sampling mode."),
783+
transformer_model: str = Form("base", description="Transformer model (base or f1)."),
784+
input_image: UploadFile = File(..., description="Input image file (required).")
785+
):
761786
"""
762-
Generate a single high-quality image from a text prompt.
787+
Generate a single high-quality image from input image and text prompt (Image-to-Image).
763788
"""
764789
if not loaded_models:
765790
raise HTTPException(status_code=503, detail="Models are not loaded or failed to load. API is not ready.")
766791

792+
# Read and encode input image
793+
try:
794+
contents = await input_image.read()
795+
import base64
796+
input_image_b64 = base64.b64encode(contents).decode('utf-8')
797+
except Exception as e:
798+
raise HTTPException(status_code=400, detail=f"Failed to read input image: {e}")
799+
finally:
800+
await input_image.close()
801+
767802
# Create a unique job ID
768803
import uuid
769804
job_id = str(uuid.uuid4())
770805

771806
# Prepare job data
772807
job_data = {
773808
"type": "image",
774-
"data": request.dict(),
809+
"data": {
810+
"prompt": prompt,
811+
"negative_prompt": negative_prompt,
812+
"input_image": input_image_b64,
813+
"seed": seed,
814+
"steps": steps,
815+
"cfg": cfg,
816+
"width": width,
817+
"height": height,
818+
"lora_path": lora_path,
819+
"lora_scale": lora_scale,
820+
"gpu_memory_preservation": gpu_memory_preservation,
821+
"sampling_mode": sampling_mode,
822+
"transformer_model": transformer_model
823+
},
775824
"created_at": datetime.now().isoformat()
776825
}
777826

@@ -784,7 +833,11 @@ async def generate_image(request: ImageGenerationRequest):
784833
raise HTTPException(status_code=500, detail=f"Failed to add job to queue: {e}")
785834

786835
# Start processing in background
787-
asyncio.create_task(worker_image.process_image_generation(job_id, job_data, loaded_models))
836+
threading.Thread(
837+
target=worker_image.process_image_generation,
838+
args=(job_id, job_data, loaded_models),
839+
daemon=True
840+
).start()
788841

789842
return ImageGenerationResponse(
790843
job_id=job_id,
@@ -820,7 +873,11 @@ async def batch_images(request: BatchImageRequest):
820873
raise HTTPException(status_code=500, detail=f"Failed to add job to queue: {e}")
821874

822875
# Start processing in background
823-
asyncio.create_task(worker_image.process_batch_images(job_id, job_data, loaded_models))
876+
threading.Thread(
877+
target=worker_image.process_batch_images,
878+
args=(job_id, job_data, loaded_models),
879+
daemon=True
880+
).start()
824881

825882
return ImageGenerationResponse(
826883
job_id=job_id,
@@ -856,7 +913,11 @@ async def transfer_image(request: ImageTransferRequest):
856913
raise HTTPException(status_code=500, detail=f"Failed to add job to queue: {e}")
857914

858915
# Start processing in background
859-
asyncio.create_task(worker_image.process_image_transfer(job_id, job_data, loaded_models))
916+
threading.Thread(
917+
target=worker_image.process_image_transfer,
918+
args=(job_id, job_data, loaded_models),
919+
daemon=True
920+
).start()
860921

861922
return ImageGenerationResponse(
862923
job_id=job_id,

0 commit comments

Comments
 (0)