Skip to content

Commit 0c96db2

Browse files
committed
feat: Add sampling mode and transformer model selection to QueuedJob
- Introduced `sampling_mode` and `transformer_model` fields in the `QueuedJob` class. - Updated the `add_to_queue` function to accept new parameters for sampling mode and transformer model. - Enhanced the worker function to handle both forward and reverse sampling modes based on the new parameters. - Implemented logic for selecting the appropriate transformer model during processing. - Added detailed logging for job progress and sampling steps.
1 parent 0ce151b commit 0c96db2

4 files changed

Lines changed: 502 additions & 329 deletions

File tree

api/api.py

Lines changed: 50 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@
55
import traceback
66
import asyncio
77
import json
8-
import base64 # 追加: Base64エンコード用
9-
import mimetypes # 追加: MIMEタイプ判定用
10-
import logging # 追加: Logging
8+
import base64
9+
import mimetypes
10+
import logging
1111
from contextlib import asynccontextmanager
12-
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form, Request # Request を追加
13-
from fastapi.responses import FileResponse, StreamingResponse # JSONResponse を削除
12+
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form, Request
13+
from fastapi.responses import FileResponse, StreamingResponse
1414
from fastapi.middleware.cors import CORSMiddleware
1515
from pydantic import BaseModel, Field
1616
from PIL import Image
1717
import numpy as np
18-
from typing import List, Optional # Import Optional (Dict removed as unused)
19-
from watchdog.observers import Observer # 追加: Watchdog Observer
18+
from typing import List, Optional
19+
from watchdog.observers import Observer
2020

2121
# Import modules created earlier (relative imports)
2222
from . import settings
@@ -39,15 +39,15 @@
3939

4040

4141
# --- Lifespan Context Manager ---
42-
@asynccontextmanager # Use the imported decorator directly
42+
@asynccontextmanager
4343
async def lifespan(app: FastAPI):
4444
# Startup logic
45-
global loaded_models, worker_running, worker_thread, observer, sse_clients # Add observer and sse_clients
45+
global loaded_models, worker_running, worker_thread, observer, sse_clients
4646
print("API starting up via lifespan...")
4747
# Load models
4848
try:
4949
# Consider running blocking IO in a threadpool executor in async context
50-
# e.g., await asyncio.to_thread(models.load_models) # lora_path removed
50+
# e.g., await asyncio.to_thread(models.load_models)
5151
# For simplicity now, keeping the direct call but be aware of potential blocking
5252
loaded_models = models.load_models()
5353
print("Models loaded successfully via lifespan.")
@@ -75,7 +75,7 @@ async def lifespan(app: FastAPI):
7575
except Exception as e:
7676
print(f"FATAL: Failed to start video watcher on startup: {e}")
7777
traceback.print_exc()
78-
observer = None # Ensure observer is None if startup failed
78+
observer = None
7979

8080
yield
8181

@@ -106,7 +106,7 @@ async def lifespan(app: FastAPI):
106106
# Cleanup resources
107107
print("Attempting to unload models...")
108108
try:
109-
models.unload_models(loaded_models) # Call the function from models module
109+
models.unload_models(loaded_models)
110110
print("Models unloaded successfully (or placeholder executed).")
111111
except Exception as unload_e:
112112
print(f"Error during model unloading: {unload_e}")
@@ -121,18 +121,18 @@ async def lifespan(app: FastAPI):
121121
# --- CORS Middleware Configuration ---
122122
app.add_middleware(
123123
CORSMiddleware,
124-
allow_origins=settings.ALLOWED_ORIGINS, # Use the loaded origins from settings
125-
allow_credentials=True, # Allow credentials (cookies, authorization headers, etc.)
126-
allow_methods=["*"], # Allow all methods (GET, POST, etc.)
127-
allow_headers=["*"], # Allow all headers (Content-Type, Authorization, etc.)
124+
allow_origins=settings.ALLOWED_ORIGINS,
125+
allow_credentials=True,
126+
allow_methods=["*"],
127+
allow_headers=["*"],
128128
)
129129
# --- End CORS Middleware Configuration ---
130130

