Skip to content

Commit 57c282b

Browse files
authored
feat: Add sampling mode and transformer model selection to QueuedJob (#14)
* feat: Add sampling mode and transformer model selection to QueuedJob - Introduced `sampling_mode` and `transformer_model` fields in the `QueuedJob` class. - Updated the `add_to_queue` function to accept new parameters for sampling mode and transformer model. - Enhanced the worker function to handle both forward and reverse sampling modes based on the new parameters. - Implemented logic for selecting the appropriate transformer model during processing. - Added detailed logging for job progress and sampling steps. * feat: APIでのビデオ生成モード選択機能を追加 * feat: サンプリングモードに基づいてトランスフォーマーモデルを動的に決定する機能を追加 * feat: サンプリングモードの列挙型を追加し、ビデオ生成エンドポイントでの使用を更新 * feat: ジョブ生成テストにサンプリングモードとトランスフォーマーモデルの期待値を追加
1 parent 0ce151b commit 57c282b

6 files changed

Lines changed: 609 additions & 330 deletions

File tree

api/api.py

Lines changed: 68 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@
55
import traceback
66
import asyncio
77
import json
8-
import base64 # 追加: Base64エンコード用
9-
import mimetypes # 追加: MIMEタイプ判定用
10-
import logging # 追加: Logging
8+
import base64
9+
import mimetypes
10+
import logging
11+
import enum
1112
from contextlib import asynccontextmanager
12-
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form, Request # Request を追加
13-
from fastapi.responses import FileResponse, StreamingResponse # JSONResponse を削除
13+
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form, Request
14+
from fastapi.responses import FileResponse, StreamingResponse
1415
from fastapi.middleware.cors import CORSMiddleware
1516
from pydantic import BaseModel, Field
1617
from PIL import Image
1718
import numpy as np
18-
from typing import List, Optional # Import Optional (Dict removed as unused)
19-
from watchdog.observers import Observer # 追加: Watchdog Observer
19+
from typing import List, Optional
20+
from watchdog.observers import Observer
2021

2122
# Import modules created earlier (relative imports)
2223
from . import settings
@@ -39,15 +40,15 @@
3940

4041

4142
# --- Lifespan Context Manager ---
42-
@asynccontextmanager # Use the imported decorator directly
43+
@asynccontextmanager
4344
async def lifespan(app: FastAPI):
4445
# Startup logic
45-
global loaded_models, worker_running, worker_thread, observer, sse_clients # Add observer and sse_clients
46+
global loaded_models, worker_running, worker_thread, observer, sse_clients
4647
print("API starting up via lifespan...")
4748
# Load models
4849
try:
4950
# Consider running blocking IO in a threadpool executor in async context
50-
# e.g., await asyncio.to_thread(models.load_models) # lora_path removed
51+
# e.g., await asyncio.to_thread(models.load_models)
5152
# For simplicity now, keeping the direct call but be aware of potential blocking
5253
loaded_models = models.load_models()
5354
print("Models loaded successfully via lifespan.")
@@ -75,7 +76,7 @@ async def lifespan(app: FastAPI):
7576
except Exception as e:
7677
print(f"FATAL: Failed to start video watcher on startup: {e}")
7778
traceback.print_exc()
78-
observer = None # Ensure observer is None if startup failed
79+
observer = None
7980

8081
yield
8182

@@ -106,7 +107,7 @@ async def lifespan(app: FastAPI):
106107
# Cleanup resources
107108
print("Attempting to unload models...")
108109
try:
109-
models.unload_models(loaded_models) # Call the function from models module
110+
models.unload_models(loaded_models)
110111
print("Models unloaded successfully (or placeholder executed).")
111112
except Exception as unload_e:
112113
print(f"Error during model unloading: {unload_e}")
@@ -121,18 +122,18 @@ async def lifespan(app: FastAPI):
121122
# --- CORS Middleware Configuration ---
122123
app.add_middleware(
123124
CORSMiddleware,
124-
allow_origins=settings.ALLOWED_ORIGINS, # Use the loaded origins from settings
125-
allow_credentials=True, # Allow credentials (cookies, authorization headers, etc.)
126-
allow_methods=["*"], # Allow all methods (GET, POST, etc.)
127-
allow_headers=["*"], # Allow all headers (Content-Type, Authorization, etc.)
125+
allow_origins=settings.ALLOWED_ORIGINS,
126+
allow_credentials=True,
127+
allow_methods=["*"],
128+
allow_headers=["*"],
128129
)
129130
# --- End CORS Middleware Configuration ---
130131

131132

132133
# --- Pydantic Models for API Requests/Responses ---
133134
class GenerateRequest(BaseModel):
134135
prompt: str = Field(..., description="Text prompt for video generation.")
135-
# image: str = Field(..., description="Base64 encoded input image.") # Changed to use UploadFile
136+
# image: str = Field(..., description="Base64 encoded input image.")
136137
video_length: float = Field(5.0, description="Length of the video in seconds.", gt=0)
137138
seed: int = Field(-1, description="Seed for generation. -1 for random.")
138139
use_teacache: bool = Field(False, description="Enable TEACache optimization.")
@@ -177,22 +178,28 @@ class ResultResponse(BaseModel):
177178
thumbnail_base64: Optional[str] = None
178179

179180

181+
# --- Enum for Sampling Mode ---
182+
class SamplingMode(str, enum.Enum):
183+
reverse = "reverse"
184+
forward = "forward"
185+
186+
180187
# --- Background Worker ---
181188
def background_worker_task():
182189
global worker_running, currently_processing_job_id
183190
print("Background worker started.")
184191
while worker_running:
185192
next_job = queue_manager.get_next_job()
186193
if next_job:
187-
currently_processing_job_id = next_job.job_id # Set current job ID
194+
currently_processing_job_id = next_job.job_id
188195
print(f"Worker picked up job: {currently_processing_job_id}")
189196
try:
190197
# Ensure models are loaded before processing
191198
if not loaded_models:
192199
print("Error: Models not loaded. Cannot process job.")
193200
queue_manager.update_job_status(currently_processing_job_id, "failed - models not loaded")
194-
currently_processing_job_id = None # Clear current job ID on error
195-
continue # Skip to next loop iteration
201+
currently_processing_job_id = None
202+
continue
196203

197204
worker.worker(next_job, loaded_models)
198205
except Exception as e:
@@ -220,18 +227,19 @@ def background_worker_task():
220227
@app.post("/generate", response_model=GenerateResponse)
221228
async def generate_video(
222229
background_tasks: BackgroundTasks,
223-
prompt: str = Form("A character doing some simple body movements."), # Set default prompt
230+
prompt: str = Form("A character doing some simple body movements."),
224231
video_length: float = Form(5.0),
225232
seed: int = Form(-1),
226-
use_teacache: bool = Form(True), # Default to True (matching demo_gradio.py)
227-
gpu_memory_preservation: float = Form(6.0), # Default to 6.0 GB (matching demo_gradio.py)
233+
use_teacache: bool = Form(True),
234+
gpu_memory_preservation: float = Form(6.0),
228235
steps: int = Form(25),
229236
cfg: float = Form(1.0),
230237
gs: float = Form(10.0),
231238
rs: float = Form(0.0),
232239
mp4_crf: float = Form(16.0),
233-
lora_scale: float = Form(1.0), # 追加: LoRA強度パラメータ
234-
lora_path: Optional[str] = Form(None, description="Path to the LoRA file to use for this request (overrides server default if provided)."), # 追加: LoRAファイルパス
240+
lora_scale: float = Form(1.0),
241+
lora_path: Optional[str] = Form(None, description="Path to the LoRA file to use for this request (overrides server default if provided)."),
242+
sampling_mode: SamplingMode = Form(SamplingMode.reverse, description="Sampling loop direction."),
235243
image: UploadFile = File(...)
236244
):
237245
"""
@@ -258,12 +266,24 @@ async def generate_video(
258266
finally:
259267
await image.close()
260268

269+
# Determine the transformer model based on sampling_mode
270+
# Use sampling_mode.value to get the string value from the Enum
271+
if sampling_mode == SamplingMode.forward:
272+
actual_transformer_model = "f1"
273+
elif sampling_mode == SamplingMode.reverse:
274+
actual_transformer_model = "base"
275+
else:
276+
# This 'else' block might be unreachable if using Enum correctly,
277+
# but kept for safety or future expansion. FastAPI handles invalid Enum values.
278+
print(f"Warning: Unexpected sampling_mode '{sampling_mode.value}'. Defaulting transformer_model to 'base'.")
279+
actual_transformer_model = "base"
280+
261281
# Add job to the queue using queue_manager
262282
try:
263283
job_id = queue_manager.add_to_queue(
264284
prompt=prompt,
265285
image=image_np,
266-
original_exif=original_exif, # Pass extracted Exif data
286+
original_exif=original_exif,
267287
video_length=video_length,
268288
seed=seed,
269289
use_teacache=use_teacache,
@@ -273,9 +293,11 @@ async def generate_video(
273293
gs=gs,
274294
rs=rs,
275295
mp4_crf=mp4_crf,
276-
lora_scale=lora_scale, # 追加: lora_scale を渡す
277-
lora_path=lora_path, # 追加: lora_path を渡す
278-
status="pending" # Explicitly set initial status
296+
lora_scale=lora_scale,
297+
lora_path=lora_path,
298+
sampling_mode=sampling_mode.value,
299+
transformer_model=actual_transformer_model,
300+
status="pending"
279301
)
280302
except Exception as e:
281303
print(f"Error adding job via queue_manager: {e}")
@@ -313,7 +335,7 @@ async def get_job_status(job_id: str):
313335
return JobStatusResponse(job_id=job_id, status="processing", progress_info="Details temporarily unavailable")
314336

315337
# 2. Check if the job exists in the queue file (pending, failed, potentially completed but file not checked yet)
316-
job_in_file = queue_manager.get_job_by_id(job_id) # Use the function that reads file
338+
job_in_file = queue_manager.get_job_by_id(job_id)
317339
if job_in_file:
318340
# Return the status and progress details from the file
319341
return JobStatusResponse(
@@ -347,7 +369,7 @@ async def stream_job_status(job_id: str, request: Request):
347369
"""
348370
async def event_generator():
349371
last_data_sent = None
350-
# terminal_statuses = {"completed", "cancelled"} # Unused variable removed
372+
# terminal_statuses = {"completed", "cancelled"}
351373

352374
while True:
353375
# Check if client disconnected
@@ -388,14 +410,14 @@ async def event_generator():
388410
# Send final status event if it hasn't been sent already
389411
if current_data_json != last_data_sent:
390412
yield f"event: progress\ndata: {current_data_json}\n\n"
391-
last_data_sent = current_data_json # Ensure last_data_sent is updated even for the final message
413+
last_data_sent = current_data_json
392414
print(f"Sent final progress update for job {job_id}: Status {job.status}")
393415

394416
# Send a dedicated 'status' event to signal completion/failure/cancellation
395417
final_status_data = json.dumps({"status": job.status, "message": "Job finished."})
396418
yield f"event: status\ndata: {final_status_data}\n\n"
397419
print(f"Job {job_id} reached terminal state: {job.status}. Closing stream.")
398-
break # Exit loop after sending final status
420+
break
399421
else:
400422
# Wait before checking again only if not terminal
401423
await asyncio.sleep(1) # Check every 1 second
@@ -404,7 +426,7 @@ async def event_generator():
404426
final_data = json.dumps({"status": job.status, "message": "Job finished."})
405427
yield f"event: status\ndata: {final_data}\n\n"
406428
print(f"Job {job_id} reached terminal state: {job.status}. Closing stream.")
407-
break # Exit loop after sending final status
429+
break
408430

409431
# Wait before checking again
410432
await asyncio.sleep(1) # Check every 1 second
@@ -413,7 +435,7 @@ async def event_generator():
413435

414436

415437
@app.get("/result/{job_id}", response_model=ResultResponse)
416-
async def get_job_result(job_id: str, request: Request): # requestを追加してURLを構築
438+
async def get_job_result(job_id: str, request: Request):
417439
"""
418440
Returns the download URL for the completed video and the Base64 encoded thumbnail.
419441
"""
@@ -440,7 +462,7 @@ async def get_job_result(job_id: str, request: Request): # requestを追加し
440462
thumbnail_base64 = f"data:{mime_type};base64,{thumbnail_base64_data}"
441463
else:
442464
# MIMEタイプが不明な場合はデフォルトを使用(またはエラー処理)
443-
thumbnail_base64 = f"data:image/jpeg;base64,{thumbnail_base64_data}" # デフォルトをJPEGに
465+
thumbnail_base64 = f"data:image/jpeg;base64,{thumbnail_base64_data}"
444466
print(f"Job {job_id}: Encoded thumbnail from {job.thumbnail}")
445467
except Exception as e:
446468
print(f"Job {job_id}: Error reading or encoding thumbnail {job.thumbnail}: {e}")
@@ -479,7 +501,7 @@ async def get_input_image(job_id: str):
479501
Returns the input JPEG image file associated with a job, potentially including Exif metadata.
480502
"""
481503
job = queue_manager.get_job_by_id(job_id)
482-
filename_base = f"queue_image_{job_id}.jpg" # Changed extension to jpg
504+
filename_base = f"queue_image_{job_id}.jpg"
483505
input_image_path_in_temp = os.path.join(settings.TEMP_QUEUE_IMAGES_DIR, filename_base)
484506

485507
if not job:
@@ -563,16 +585,16 @@ async def get_worker_status():
563585

564586
@app.post("/cleanup_jobs", status_code=200)
565587
async def trigger_cleanup_jobs():
566-
""" # Correct indentation for docstring
588+
"""
567589
Manually triggers the cleanup of old completed, cancelled, or failed jobs
568590
based on the MAX_COMPLETED_JOBS setting.
569591
"""
570-
try: # Correct indentation for try block
592+
try:
571593
removed_count = queue_manager.cleanup_jobs_by_max_count()
572594
return {"message": f"Cleanup process completed. Removed {removed_count} old job entries."}
573595
except Exception as e:
574-
print(f"Error during manual job cleanup: {e}") # Correct indentation
575-
traceback.print_exc() # Correct indentation
596+
print(f"Error during manual job cleanup: {e}")
597+
traceback.print_exc()
576598
raise HTTPException(status_code=500, detail=f"Failed to perform job cleanup: {e}")
577599

578600

@@ -582,22 +604,22 @@ async def trigger_cleanup_jobs():
582604
async def list_loras():
583605
"""Lists available LoRA files from the configured directory."""
584606
lora_files = []
585-
allowed_extensions = {".safetensors", ".pt", ".bin"} # Common LoRA extensions
607+
allowed_extensions = {".safetensors", ".pt", ".bin"}
586608
try:
587609
if os.path.isdir(settings.LORA_DIR):
588610
for filename in os.listdir(settings.LORA_DIR):
589611
if os.path.isfile(os.path.join(settings.LORA_DIR, filename)):
590612
_, ext = os.path.splitext(filename)
591613
if ext.lower() in allowed_extensions:
592614
lora_files.append(filename)
593-
lora_files.sort() # Sort alphabetically
615+
lora_files.sort()
594616
else:
595617
print(f"Warning: LORA_DIR '{settings.LORA_DIR}' is not a valid directory.")
596618
except Exception as e:
597619
print(f"Error listing LoRA files: {e}")
598620
# Return empty list on error, or raise HTTPException
599621
# raise HTTPException(status_code=500, detail=f"Failed to list LoRA files: {e}")
600-
return LoraListResponse(loras=lora_files) # Correct indentation for return
622+
return LoraListResponse(loras=lora_files)
601623

602624

603625
# === Video Streaming Endpoints ===
@@ -632,7 +654,7 @@ async def event_generator():
632654
logging.error(f"Error in SSE generator: {e}")
633655
# Optionally send an error event to the client
634656
# yield f"event: error\ndata: {json.dumps({'message': 'Internal server error'})}\n\n"
635-
break # Stop streaming on unexpected errors
657+
break
636658
finally:
637659
# Cleanup when client disconnects or loop breaks
638660
if client_queue in sse_clients:
@@ -689,4 +711,4 @@ async def list_videos():
689711
# Configure logging for the main execution context as well
690712
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
691713
print(f"Starting Uvicorn server on {settings.API_HOST}:{settings.API_PORT}")
692-
uvicorn.run("api.api:app", host=settings.API_HOST, port=settings.API_PORT, reload=True) # Use string import for reload
714+
uvicorn.run("api.api:app", host=settings.API_HOST, port=settings.API_PORT, reload=True)

0 commit comments

Comments
 (0)