@@ -233,7 +233,6 @@ async def generate_video(
233233 lora_scale : float = Form (1.0 ),
234234 lora_path : Optional [str ] = Form (None , description = "Path to the LoRA file to use for this request (overrides server default if provided)." ),
235235 sampling_mode : str = Form ("reverse" , description = "Sampling loop direction ('reverse' or 'forward')." ),
236- transformer_model : str = Form ("base" , description = "Transformer model to use ('base' or 'f1')." ),
237236 image : UploadFile = File (...)
238237):
239238 """
@@ -260,6 +259,18 @@ async def generate_video(
260259 finally :
261260 await image .close ()
262261
262+ # Determine the transformer model based on sampling_mode
263+ if sampling_mode == "forward" :
264+ actual_transformer_model = "f1"
265+ elif sampling_mode == "reverse" :
266+ actual_transformer_model = "base"
267+ else :
268+ # Handle unexpected sampling_mode, perhaps default to 'base' or raise error
269+ print (f"Warning: Unexpected sampling_mode '{ sampling_mode } '. Defaulting transformer_model to 'base'." )
270+ actual_transformer_model = "base"
271+ # Alternatively, raise HTTPException:
272+ # raise HTTPException(status_code=400, detail=f"Invalid sampling_mode: {sampling_mode}. Must be 'reverse' or 'forward'.")
273+
263274 # Add job to the queue using queue_manager
264275 try :
265276 job_id = queue_manager .add_to_queue (
@@ -278,7 +289,7 @@ async def generate_video(
278289 lora_scale = lora_scale ,
279290 lora_path = lora_path ,
280291 sampling_mode = sampling_mode ,
281- transformer_model = transformer_model ,
292+ transformer_model = actual_transformer_model , # Use the determined model
282293 status = "pending"
283294 )
284295 except Exception as e :
0 commit comments