Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 68 additions & 46 deletions api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@
import traceback
import asyncio
import json
import base64 # 追加: Base64エンコード用
import mimetypes # 追加: MIMEタイプ判定用
import logging # 追加: Logging
import base64
import mimetypes
import logging
import enum
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form, Request # Request を追加
from fastapi.responses import FileResponse, StreamingResponse # JSONResponse を削除
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form, Request
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from PIL import Image
import numpy as np
from typing import List, Optional # Import Optional (Dict removed as unused)
from watchdog.observers import Observer # 追加: Watchdog Observer
from typing import List, Optional
from watchdog.observers import Observer

# Import modules created earlier (relative imports)
from . import settings
Expand All @@ -39,15 +40,15 @@


# --- Lifespan Context Manager ---
@asynccontextmanager # Use the imported decorator directly
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup logic
global loaded_models, worker_running, worker_thread, observer, sse_clients # Add observer and sse_clients
global loaded_models, worker_running, worker_thread, observer, sse_clients
print("API starting up via lifespan...")
# Load models
try:
# Consider running blocking IO in a threadpool executor in async context
# e.g., await asyncio.to_thread(models.load_models) # lora_path removed
# e.g., await asyncio.to_thread(models.load_models)
# For simplicity now, keeping the direct call but be aware of potential blocking
loaded_models = models.load_models()
print("Models loaded successfully via lifespan.")
Expand Down Expand Up @@ -75,7 +76,7 @@ async def lifespan(app: FastAPI):
except Exception as e:
print(f"FATAL: Failed to start video watcher on startup: {e}")
traceback.print_exc()
observer = None # Ensure observer is None if startup failed
observer = None

yield