131131

132132
# --- Pydantic Models for API Requests/Responses ---
133133
class GenerateRequest(BaseModel):
134134
prompt: str = Field(..., description="Text prompt for video generation.")
135-
# image: str = Field(..., description="Base64 encoded input image.") # Changed to use UploadFile
135+
# image: str = Field(..., description="Base64 encoded input image.")
136136
video_length: float = Field(5.0, description="Length of the video in seconds.", gt=0)
137137
seed: int = Field(-1, description="Seed for generation. -1 for random.")
138138
use_teacache: bool = Field(False, description="Enable TEACache optimization.")
@@ -184,15 +184,15 @@ def background_worker_task():
184184
while worker_running:
185185
next_job = queue_manager.get_next_job()
186186
if next_job:
187-
currently_processing_job_id = next_job.job_id # Set current job ID
187+
currently_processing_job_id = next_job.job_id
188188
print(f"Worker picked up job: {currently_processing_job_id}")
189189
try:
190190
# Ensure models are loaded before processing
191191
if not loaded_models:
192192
print("Error: Models not loaded. Cannot process job.")
193193
queue_manager.update_job_status(currently_processing_job_id, "failed - models not loaded")
194-
currently_processing_job_id = None # Clear current job ID on error
195-
continue # Skip to next loop iteration
194+
currently_processing_job_id = None
195+
continue
196196

197197
worker.worker(next_job, loaded_models)
198198
except Exception as e:
@@ -220,18 +220,20 @@ def background_worker_task():
220220
@app.post("/generate", response_model=GenerateResponse)
221221
async def generate_video(
222222
background_tasks: BackgroundTasks,
223-
prompt: str = Form("A character doing some simple body movements."), # Set default prompt
223+
prompt: str = Form("A character doing some simple body movements."),
224224
video_length: float = Form(5.0),
225225
seed: int = Form(-1),
226-
use_teacache: bool = Form(True), # Default to True (matching demo_gradio.py)
227-
gpu_memory_preservation: float = Form(6.0), # Default to 6.0 GB (matching demo_gradio.py)
226+
use_teacache: bool = Form(True),
227+
gpu_memory_preservation: float = Form(6.0),
228228
steps: int = Form(25),
229229
cfg: float = Form(1.0),
230230
gs: float = Form(10.0),
231231
rs: float = Form(0.0),
232232
mp4_crf: float = Form(16.0),
233-
lora_scale: float = Form(1.0), # 追加: LoRA強度パラメータ
234-
lora_path: Optional[str] = Form(None, description="Path to the LoRA file to use for this request (overrides server default if provided)."), # 追加: LoRAファイルパス
233+
lora_scale: float = Form(1.0),
234+
lora_path: Optional[str] = Form(None, description="Path to the LoRA file to use for this request (overrides server default if provided)."),
235+
sampling_mode: str = Form("reverse", description="Sampling loop direction ('reverse' or 'forward')."),
236+
transformer_model: str = Form("base", description="Transformer model to use ('base' or 'f1')."),
235237
image: UploadFile = File(...)
236238
):
237239
"""
@@ -263,7 +265,7 @@ async def generate_video(
263265
job_id = queue_manager.add_to_queue(
264266
prompt=prompt,
265267
image=image_np,
266-
original_exif=original_exif, # Pass extracted Exif data
268+
original_exif=original_exif,
267269
video_length=video_length,
268270
seed=seed,
269271
use_teacache=use_teacache,
@@ -273,9 +275,11 @@ async def generate_video(
273275
gs=gs,
274276
rs=rs,
275277
mp4_crf=mp4_crf,
276-
lora_scale=lora_scale, # 追加: lora_scale を渡す
277-
lora_path=lora_path, # 追加: lora_path を渡す
278-
status="pending" # Explicitly set initial status
278+
lora_scale=lora_scale,
279+
lora_path=lora_path,
280+
sampling_mode=sampling_mode,
281+
transformer_model=transformer_model,
282+
status="pending"
279283
)
280284
except Exception as e:
281285
print(f"Error adding job via queue_manager: {e}")
@@ -313,7 +317,7 @@ async def get_job_status(job_id: str):
313317
return JobStatusResponse(job_id=job_id, status="processing", progress_info="Details temporarily unavailable")
314318

315319
# 2. Check if the job exists in the queue file (pending, failed, potentially completed but file not checked yet)
316-
job_in_file = queue_manager.get_job_by_id(job_id) # Use the function that reads file
320+
job_in_file = queue_manager.get_job_by_id(job_id)
317321
if job_in_file:
318322
# Return the status and progress details from the file
319323
return JobStatusResponse(
@@ -347,7 +351,7 @@ async def stream_job_status(job_id: str, request: Request):
347351
"""
348352
async def event_generator():
349353
last_data_sent = None
350-
# terminal_statuses = {"completed", "cancelled"} # Unused variable removed
354+
# terminal_statuses = {"completed", "cancelled"}
351355

