55import traceback
66import asyncio
77import json
8- import base64 # 追加: Base64エンコード用
9- import mimetypes # 追加: MIMEタイプ判定用
10- import logging # 追加: Logging
8+ import base64
9+ import mimetypes
10+ import logging
11+ import enum
1112from contextlib import asynccontextmanager
12- from fastapi import FastAPI , HTTPException , BackgroundTasks , UploadFile , File , Form , Request # Request を追加
13- from fastapi .responses import FileResponse , StreamingResponse # JSONResponse を削除
13+ from fastapi import FastAPI , HTTPException , BackgroundTasks , UploadFile , File , Form , Request
14+ from fastapi .responses import FileResponse , StreamingResponse
1415from fastapi .middleware .cors import CORSMiddleware
1516from pydantic import BaseModel , Field
1617from PIL import Image
1718import numpy as np
18- from typing import List , Optional # Import Optional (Dict removed as unused)
19- from watchdog .observers import Observer # 追加: Watchdog Observer
19+ from typing import List , Optional
20+ from watchdog .observers import Observer
2021
2122# Import modules created earlier (relative imports)
2223from . import settings
3940
4041
4142# --- Lifespan Context Manager ---
42- @asynccontextmanager # Use the imported decorator directly
43+ @asynccontextmanager
4344async def lifespan (app : FastAPI ):
4445 # Startup logic
45- global loaded_models , worker_running , worker_thread , observer , sse_clients # Add observer and sse_clients
46+ global loaded_models , worker_running , worker_thread , observer , sse_clients
4647 print ("API starting up via lifespan..." )
4748 # Load models
4849 try :
4950 # Consider running blocking IO in a threadpool executor in async context
50- # e.g., await asyncio.to_thread(models.load_models) # lora_path removed
51+ # e.g., await asyncio.to_thread(models.load_models)
5152 # For simplicity now, keeping the direct call but be aware of potential blocking
5253 loaded_models = models .load_models ()
5354 print ("Models loaded successfully via lifespan." )
@@ -75,7 +76,7 @@ async def lifespan(app: FastAPI):
7576 except Exception as e :
7677 print (f"FATAL: Failed to start video watcher on startup: { e } " )
7778 traceback .print_exc ()
78- observer = None # Ensure observer is None if startup failed
79+ observer = None
7980
8081 yield
8182
@@ -106,7 +107,7 @@ async def lifespan(app: FastAPI):
106107 # Cleanup resources
107108 print ("Attempting to unload models..." )
108109 try :
109- models .unload_models (loaded_models ) # Call the function from models module
110+ models .unload_models (loaded_models )
110111 print ("Models unloaded successfully (or placeholder executed)." )
111112 except Exception as unload_e :
112113 print (f"Error during model unloading: { unload_e } " )
@@ -121,18 +122,18 @@ async def lifespan(app: FastAPI):
121122# --- CORS Middleware Configuration ---
122123app .add_middleware (
123124 CORSMiddleware ,
124- allow_origins = settings .ALLOWED_ORIGINS , # Use the loaded origins from settings
125- allow_credentials = True , # Allow credentials (cookies, authorization headers, etc.)
126- allow_methods = ["*" ], # Allow all methods (GET, POST, etc.)
127- allow_headers = ["*" ], # Allow all headers (Content-Type, Authorization, etc.)
125+ allow_origins = settings .ALLOWED_ORIGINS ,
126+ allow_credentials = True ,
127+ allow_methods = ["*" ],
128+ allow_headers = ["*" ],
128129)
129130# --- End CORS Middleware Configuration ---
130131
131132
132133# --- Pydantic Models for API Requests/Responses ---
133134class GenerateRequest (BaseModel ):
134135 prompt : str = Field (..., description = "Text prompt for video generation." )
135- # image: str = Field(..., description="Base64 encoded input image.") # Changed to use UploadFile
136+ # image: str = Field(..., description="Base64 encoded input image.")
136137 video_length : float = Field (5.0 , description = "Length of the video in seconds." , gt = 0 )
137138 seed : int = Field (- 1 , description = "Seed for generation. -1 for random." )
138139 use_teacache : bool = Field (False , description = "Enable TEACache optimization." )
@@ -177,22 +178,28 @@ class ResultResponse(BaseModel):
177178 thumbnail_base64 : Optional [str ] = None
178179
179180
181+ # --- Enum for Sampling Mode ---
182+ class SamplingMode (str , enum .Enum ):
183+ reverse = "reverse"
184+ forward = "forward"
185+
186+
180187# --- Background Worker ---
181188def background_worker_task ():
182189 global worker_running , currently_processing_job_id
183190 print ("Background worker started." )
184191 while worker_running :
185192 next_job = queue_manager .get_next_job ()
186193 if next_job :
187- currently_processing_job_id = next_job .job_id # Set current job ID
194+ currently_processing_job_id = next_job .job_id
188195 print (f"Worker picked up job: { currently_processing_job_id } " )
189196 try :
190197 # Ensure models are loaded before processing
191198 if not loaded_models :
192199 print ("Error: Models not loaded. Cannot process job." )
193200 queue_manager .update_job_status (currently_processing_job_id , "failed - models not loaded" )
194- currently_processing_job_id = None # Clear current job ID on error
195- continue # Skip to next loop iteration
201+ currently_processing_job_id = None
202+ continue
196203
197204 worker .worker (next_job , loaded_models )
198205 except Exception as e :
@@ -220,18 +227,19 @@ def background_worker_task():
220227@app .post ("/generate" , response_model = GenerateResponse )
221228async def generate_video (
222229 background_tasks : BackgroundTasks ,
223- prompt : str = Form ("A character doing some simple body movements." ), # Set default prompt
230+ prompt : str = Form ("A character doing some simple body movements." ),
224231 video_length : float = Form (5.0 ),
225232 seed : int = Form (- 1 ),
226- use_teacache : bool = Form (True ), # Default to True (matching demo_gradio.py)
227- gpu_memory_preservation : float = Form (6.0 ), # Default to 6.0 GB (matching demo_gradio.py)
233+ use_teacache : bool = Form (True ),
234+ gpu_memory_preservation : float = Form (6.0 ),
228235 steps : int = Form (25 ),
229236 cfg : float = Form (1.0 ),
230237 gs : float = Form (10.0 ),
231238 rs : float = Form (0.0 ),
232239 mp4_crf : float = Form (16.0 ),
233- lora_scale : float = Form (1.0 ), # 追加: LoRA強度パラメータ
234- lora_path : Optional [str ] = Form (None , description = "Path to the LoRA file to use for this request (overrides server default if provided)." ), # 追加: LoRAファイルパス
240+ lora_scale : float = Form (1.0 ),
241+ lora_path : Optional [str ] = Form (None , description = "Path to the LoRA file to use for this request (overrides server default if provided)." ),
242+ sampling_mode : SamplingMode = Form (SamplingMode .reverse , description = "Sampling loop direction." ),
235243 image : UploadFile = File (...)
236244):
237245 """
@@ -258,12 +266,24 @@ async def generate_video(
258266 finally :
259267 await image .close ()
260268
269+ # Determine the transformer model based on sampling_mode
270+ # Use sampling_mode.value to get the string value from the Enum
271+ if sampling_mode == SamplingMode .forward :
272+ actual_transformer_model = "f1"
273+ elif sampling_mode == SamplingMode .reverse :
274+ actual_transformer_model = "base"
275+ else :
276+ # This 'else' block might be unreachable if using Enum correctly,
277+ # but kept for safety or future expansion. FastAPI handles invalid Enum values.
278+ print (f"Warning: Unexpected sampling_mode '{ sampling_mode .value } '. Defaulting transformer_model to 'base'." )
279+ actual_transformer_model = "base"
280+
261281 # Add job to the queue using queue_manager
262282 try :
263283 job_id = queue_manager .add_to_queue (
264284 prompt = prompt ,
265285 image = image_np ,
266- original_exif = original_exif , # Pass extracted Exif data
286+ original_exif = original_exif ,
267287 video_length = video_length ,
268288 seed = seed ,
269289 use_teacache = use_teacache ,
@@ -273,9 +293,11 @@ async def generate_video(
273293 gs = gs ,
274294 rs = rs ,
275295 mp4_crf = mp4_crf ,
276- lora_scale = lora_scale , # 追加: lora_scale を渡す
277- lora_path = lora_path , # 追加: lora_path を渡す
278- status = "pending" # Explicitly set initial status
296+ lora_scale = lora_scale ,
297+ lora_path = lora_path ,
298+ sampling_mode = sampling_mode .value ,
299+ transformer_model = actual_transformer_model ,
300+ status = "pending"
279301 )
280302 except Exception as e :
281303 print (f"Error adding job via queue_manager: { e } " )
@@ -313,7 +335,7 @@ async def get_job_status(job_id: str):
313335 return JobStatusResponse (job_id = job_id , status = "processing" , progress_info = "Details temporarily unavailable" )
314336
315337 # 2. Check if the job exists in the queue file (pending, failed, potentially completed but file not checked yet)
316- job_in_file = queue_manager .get_job_by_id (job_id ) # Use the function that reads file
338+ job_in_file = queue_manager .get_job_by_id (job_id )
317339 if job_in_file :
318340 # Return the status and progress details from the file
319341 return JobStatusResponse (
@@ -347,7 +369,7 @@ async def stream_job_status(job_id: str, request: Request):
347369 """
348370 async def event_generator ():
349371 last_data_sent = None
350- # terminal_statuses = {"completed", "cancelled"} # Unused variable removed
372+ # terminal_statuses = {"completed", "cancelled"}
351373
352374 while True :
353375 # Check if client disconnected
@@ -388,14 +410,14 @@ async def event_generator():
388410 # Send final status event if it hasn't been sent already
389411 if current_data_json != last_data_sent :
390412 yield f"event: progress\n data: { current_data_json } \n \n "
391- last_data_sent = current_data_json # Ensure last_data_sent is updated even for the final message
413+ last_data_sent = current_data_json
392414 print (f"Sent final progress update for job { job_id } : Status { job .status } " )
393415
394416 # Send a dedicated 'status' event to signal completion/failure/cancellation
395417 final_status_data = json .dumps ({"status" : job .status , "message" : "Job finished." })
396418 yield f"event: status\n data: { final_status_data } \n \n "
397419 print (f"Job { job_id } reached terminal state: { job .status } . Closing stream." )
398- break # Exit loop after sending final status
420+ break
399421 else :
400422 # Wait before checking again only if not terminal
401423 await asyncio .sleep (1 ) # Check every 1 second
@@ -404,7 +426,7 @@ async def event_generator():
404426 final_data = json .dumps ({"status" : job .status , "message" : "Job finished." })
405427 yield f"event: status\n data: { final_data } \n \n "
406428 print (f"Job { job_id } reached terminal state: { job .status } . Closing stream." )
407- break # Exit loop after sending final status
429+ break
408430
409431 # Wait before checking again
410432 await asyncio .sleep (1 ) # Check every 1 second
@@ -413,7 +435,7 @@ async def event_generator():
413435
414436
415437@app .get ("/result/{job_id}" , response_model = ResultResponse )
416- async def get_job_result (job_id : str , request : Request ): # requestを追加してURLを構築
438+ async def get_job_result (job_id : str , request : Request ):
417439 """
418440 Returns the download URL for the completed video and the Base64 encoded thumbnail.
419441 """
@@ -440,7 +462,7 @@ async def get_job_result(job_id: str, request: Request): # requestを追加し
440462 thumbnail_base64 = f"data:{ mime_type } ;base64,{ thumbnail_base64_data } "
441463 else :
442464 # MIMEタイプが不明な場合はデフォルトを使用(またはエラー処理)
443- thumbnail_base64 = f"data:image/jpeg;base64,{ thumbnail_base64_data } " # デフォルトをJPEGに
465+ thumbnail_base64 = f"data:image/jpeg;base64,{ thumbnail_base64_data } "
444466 print (f"Job { job_id } : Encoded thumbnail from { job .thumbnail } " )
445467 except Exception as e :
446468 print (f"Job { job_id } : Error reading or encoding thumbnail { job .thumbnail } : { e } " )
@@ -479,7 +501,7 @@ async def get_input_image(job_id: str):
479501 Returns the input JPEG image file associated with a job, potentially including Exif metadata.
480502 """
481503 job = queue_manager .get_job_by_id (job_id )
482- filename_base = f"queue_image_{ job_id } .jpg" # Changed extension to jpg
504+ filename_base = f"queue_image_{ job_id } .jpg"
483505 input_image_path_in_temp = os .path .join (settings .TEMP_QUEUE_IMAGES_DIR , filename_base )
484506
485507 if not job :
@@ -563,16 +585,16 @@ async def get_worker_status():
563585
564586@app .post ("/cleanup_jobs" , status_code = 200 )
565587async def trigger_cleanup_jobs ():
566- """ # Correct indentation for docstring
588+ """
567589 Manually triggers the cleanup of old completed, cancelled, or failed jobs
568590 based on the MAX_COMPLETED_JOBS setting.
569591 """
570- try : # Correct indentation for try block
592+ try :
571593 removed_count = queue_manager .cleanup_jobs_by_max_count ()
572594 return {"message" : f"Cleanup process completed. Removed { removed_count } old job entries." }
573595 except Exception as e :
574- print (f"Error during manual job cleanup: { e } " ) # Correct indentation
575- traceback .print_exc () # Correct indentation
596+ print (f"Error during manual job cleanup: { e } " )
597+ traceback .print_exc ()
576598 raise HTTPException (status_code = 500 , detail = f"Failed to perform job cleanup: { e } " )
577599
578600
@@ -582,22 +604,22 @@ async def trigger_cleanup_jobs():
582604async def list_loras ():
583605 """Lists available LoRA files from the configured directory."""
584606 lora_files = []
585- allowed_extensions = {".safetensors" , ".pt" , ".bin" } # Common LoRA extensions
607+ allowed_extensions = {".safetensors" , ".pt" , ".bin" }
586608 try :
587609 if os .path .isdir (settings .LORA_DIR ):
588610 for filename in os .listdir (settings .LORA_DIR ):
589611 if os .path .isfile (os .path .join (settings .LORA_DIR , filename )):
590612 _ , ext = os .path .splitext (filename )
591613 if ext .lower () in allowed_extensions :
592614 lora_files .append (filename )
593- lora_files .sort () # Sort alphabetically
615+ lora_files .sort ()
594616 else :
595617 print (f"Warning: LORA_DIR '{ settings .LORA_DIR } ' is not a valid directory." )
596618 except Exception as e :
597619 print (f"Error listing LoRA files: { e } " )
598620 # Return empty list on error, or raise HTTPException
599621 # raise HTTPException(status_code=500, detail=f"Failed to list LoRA files: {e}")
600- return LoraListResponse (loras = lora_files ) # Correct indentation for return
622+ return LoraListResponse (loras = lora_files )
601623
602624
603625# === Video Streaming Endpoints ===
@@ -632,7 +654,7 @@ async def event_generator():
632654 logging .error (f"Error in SSE generator: { e } " )
633655 # Optionally send an error event to the client
634656 # yield f"event: error\ndata: {json.dumps({'message': 'Internal server error'})}\n\n"
635- break # Stop streaming on unexpected errors
657+ break
636658 finally :
637659 # Cleanup when client disconnects or loop breaks
638660 if client_queue in sse_clients :
@@ -689,4 +711,4 @@ async def list_videos():
689711 # Configure logging for the main execution context as well
690712 logging .basicConfig (level = logging .INFO , format = '%(asctime)s - %(levelname)s - %(message)s' )
691713 print (f"Starting Uvicorn server on { settings .API_HOST } :{ settings .API_PORT } " )
692- uvicorn .run ("api.api:app" , host = settings .API_HOST , port = settings .API_PORT , reload = True ) # Use string import for reload
714+ uvicorn .run ("api.api:app" , host = settings .API_HOST , port = settings .API_PORT , reload = True )
0 commit comments