55import traceback
66import asyncio
77import json
8- import base64 # 追加: Base64エンコード用
9- import mimetypes # 追加: MIMEタイプ判定用
10- import logging # 追加: Logging
8+ import base64
9+ import mimetypes
10+ import logging
1111from contextlib import asynccontextmanager
12- from fastapi import FastAPI , HTTPException , BackgroundTasks , UploadFile , File , Form , Request # Request を追加
13- from fastapi .responses import FileResponse , StreamingResponse # JSONResponse を削除
12+ from fastapi import FastAPI , HTTPException , BackgroundTasks , UploadFile , File , Form , Request
13+ from fastapi .responses import FileResponse , StreamingResponse
1414from fastapi .middleware .cors import CORSMiddleware
1515from pydantic import BaseModel , Field
1616from PIL import Image
1717import numpy as np
18- from typing import List , Optional # Import Optional (Dict removed as unused)
19- from watchdog .observers import Observer # 追加: Watchdog Observer
18+ from typing import List , Optional
19+ from watchdog .observers import Observer
2020
2121# Import modules created earlier (relative imports)
2222from . import settings
3939
4040
4141# --- Lifespan Context Manager ---
42- @asynccontextmanager # Use the imported decorator directly
42+ @asynccontextmanager
4343async def lifespan (app : FastAPI ):
4444 # Startup logic
45- global loaded_models , worker_running , worker_thread , observer , sse_clients # Add observer and sse_clients
45+ global loaded_models , worker_running , worker_thread , observer , sse_clients
4646 print ("API starting up via lifespan..." )
4747 # Load models
4848 try :
4949 # Consider running blocking IO in a threadpool executor in async context
50- # e.g., await asyncio.to_thread(models.load_models) # lora_path removed
50+ # e.g., await asyncio.to_thread(models.load_models)
5151 # For simplicity now, keeping the direct call but be aware of potential blocking
5252 loaded_models = models .load_models ()
5353 print ("Models loaded successfully via lifespan." )
@@ -75,7 +75,7 @@ async def lifespan(app: FastAPI):
7575 except Exception as e :
7676 print (f"FATAL: Failed to start video watcher on startup: { e } " )
7777 traceback .print_exc ()
78- observer = None # Ensure observer is None if startup failed
78+ observer = None
7979
8080 yield
8181
@@ -106,7 +106,7 @@ async def lifespan(app: FastAPI):
106106 # Cleanup resources
107107 print ("Attempting to unload models..." )
108108 try :
109- models .unload_models (loaded_models ) # Call the function from models module
109+ models .unload_models (loaded_models )
110110 print ("Models unloaded successfully (or placeholder executed)." )
111111 except Exception as unload_e :
112112 print (f"Error during model unloading: { unload_e } " )
@@ -121,18 +121,18 @@ async def lifespan(app: FastAPI):
121121# --- CORS Middleware Configuration ---
122122app .add_middleware (
123123 CORSMiddleware ,
124- allow_origins = settings .ALLOWED_ORIGINS , # Use the loaded origins from settings
125- allow_credentials = True , # Allow credentials (cookies, authorization headers, etc.)
126- allow_methods = ["*" ], # Allow all methods (GET, POST, etc.)
127- allow_headers = ["*" ], # Allow all headers (Content-Type, Authorization, etc.)
124+ allow_origins = settings .ALLOWED_ORIGINS ,
125+ allow_credentials = True ,
126+ allow_methods = ["*" ],
127+ allow_headers = ["*" ],
128128)
129129# --- End CORS Middleware Configuration ---
130130
131131
132132# --- Pydantic Models for API Requests/Responses ---
133133class GenerateRequest (BaseModel ):
134134 prompt : str = Field (..., description = "Text prompt for video generation." )
135- # image: str = Field(..., description="Base64 encoded input image.") # Changed to use UploadFile
135+ # image: str = Field(..., description="Base64 encoded input image.")
136136 video_length : float = Field (5.0 , description = "Length of the video in seconds." , gt = 0 )
137137 seed : int = Field (- 1 , description = "Seed for generation. -1 for random." )
138138 use_teacache : bool = Field (False , description = "Enable TEACache optimization." )
@@ -184,15 +184,15 @@ def background_worker_task():
184184 while worker_running :
185185 next_job = queue_manager .get_next_job ()
186186 if next_job :
187- currently_processing_job_id = next_job .job_id # Set current job ID
187+ currently_processing_job_id = next_job .job_id
188188 print (f"Worker picked up job: { currently_processing_job_id } " )
189189 try :
190190 # Ensure models are loaded before processing
191191 if not loaded_models :
192192 print ("Error: Models not loaded. Cannot process job." )
193193 queue_manager .update_job_status (currently_processing_job_id , "failed - models not loaded" )
194- currently_processing_job_id = None # Clear current job ID on error
195- continue # Skip to next loop iteration
194+ currently_processing_job_id = None
195+ continue
196196
197197 worker .worker (next_job , loaded_models )
198198 except Exception as e :
@@ -220,18 +220,20 @@ def background_worker_task():
220220@app .post ("/generate" , response_model = GenerateResponse )
221221async def generate_video (
222222 background_tasks : BackgroundTasks ,
223- prompt : str = Form ("A character doing some simple body movements." ), # Set default prompt
223+ prompt : str = Form ("A character doing some simple body movements." ),
224224 video_length : float = Form (5.0 ),
225225 seed : int = Form (- 1 ),
226- use_teacache : bool = Form (True ), # Default to True (matching demo_gradio.py)
227- gpu_memory_preservation : float = Form (6.0 ), # Default to 6.0 GB (matching demo_gradio.py)
226+ use_teacache : bool = Form (True ),
227+ gpu_memory_preservation : float = Form (6.0 ),
228228 steps : int = Form (25 ),
229229 cfg : float = Form (1.0 ),
230230 gs : float = Form (10.0 ),
231231 rs : float = Form (0.0 ),
232232 mp4_crf : float = Form (16.0 ),
233- lora_scale : float = Form (1.0 ), # 追加: LoRA強度パラメータ
234- lora_path : Optional [str ] = Form (None , description = "Path to the LoRA file to use for this request (overrides server default if provided)." ), # 追加: LoRAファイルパス
233+ lora_scale : float = Form (1.0 ),
234+ lora_path : Optional [str ] = Form (None , description = "Path to the LoRA file to use for this request (overrides server default if provided)." ),
235+ sampling_mode : str = Form ("reverse" , description = "Sampling loop direction ('reverse' or 'forward')." ),
236+ transformer_model : str = Form ("base" , description = "Transformer model to use ('base' or 'f1')." ),
235237 image : UploadFile = File (...)
236238):
237239 """
@@ -263,7 +265,7 @@ async def generate_video(
263265 job_id = queue_manager .add_to_queue (
264266 prompt = prompt ,
265267 image = image_np ,
266- original_exif = original_exif , # Pass extracted Exif data
268+ original_exif = original_exif ,
267269 video_length = video_length ,
268270 seed = seed ,
269271 use_teacache = use_teacache ,
@@ -273,9 +275,11 @@ async def generate_video(
273275 gs = gs ,
274276 rs = rs ,
275277 mp4_crf = mp4_crf ,
276- lora_scale = lora_scale , # 追加: lora_scale を渡す
277- lora_path = lora_path , # 追加: lora_path を渡す
278- status = "pending" # Explicitly set initial status
278+ lora_scale = lora_scale ,
279+ lora_path = lora_path ,
280+ sampling_mode = sampling_mode ,
281+ transformer_model = transformer_model ,
282+ status = "pending"
279283 )
280284 except Exception as e :
281285 print (f"Error adding job via queue_manager: { e } " )
@@ -313,7 +317,7 @@ async def get_job_status(job_id: str):
313317 return JobStatusResponse (job_id = job_id , status = "processing" , progress_info = "Details temporarily unavailable" )
314318
315319 # 2. Check if the job exists in the queue file (pending, failed, potentially completed but file not checked yet)
316- job_in_file = queue_manager .get_job_by_id (job_id ) # Use the function that reads file
320+ job_in_file = queue_manager .get_job_by_id (job_id )
317321 if job_in_file :
318322 # Return the status and progress details from the file
319323 return JobStatusResponse (
@@ -347,7 +351,7 @@ async def stream_job_status(job_id: str, request: Request):
347351 """
348352 async def event_generator ():
349353 last_data_sent = None
350- # terminal_statuses = {"completed", "cancelled"} # Unused variable removed
354+ # terminal_statuses = {"completed", "cancelled"}
351355
352356 while True :
353357 # Check if client disconnected
@@ -388,14 +392,14 @@ async def event_generator():
388392 # Send final status event if it hasn't been sent already
389393 if current_data_json != last_data_sent :
390394 yield f"event: progress\n data: { current_data_json } \n \n "
391- last_data_sent = current_data_json # Ensure last_data_sent is updated even for the final message
395+ last_data_sent = current_data_json
392396 print (f"Sent final progress update for job { job_id } : Status { job .status } " )
393397
394398 # Send a dedicated 'status' event to signal completion/failure/cancellation
395399 final_status_data = json .dumps ({"status" : job .status , "message" : "Job finished." })
396400 yield f"event: status\n data: { final_status_data } \n \n "
397401 print (f"Job { job_id } reached terminal state: { job .status } . Closing stream." )
398- break # Exit loop after sending final status
402+ break
399403 else :
400404 # Wait before checking again only if not terminal
401405 await asyncio .sleep (1 ) # Check every 1 second
@@ -404,7 +408,7 @@ async def event_generator():
404408 final_data = json .dumps ({"status" : job .status , "message" : "Job finished." })
405409 yield f"event: status\n data: { final_data } \n \n "
406410 print (f"Job { job_id } reached terminal state: { job .status } . Closing stream." )
407- break # Exit loop after sending final status
411+ break
408412
409413 # Wait before checking again
410414 await asyncio .sleep (1 ) # Check every 1 second
@@ -413,7 +417,7 @@ async def event_generator():
413417
414418
415419@app .get ("/result/{job_id}" , response_model = ResultResponse )
416- async def get_job_result (job_id : str , request : Request ): # requestを追加してURLを構築
420+ async def get_job_result (job_id : str , request : Request ):
417421 """
418422 Returns the download URL for the completed video and the Base64 encoded thumbnail.
419423 """
@@ -440,7 +444,7 @@ async def get_job_result(job_id: str, request: Request): # requestを追加し
440444 thumbnail_base64 = f"data:{ mime_type } ;base64,{ thumbnail_base64_data } "
441445 else :
442446 # MIMEタイプが不明な場合はデフォルトを使用(またはエラー処理)
443- thumbnail_base64 = f"data:image/jpeg;base64,{ thumbnail_base64_data } " # デフォルトをJPEGに
447+ thumbnail_base64 = f"data:image/jpeg;base64,{ thumbnail_base64_data } "
444448 print (f"Job { job_id } : Encoded thumbnail from { job .thumbnail } " )
445449 except Exception as e :
446450 print (f"Job { job_id } : Error reading or encoding thumbnail { job .thumbnail } : { e } " )
@@ -479,7 +483,7 @@ async def get_input_image(job_id: str):
479483 Returns the input JPEG image file associated with a job, potentially including Exif metadata.
480484 """
481485 job = queue_manager .get_job_by_id (job_id )
482- filename_base = f"queue_image_{ job_id } .jpg" # Changed extension to jpg
486+ filename_base = f"queue_image_{ job_id } .jpg"
483487 input_image_path_in_temp = os .path .join (settings .TEMP_QUEUE_IMAGES_DIR , filename_base )
484488
485489 if not job :
@@ -563,16 +567,16 @@ async def get_worker_status():
563567
564568@app .post ("/cleanup_jobs" , status_code = 200 )
565569async def trigger_cleanup_jobs ():
566- """ # Correct indentation for docstring
570+ """
567571 Manually triggers the cleanup of old completed, cancelled, or failed jobs
568572 based on the MAX_COMPLETED_JOBS setting.
569573 """
570- try : # Correct indentation for try block
574+ try :
571575 removed_count = queue_manager .cleanup_jobs_by_max_count ()
572576 return {"message" : f"Cleanup process completed. Removed { removed_count } old job entries." }
573577 except Exception as e :
574- print (f"Error during manual job cleanup: { e } " ) # Correct indentation
575- traceback .print_exc () # Correct indentation
578+ print (f"Error during manual job cleanup: { e } " )
579+ traceback .print_exc ()
576580 raise HTTPException (status_code = 500 , detail = f"Failed to perform job cleanup: { e } " )
577581
578582
@@ -582,22 +586,22 @@ async def trigger_cleanup_jobs():
582586async def list_loras ():
583587 """Lists available LoRA files from the configured directory."""
584588 lora_files = []
585- allowed_extensions = {".safetensors" , ".pt" , ".bin" } # Common LoRA extensions
589+ allowed_extensions = {".safetensors" , ".pt" , ".bin" }
586590 try :
587591 if os .path .isdir (settings .LORA_DIR ):
588592 for filename in os .listdir (settings .LORA_DIR ):
589593 if os .path .isfile (os .path .join (settings .LORA_DIR , filename )):
590594 _ , ext = os .path .splitext (filename )
591595 if ext .lower () in allowed_extensions :
592596 lora_files .append (filename )
593- lora_files .sort () # Sort alphabetically
597+ lora_files .sort ()
594598 else :
595599 print (f"Warning: LORA_DIR '{ settings .LORA_DIR } ' is not a valid directory." )
596600 except Exception as e :
597601 print (f"Error listing LoRA files: { e } " )
598602 # Return empty list on error, or raise HTTPException
599603 # raise HTTPException(status_code=500, detail=f"Failed to list LoRA files: {e}")
600- return LoraListResponse (loras = lora_files ) # Correct indentation for return
604+ return LoraListResponse (loras = lora_files )
601605
602606
603607# === Video Streaming Endpoints ===
@@ -632,7 +636,7 @@ async def event_generator():
632636 logging .error (f"Error in SSE generator: { e } " )
633637 # Optionally send an error event to the client
634638 # yield f"event: error\ndata: {json.dumps({'message': 'Internal server error'})}\n\n"
635- break # Stop streaming on unexpected errors
639+ break
636640 finally :
637641 # Cleanup when client disconnects or loop breaks
638642 if client_queue in sse_clients :
@@ -689,4 +693,4 @@ async def list_videos():
689693 # Configure logging for the main execution context as well
690694 logging .basicConfig (level = logging .INFO , format = '%(asctime)s - %(levelname)s - %(message)s' )
691695 print (f"Starting Uvicorn server on { settings .API_HOST } :{ settings .API_PORT } " )
692- uvicorn .run ("api.api:app" , host = settings .API_HOST , port = settings .API_PORT , reload = True ) # Use string import for reload
696+ uvicorn .run ("api.api:app" , host = settings .API_HOST , port = settings .API_PORT , reload = True )
0 commit comments