352356
while True:
353357
# Check if client disconnected
@@ -388,14 +392,14 @@ async def event_generator():
388392
# Send final status event if it hasn't been sent already
389393
if current_data_json != last_data_sent:
390394
yield f"event: progress\ndata: {current_data_json}\n\n"
391-
last_data_sent = current_data_json # Ensure last_data_sent is updated even for the final message
395+
last_data_sent = current_data_json
392396
print(f"Sent final progress update for job {job_id}: Status {job.status}")
393397

394398
# Send a dedicated 'status' event to signal completion/failure/cancellation
395399
final_status_data = json.dumps({"status": job.status, "message": "Job finished."})
396400
yield f"event: status\ndata: {final_status_data}\n\n"
397401
print(f"Job {job_id} reached terminal state: {job.status}. Closing stream.")
398-
break # Exit loop after sending final status
402+
break
399403
else:
400404
# Wait before checking again only if not terminal
401405
await asyncio.sleep(1) # Check every 1 second
@@ -404,7 +408,7 @@ async def event_generator():
404408
final_data = json.dumps({"status": job.status, "message": "Job finished."})
405409
yield f"event: status\ndata: {final_data}\n\n"
406410
print(f"Job {job_id} reached terminal state: {job.status}. Closing stream.")
407-
break # Exit loop after sending final status
411+
break
408412

409413
# Wait before checking again
410414
await asyncio.sleep(1) # Check every 1 second
@@ -413,7 +417,7 @@ async def event_generator():
413417

414418