Expand Down Expand Up @@ -106,7 +107,7 @@ async def lifespan(app: FastAPI):
# Cleanup resources
print("Attempting to unload models...")
try:
models.unload_models(loaded_models) # Call the function from models module
models.unload_models(loaded_models)
print("Models unloaded successfully (or placeholder executed).")
except Exception as unload_e:
print(f"Error during model unloading: {unload_e}")
Expand All @@ -121,18 +122,18 @@ async def lifespan(app: FastAPI):
# --- CORS Middleware Configuration ---
app.add_middleware(
CORSMiddleware,
allow_origins=settings.ALLOWED_ORIGINS, # Use the loaded origins from settings
allow_credentials=True, # Allow credentials (cookies, authorization headers, etc.)
allow_methods=["*"], # Allow all methods (GET, POST, etc.)
allow_headers=["*"], # Allow all headers (Content-Type, Authorization, etc.)
allow_origins=settings.ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- End CORS Middleware Configuration ---


# --- Pydantic Models for API Requests/Responses ---
class GenerateRequest(BaseModel):
prompt: str = Field(..., description="Text prompt for video generation.")
# image: str = Field(..., description="Base64 encoded input image.") # Changed to use UploadFile
# image: str = Field(..., description="Base64 encoded input image.")
video_length: float = Field(5.0, description="Length of the video in seconds.", gt=0)
seed: int = Field(-1, description="Seed for generation. -1 for random.")
use_teacache: bool = Field(False, description="Enable TEACache optimization.")
Expand Down Expand Up @@ -177,22 +178,28 @@ class ResultResponse(BaseModel):
thumbnail_base64: Optional[str] = None


# --- Enum for Sampling Mode ---
class SamplingMode(str, enum.Enum):
reverse = "reverse"
forward = "forward"


# --- Background Worker ---
def background_worker_task():
global worker_running, currently_processing_job_id
print("Background worker started.")
while worker_running:
next_job = queue_manager.get_next_job()
if next_job:
currently_processing_job_id = next_job.job_id # Set current job ID
currently_processing_job_id = next_job.job_id
print(f"Worker picked up job: {currently_processing_job_id}")
try:
# Ensure models are loaded before processing
if not loaded_models:
print("Error: Models not loaded. Cannot process job.")
queue_manager.update_job_status(currently_processing_job_id, "failed - models not loaded")
currently_processing_job_id = None # Clear current job ID on error
continue # Skip to next loop iteration
currently_processing_job_id = None
continue

worker.worker(next_job, loaded_models)
except Exception as e:
Expand Down Expand Up @@ -220,18 +227,19 @@ def background_worker_task():
@app.post("/generate", response_model=GenerateResponse)
async def generate_video(
background_tasks: BackgroundTasks,
prompt: str = Form("A character doing some simple body movements."), # Set default prompt
prompt: str = Form("A character doing some simple body movements."),
video_length: float = Form(5.0),
seed: int = Form(-1),
use_teacache: bool = Form(True), # Default to True (matching demo_gradio.py)
gpu_memory_preservation: float = Form(6.0), # Default to 6.0 GB (matching demo_gradio.py)
use_teacache: bool = Form(True),
gpu_memory_preservation: float = Form(6.0),
steps: int = Form(25),
cfg: float = Form(1.0),
gs: float = Form(10.0),
rs: float = Form(0.0),
mp4_crf: float = Form(16.0),
lora_scale: float = Form(1.0), # 追加: LoRA強度パラメータ
lora_path: Optional[str] = Form(None, description="Path to the LoRA file to use for this request (overrides server default if provided)."), # 追加: LoRAファイルパス
lora_scale: float = Form(1.0),
lora_path: Optional[str] = Form(None, description="Path to the LoRA file to use for this request (overrides server default if provided)."),
sampling_mode: SamplingMode = Form(SamplingMode.reverse, description="Sampling loop direction."),
image: UploadFile = File(...)
):
"""
Expand All @@ -258,12 +266,24 @@ async def generate_video(
finally:
await image.close()

# Determine the transformer model based on sampling_mode
# Use sampling_mode.value to get the string value from the Enum
if sampling_mode == SamplingMode.forward:
actual_transformer_model = "f1"
elif sampling_mode == SamplingMode.reverse:
actual_transformer_model = "base"
else:
# This 'else' block might be unreachable if using Enum correctly,
# but kept for safety or future expansion. FastAPI handles invalid Enum values.
print(f"Warning: Unexpected sampling_mode '{sampling_mode.value}'. Defaulting transformer_model to 'base'.")
actual_transformer_model = "base"

# Add job to the queue using queue_manager
try:
job_id = queue_manager.add_to_queue(
prompt=prompt,
image=image_np,
original_exif=original_exif, # Pass extracted Exif data
original_exif=original_exif,
video_length=video_length,
seed=seed,
use_teacache=use_teacache,
Expand All @@ -273,9 +293,11 @@ async def generate_video(
gs=gs,
rs=rs,
mp4_crf=mp4_crf,
lora_scale=lora_scale, # 追加: lora_scale を渡す
lora_path=lora_path, # 追加: lora_path を渡す
status="pending" # Explicitly set initial status
lora_scale=lora_scale,
lora_path=lora_path,
sampling_mode=sampling_mode.value,
transformer_model=actual_transformer_model,
status="pending"
)
except Exception as e:
print(f"Error adding job via queue_manager: {e}")
Expand Down Expand Up @@ -313,7 +335,7 @@ async def get_job_status(job_id: str):
return JobStatusResponse(job_id=job_id, status="processing", progress_info="Details temporarily unavailable")

# 2. Check if the job exists in the queue file (pending, failed, potentially completed but file not checked yet)
job_in_file = queue_manager.get_job_by_id(job_id) # Use the function that reads file
job_in_file = queue_manager.get_job_by_id(job_id)
if job_in_file:
# Return the status and progress details from the file
return JobStatusResponse(
Expand Down Expand Up @@ -347,7 +369,7 @@ async def stream_job_status(job_id: str, request: Request):
"""
async def event_generator():
last_data_sent = None
# terminal_statuses = {"completed", "cancelled"} # Unused variable removed
# terminal_statuses = {"completed", "cancelled"}

while True:
# Check if client disconnected
Expand Down Expand Up @@ -388,14 +410,14 @@ async def event_generator():
# Send final status event if it hasn't been sent already
if current_data_json != last_data_sent:
yield f"event: progress\ndata: {current_data_json}\n\n"
last_data_sent = current_data_json # Ensure last_data_sent is updated even for the final message
last_data_sent = current_data_json
print(f"Sent final progress update for job {job_id}: Status {job.status}")

# Send a dedicated 'status' event to signal completion/failure/cancellation
final_status_data = json.dumps({"status": job.status, "message": "Job finished."})
yield f"event: status\ndata: {final_status_data}\n\n"
print(f"Job {job_id} reached terminal state: {job.status}. Closing stream.")
break # Exit loop after sending final status
break
else:
# Wait before checking again only if not terminal
await asyncio.sleep(1) # Check every 1 second
Expand All @@ -404,7 +426,7 @@ async def event_generator():
final_data = json.dumps({"status": job.status, "message": "Job finished."})
yield f"event: status\ndata: {final_data}\n\n"
print(f"Job {job_id} reached terminal state: {job.status}. Closing stream.")
break # Exit loop after sending final status
break

# Wait before checking again
await asyncio.sleep(1) # Check every 1 second
Expand All @@ -413,7 +435,7 @@ async def event_generator():


@app.get("/result/{job_id}", response_model=ResultResponse)
async def get_job_result(job_id: str, request: Request): # requestを追加してURLを構築
async def get_job_result(job_id: str, request: Request):
"""
Returns the download URL for the completed video and the Base64 encoded thumbnail.
"""
Expand All @@ -440,7 +462,7 @@ async def get_job_result(job_id: str, request: Request): # requestを追加し
thumbnail_base64 = f"data:{mime_type};base64,{thumbnail_base64_data}"
else:
# MIMEタイプが不明な場合はデフォルトを使用(またはエラー処理)
thumbnail_base64 = f"data:image/jpeg;base64,{thumbnail_base64_data}" # デフォルトをJPEGに
thumbnail_base64 = f"data:image/jpeg;base64,{thumbnail_base64_data}"
print(f"Job {job_id}: Encoded thumbnail from {job.thumbnail}")
except Exception as e:
print(f"Job {job_id}: Error reading or encoding thumbnail {job.thumbnail}: {e}")
Expand Down Expand Up @@ -479,7 +501,7 @@ async def get_input_image(job_id: str):
Returns the input JPEG image file associated with a job, potentially including Exif metadata.
"""
job = queue_manager.get_job_by_id(job_id)
filename_base = f"queue_image_{job_id}.jpg" # Changed extension to jpg
filename_base = f"queue_image_{job_id}.jpg"
input_image_path_in_temp = os.path.join(settings.TEMP_QUEUE_IMAGES_DIR, filename_base)

if not job:
Expand Down Expand Up @@ -563,16 +585,16 @@ async def get_worker_status():

@app.post("/cleanup_jobs", status_code=200)
async def trigger_cleanup_jobs():
""" # Correct indentation for docstring
"""
Manually triggers the cleanup of old completed, cancelled, or failed jobs
based on the MAX_COMPLETED_JOBS setting.
"""
try: # Correct indentation for try block
try:
removed_count = queue_manager.cleanup_jobs_by_max_count()
return {"message": f"Cleanup process completed. Removed {removed_count} old job entries."}
except Exception as e:
print(f"Error during manual job cleanup: {e}") # Correct indentation
traceback.print_exc() # Correct indentation
print(f"Error during manual job cleanup: {e}")
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Failed to perform job cleanup: {e}")


Expand All @@ -582,22 +604,22 @@ async def trigger_cleanup_jobs():
async def list_loras():
"""Lists available LoRA files from the configured directory."""
lora_files = []
allowed_extensions = {".safetensors", ".pt", ".bin"} # Common LoRA extensions
allowed_extensions = {".safetensors", ".pt", ".bin"}
try:
if os.path.isdir(settings.LORA_DIR):
for filename in os.listdir(settings.LORA_DIR):
if os.path.isfile(os.path.join(settings.LORA_DIR, filename)):
_, ext = os.path.splitext(filename)
if ext.lower() in allowed_extensions:
lora_files.append(filename)
lora_files.sort() # Sort alphabetically
lora_files.sort()
else:
print(f"Warning: LORA_DIR '{settings.LORA_DIR}' is not a valid directory.")
except Exception as e:
print(f"Error listing LoRA files: {e}")
# Return empty list on error, or raise HTTPException
# raise HTTPException(status_code=500, detail=f"Failed to list LoRA files: {e}")
return LoraListResponse(loras=lora_files) # Correct indentation for return
return LoraListResponse(loras=lora_files)


# === Video Streaming Endpoints ===
Expand Down Expand Up @@ -632,7 +654,7 @@ async def event_generator():
logging.error(f"Error in SSE generator: {e}")
# Optionally send an error event to the client
# yield f"event: error\ndata: {json.dumps({'message': 'Internal server error'})}\n\n"
break # Stop streaming on unexpected errors
break
finally:
# Cleanup when client disconnects or loop breaks
if client_queue in sse_clients:
Expand Down Expand Up @@ -689,4 +711,4 @@ async def list_videos():
# Configure logging for the main execution context as well
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
print(f"Starting Uvicorn server on {settings.API_HOST}:{settings.API_PORT}")
uvicorn.run("api.api:app", host=settings.API_HOST, port=settings.API_PORT, reload=True) # Use string import for reload
uvicorn.run("api.api:app", host=settings.API_HOST, port=settings.API_PORT, reload=True)
Loading