Skip to content

Commit 03a2632

Browse files
committed
feat: サンプリングモードの列挙型を追加し、ビデオ生成エンドポイントでの使用を更新
1 parent 796f8db commit 03a2632

1 file changed

Lines changed: 16 additions & 9 deletions

File tree

api/api.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import base64
99
import mimetypes
1010
import logging
11+
import enum
1112
from contextlib import asynccontextmanager
1213
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form, Request
1314
from fastapi.responses import FileResponse, StreamingResponse
@@ -177,6 +178,12 @@ class ResultResponse(BaseModel):
177178
thumbnail_base64: Optional[str] = None
178179

179180

181+
# --- Enum for Sampling Mode ---
182+
class SamplingMode(str, enum.Enum):
183+
reverse = "reverse"
184+
forward = "forward"
185+
186+
180187
# --- Background Worker ---
181188
def background_worker_task():
182189
global worker_running, currently_processing_job_id
@@ -232,7 +239,7 @@ async def generate_video(
232239
mp4_crf: float = Form(16.0),
233240
lora_scale: float = Form(1.0),
234241
lora_path: Optional[str] = Form(None, description="Path to the LoRA file to use for this request (overrides server default if provided)."),
235-
sampling_mode: str = Form("reverse", description="Sampling loop direction ('reverse' or 'forward')."),
242+
sampling_mode: SamplingMode = Form(SamplingMode.reverse, description="Sampling loop direction."),
236243
image: UploadFile = File(...)
237244
):
238245
"""
@@ -260,16 +267,16 @@ async def generate_video(
260267
await image.close()
261268

262269
# Determine the transformer model based on sampling_mode
263-
if sampling_mode == "forward":
270+
# Use sampling_mode.value to get the string value from the Enum
271+
if sampling_mode == SamplingMode.forward:
264272
actual_transformer_model = "f1"
265-
elif sampling_mode == "reverse":
273+
elif sampling_mode == SamplingMode.reverse:
266274
actual_transformer_model = "base"
267275
else:
268-
# Handle unexpected sampling_mode, perhaps default to 'base' or raise error
269-
print(f"Warning: Unexpected sampling_mode '{sampling_mode}'. Defaulting transformer_model to 'base'.")
276+
# This 'else' block might be unreachable if using Enum correctly,
277+
# but kept for safety or future expansion. FastAPI handles invalid Enum values.
278+
print(f"Warning: Unexpected sampling_mode '{sampling_mode.value}'. Defaulting transformer_model to 'base'.")
270279
actual_transformer_model = "base"
271-
# Alternatively, raise HTTPException:
272-
# raise HTTPException(status_code=400, detail=f"Invalid sampling_mode: {sampling_mode}. Must be 'reverse' or 'forward'.")
273280

274281
# Add job to the queue using queue_manager
275282
try:
@@ -288,8 +295,8 @@ async def generate_video(
288295
mp4_crf=mp4_crf,
289296
lora_scale=lora_scale,
290297
lora_path=lora_path,
291-
sampling_mode=sampling_mode,
292-
transformer_model=actual_transformer_model, # Use the determined model
298+
sampling_mode=sampling_mode.value,
299+
transformer_model=actual_transformer_model,
293300
status="pending"
294301
)
295302
except Exception as e:

0 commit comments

Comments
 (0)