Skip to content

Commit 8d59718

Browse files
committed
fix
1 parent c9e609c commit 8d59718

4 files changed

Lines changed: 503 additions & 128 deletions

File tree

api/api.py

Lines changed: 68 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ async def lifespan(app: FastAPI):
126126
# Use the lifespan context manager
127127
app = FastAPI(title="FramePack API", version="0.1.0", lifespan=lifespan)
128128

129+
# Set maximum file size (100MB)
130+
app.state.max_request_size = 100 * 1024 * 1024
131+
129132
# --- CORS Middleware Configuration ---
130133
app.add_middleware(
131134
CORSMiddleware,
@@ -232,6 +235,8 @@ class ImageGenerationResponse(BaseModel):
232235
message: str
233236

234237

238+
239+
235240
# --- Enum for Sampling Mode ---
236241
class SamplingMode(str, enum.Enum):
237242
reverse = "reverse"
@@ -261,7 +266,23 @@ def background_worker_task():
261266
currently_processing_job_id = None
262267
continue
263268

264-
worker.worker(next_job, loaded_models)
269+
# Check job type and route to appropriate worker
270+
job_data = queue_manager.get_job_data_by_id(currently_processing_job_id)
271+
if job_data and isinstance(job_data, dict):
272+
job_type = job_data.get("type", "video")
273+
274+
if job_type == "image_generation":
275+
worker_image.process_image_generation(currently_processing_job_id, job_data, loaded_models)
276+
elif job_type == "batch_images":
277+
worker_image.process_batch_images(currently_processing_job_id, job_data, loaded_models)
278+
elif job_type == "image_transfer":
279+
worker_image.process_image_transfer(currently_processing_job_id, job_data, loaded_models)
280+
else:
281+
# Default to video generation
282+
worker.worker(next_job, loaded_models)
283+
else:
284+
# Fallback to video generation for compatibility
285+
worker.worker(next_job, loaded_models)
265286
except Exception as e:
266287
print(f"Unhandled exception in worker for job {currently_processing_job_id}: {e}")
267288
traceback.print_exc()
@@ -781,7 +802,7 @@ async def generate_image(
781802
gpu_memory_preservation: float = Form(10.0, description="GPU memory to preserve (GB) in low VRAM mode."),
782803
sampling_mode: str = Form("dpm-solver++", description="Sampling mode."),
783804
transformer_model: str = Form("base", description="Transformer model (base or f1)."),
784-
input_image: UploadFile = File(..., description="Input image file (required).")
805+
input_image: UploadFile = File(..., description="Input image file (required).", media_type="image/*")
785806
):
786807
"""
787808
Generate a single high-quality image from input image and text prompt (Image-to-Image).
@@ -791,21 +812,50 @@ async def generate_image(
791812

792813
# Read and encode input image
793814
try:
794-
contents = await input_image.read()
795-
import base64
796-
input_image_b64 = base64.b64encode(contents).decode('utf-8')
815+
print(f"Debug: input_image type = {type(input_image)}")
816+
817+
# Handle different types of input_image
818+
if isinstance(input_image, str):
819+
# If it's already a string, assume it's base64 encoded
820+
print("Debug: input_image is a string, treating as base64")
821+
input_image_b64 = input_image
822+
elif hasattr(input_image, 'read'):
823+
# It's an UploadFile
824+
print(f"Debug: input_image is UploadFile")
825+
if hasattr(input_image, 'filename'):
826+
print(f"Debug: filename = {input_image.filename}")
827+
if hasattr(input_image, 'content_type'):
828+
print(f"Debug: content_type = {input_image.content_type}")
829+
830+
contents = await input_image.read()
831+
print(f"Debug: Read {len(contents)} bytes from UploadFile")
832+
833+
import base64
834+
input_image_b64 = base64.b64encode(contents).decode('utf-8')
835+
print(f"Debug: Encoded to base64, length = {len(input_image_b64)}")
836+
else:
837+
raise HTTPException(status_code=400, detail=f"Invalid input_image type: {type(input_image)}. Expected UploadFile or base64 string.")
838+
797839
except Exception as e:
798-
raise HTTPException(status_code=400, detail=f"Failed to read input image: {e}")
840+
print(f"Debug: Exception occurred: {e}")
841+
print(f"Debug: Exception type: {type(e)}")
842+
import traceback
843+
traceback.print_exc()
844+
raise HTTPException(status_code=400, detail=f"Error loading image: {e}")
799845
finally:
800-
await input_image.close()
846+
if hasattr(input_image, 'close'):
847+
try:
848+
await input_image.close()
849+
except:
850+
pass
801851

802-
# Create a unique job ID
852+
# Create a unique job ID (8-digit hex like video generation)
803853
import uuid
804-
job_id = str(uuid.uuid4())
854+
job_id = uuid.uuid4().hex[:8]
805855

806856
# Prepare job data
807857
job_data = {
808-
"type": "image",
858+
"type": "image_generation",
809859
"data": {
810860
"prompt": prompt,
811861
"negative_prompt": negative_prompt,
@@ -832,19 +882,15 @@ async def generate_image(
832882
print(f"Error adding image job to queue: {e}")
833883
raise HTTPException(status_code=500, detail=f"Failed to add job to queue: {e}")
834884

835-
# Start processing in background
836-
threading.Thread(
837-
target=worker_image.process_image_generation,
838-
args=(job_id, job_data, loaded_models),
839-
daemon=True
840-
).start()
885+
# Jobs will be processed by the background worker
841886

842887
return ImageGenerationResponse(
843888
job_id=job_id,
844889
message="Image generation job added to queue."
845890
)
846891

847892

893+
848894
@app.post("/api/batch-images", response_model=ImageGenerationResponse)
849895
async def batch_images(request: BatchImageRequest):
850896
"""
@@ -853,9 +899,9 @@ async def batch_images(request: BatchImageRequest):
853899
if not loaded_models:
854900
raise HTTPException(status_code=503, detail="Models are not loaded or failed to load. API is not ready.")
855901

856-
# Create a unique job ID
902+
# Create a unique job ID (8-digit hex like video generation)
857903
import uuid
858-
job_id = str(uuid.uuid4())
904+
job_id = uuid.uuid4().hex[:8]
859905

860906
# Prepare job data
861907
job_data = {
@@ -872,12 +918,7 @@ async def batch_images(request: BatchImageRequest):
872918
print(f"Error adding batch image job to queue: {e}")
873919
raise HTTPException(status_code=500, detail=f"Failed to add job to queue: {e}")
874920

875-
# Start processing in background
876-
threading.Thread(
877-
target=worker_image.process_batch_images,
878-
args=(job_id, job_data, loaded_models),
879-
daemon=True
880-
).start()
921+
# Jobs will be processed by the background worker
881922

882923
return ImageGenerationResponse(
883924
job_id=job_id,
@@ -893,9 +934,9 @@ async def transfer_image(request: ImageTransferRequest):
893934
if not loaded_models:
894935
raise HTTPException(status_code=503, detail="Models are not loaded or failed to load. API is not ready.")
895936

896-
# Create a unique job ID
937+
# Create a unique job ID (8-digit hex like video generation)
897938
import uuid
898-
job_id = str(uuid.uuid4())
939+
job_id = uuid.uuid4().hex[:8]
899940

900941
# Prepare job data
901942
job_data = {
@@ -912,12 +953,7 @@ async def transfer_image(request: ImageTransferRequest):
912953
print(f"Error adding image transfer job to queue: {e}")
913954
raise HTTPException(status_code=500, detail=f"Failed to add job to queue: {e}")
914955

915-
# Start processing in background
916-
threading.Thread(
917-
target=worker_image.process_image_transfer,
918-
args=(job_id, job_data, loaded_models),
919-
daemon=True
920-
).start()
956+
# Jobs will be processed by the background worker
921957

922958
return ImageGenerationResponse(
923959
job_id=job_id,

0 commit comments

Comments
 (0)