415419
@app.get("/result/{job_id}", response_model=ResultResponse)
416-
async def get_job_result(job_id: str, request: Request): # requestを追加してURLを構築
420+
async def get_job_result(job_id: str, request: Request):
417421
"""
418422
Returns the download URL for the completed video and the Base64 encoded thumbnail.
419423
"""
@@ -440,7 +444,7 @@ async def get_job_result(job_id: str, request: Request): # requestを追加し
440444
thumbnail_base64 = f"data:{mime_type};base64,{thumbnail_base64_data}"
441445
else:
442446
# MIMEタイプが不明な場合はデフォルトを使用(またはエラー処理)
443-
thumbnail_base64 = f"data:image/jpeg;base64,{thumbnail_base64_data}" # デフォルトをJPEGに
447+
thumbnail_base64 = f"data:image/jpeg;base64,{thumbnail_base64_data}"
444448
print(f"Job {job_id}: Encoded thumbnail from {job.thumbnail}")
445449
except Exception as e:
446450
print(f"Job {job_id}: Error reading or encoding thumbnail {job.thumbnail}: {e}")
@@ -479,7 +483,7 @@ async def get_input_image(job_id: str):
479483
Returns the input JPEG image file associated with a job, potentially including Exif metadata.
480484
"""
481485
job = queue_manager.get_job_by_id(job_id)
482-
filename_base = f"queue_image_{job_id}.jpg" # Changed extension to jpg
486+
filename_base = f"queue_image_{job_id}.jpg"
483487
input_image_path_in_temp = os.path.join(settings.TEMP_QUEUE_IMAGES_DIR, filename_base)
484488

485489
if not job:
@@ -563,16 +567,16 @@ async def get_worker_status():
563567

564568
@app.post("/cleanup_jobs", status_code=200)
565569
async def trigger_cleanup_jobs():
566-
""" # Correct indentation for docstring
570+
"""
567571
Manually triggers the cleanup of old completed, cancelled, or failed jobs
568572
based on the MAX_COMPLETED_JOBS setting.
569573
"""
570-
try: # Correct indentation for try block
574+
try:
571575
removed_count = queue_manager.cleanup_jobs_by_max_count()
572576
return {"message": f"Cleanup process completed. Removed {removed_count} old job entries."}
573577
except Exception as e:
574-
print(f"Error during manual job cleanup: {e}") # Correct indentation
575-
traceback.print_exc() # Correct indentation
578+
print(f"Error during manual job cleanup: {e}")
579+
traceback.print_exc()
576580
raise HTTPException(status_code=500, detail=f"Failed to perform job cleanup: {e}")
577581

578582

@@ -582,22 +586,22 @@ async def trigger_cleanup_jobs():
582586
async def list_loras():
583587
"""Lists available LoRA files from the configured directory."""
584588
lora_files = []
585-
allowed_extensions = {".safetensors", ".pt", ".bin"} # Common LoRA extensions
589+
allowed_extensions = {".safetensors", ".pt", ".bin"}
586590
try:
587591
if os.path.isdir(settings.LORA_DIR):
588592
for filename in os.listdir(settings.LORA_DIR):
589593
if os.path.isfile(os.path.join(settings.LORA_DIR, filename)):
590594
_, ext = os.path.splitext(filename)
591595
if ext.lower() in allowed_extensions:
592596
lora_files.append(filename)
593-
lora_files.sort() # Sort alphabetically
597+
lora_files.sort()
594598
else:
595599
print(f"Warning: LORA_DIR '{settings.LORA_DIR}' is not a valid directory.")
596600
except Exception as e:
597601
print(f"Error listing LoRA files: {e}")
598602
# Return empty list on error, or raise HTTPException
599603
# raise HTTPException(status_code=500, detail=f"Failed to list LoRA files: {e}")
600-
return LoraListResponse(loras=lora_files) # Correct indentation for return
604+
return LoraListResponse(loras=lora_files)
601605

602606

603607
# === Video Streaming Endpoints ===
@@ -632,7 +636,7 @@ async def event_generator():
632636
logging.error(f"Error in SSE generator: {e}")
633637
# Optionally send an error event to the client
634638
# yield f"event: error\ndata: {json.dumps({'message': 'Internal server error'})}\n\n"
635-
break # Stop streaming on unexpected errors
639+
break
636640
finally:
637641
# Cleanup when client disconnects or loop breaks
638642
if client_queue in sse_clients:
@@ -689,4 +693,4 @@ async def list_videos():
689693
# Configure logging for the main execution context as well
690694
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
691695
print(f"Starting Uvicorn server on {settings.API_HOST}:{settings.API_PORT}")
692-
uvicorn.run("api.api:app", host=settings.API_HOST, port=settings.API_PORT, reload=True) # Use string import for reload
696+
uvicorn.run("api.api:app", host=settings.API_HOST, port=settings.API_PORT, reload=True)

0 commit comments

Comments
 (0)