88import base64
99import mimetypes
1010import logging
11+ import enum
1112from contextlib import asynccontextmanager
1213from fastapi import FastAPI , HTTPException , BackgroundTasks , UploadFile , File , Form , Request
1314from fastapi .responses import FileResponse , StreamingResponse
@@ -177,6 +178,12 @@ class ResultResponse(BaseModel):
177178 thumbnail_base64 : Optional [str ] = None
178179
179180
181+ # --- Enum for Sampling Mode ---
182+ class SamplingMode (str , enum .Enum ):
183+ reverse = "reverse"
184+ forward = "forward"
185+
186+
180187# --- Background Worker ---
181188def background_worker_task ():
182189 global worker_running , currently_processing_job_id
@@ -232,7 +239,7 @@ async def generate_video(
232239 mp4_crf : float = Form (16.0 ),
233240 lora_scale : float = Form (1.0 ),
234241 lora_path : Optional [str ] = Form (None , description = "Path to the LoRA file to use for this request (overrides server default if provided)." ),
235- sampling_mode : str = Form (" reverse" , description = "Sampling loop direction ('reverse' or 'forward') ." ),
242+ sampling_mode : SamplingMode = Form (SamplingMode . reverse , description = "Sampling loop direction." ),
236243 image : UploadFile = File (...)
237244):
238245 """
@@ -260,16 +267,16 @@ async def generate_video(
260267 await image .close ()
261268
262269 # Determine the transformer model based on sampling_mode
263- if sampling_mode == "forward" :
270+ # Use sampling_mode.value to get the string value from the Enum
271+ if sampling_mode == SamplingMode .forward :
264272 actual_transformer_model = "f1"
265- elif sampling_mode == " reverse" :
273+ elif sampling_mode == SamplingMode . reverse :
266274 actual_transformer_model = "base"
267275 else :
268- # Handle unexpected sampling_mode, perhaps default to 'base' or raise error
269- print (f"Warning: Unexpected sampling_mode '{ sampling_mode } '. Defaulting transformer_model to 'base'." )
276+ # This 'else' block might be unreachable if using Enum correctly,
277+ # but kept for safety or future expansion. FastAPI handles invalid Enum values.
278+ print (f"Warning: Unexpected sampling_mode '{ sampling_mode .value } '. Defaulting transformer_model to 'base'." )
270279 actual_transformer_model = "base"
271- # Alternatively, raise HTTPException:
272- # raise HTTPException(status_code=400, detail=f"Invalid sampling_mode: {sampling_mode}. Must be 'reverse' or 'forward'.")
273280
274281 # Add job to the queue using queue_manager
275282 try :
@@ -288,8 +295,8 @@ async def generate_video(
288295 mp4_crf = mp4_crf ,
289296 lora_scale = lora_scale ,
290297 lora_path = lora_path ,
291- sampling_mode = sampling_mode ,
292- transformer_model = actual_transformer_model , # Use the determined model
298+ sampling_mode = sampling_mode . value ,
299+ transformer_model = actual_transformer_model ,
293300 status = "pending"
294301 )
295302 except Exception as e :
0 commit comments