Skip to content

Commit 796f8db

Browse files
committed
feat: サンプリングモードに基づいてトランスフォーマーモデルを動的に決定する機能を追加
1 parent 636f87e commit 796f8db

1 file changed

Lines changed: 13 additions & 2 deletions

File tree

api/api.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,6 @@ async def generate_video(
233233
lora_scale: float = Form(1.0),
234234
lora_path: Optional[str] = Form(None, description="Path to the LoRA file to use for this request (overrides server default if provided)."),
235235
sampling_mode: str = Form("reverse", description="Sampling loop direction ('reverse' or 'forward')."),
236-
transformer_model: str = Form("base", description="Transformer model to use ('base' or 'f1')."),
237236
image: UploadFile = File(...)
238237
):
239238
"""
@@ -260,6 +259,18 @@ async def generate_video(
260259
finally:
261260
await image.close()
262261

262+
# Determine the transformer model based on sampling_mode
263+
if sampling_mode == "forward":
264+
actual_transformer_model = "f1"
265+
elif sampling_mode == "reverse":
266+
actual_transformer_model = "base"
267+
else:
268+
# Handle unexpected sampling_mode, perhaps default to 'base' or raise error
269+
print(f"Warning: Unexpected sampling_mode '{sampling_mode}'. Defaulting transformer_model to 'base'.")
270+
actual_transformer_model = "base"
271+
# Alternatively, raise HTTPException:
272+
# raise HTTPException(status_code=400, detail=f"Invalid sampling_mode: {sampling_mode}. Must be 'reverse' or 'forward'.")
273+
263274
# Add job to the queue using queue_manager
264275
try:
265276
job_id = queue_manager.add_to_queue(
@@ -278,7 +289,7 @@ async def generate_video(
278289
lora_scale=lora_scale,
279290
lora_path=lora_path,
280291
sampling_mode=sampling_mode,
281-
transformer_model=transformer_model,
292+
transformer_model=actual_transformer_model, # Use the determined model
282293
status="pending"
283294
)
284295
except Exception as e:

0 commit comments

Comments
 (0)