Skip to content

Commit c623ee4

Browse files
authored
Add new prompt for animation and update requirements (#11)
* Add new prompt for animation and update requirements - Added a new prompt for a digitally rendered animation featuring a young woman in swimwear to quick_prompts.json. - Updated requirements.txt to include the 'watchdog' package for improved file monitoring. * Remove inappropriate prompt related to nudity from quick_prompts.json * quick_prompts.jsonの末尾に改行を追加 * GPUデバイスの選択ロジックをテストファイルに追加 * GPUデバイスの選択ロジックを例外処理でラップし、エラーハンドリングを追加 * GPUデバイスの選択ロジックを例外処理でラップし、エラーハンドリングを改善 * torchとtorchvisionをrequirements.txtに追加 * get_cuda_free_memory_gb関数のエラーハンドリングを改善し、メモリ統計のキーが存在しない場合に安全に0を返すように修正 * get_cuda_free_memory_gb関数の不要な空行を削除 * デバッグ用のprint文をコメントアウト * プレビュー機能を実装し、ジョブの進行状況に応じてBase64エンコードされたプレビュー画像を生成・管理する機能を追加
1 parent 898e2a7 commit c623ee4

11 files changed

Lines changed: 545 additions & 20 deletions

File tree

PLAN.md

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# 動画連続再生機能 実装計画
2+
3+
## 概要
4+
5+
ローカルの特定ディレクトリ (`outputs`) に順次追加される動画ファイル (`.mp4`) を検知し、FastAPIのSSE (Server-Sent Events) を通じてクライアント (React想定) に通知する。クライアントは通知されたファイル名を元に動画をリクエストし、連続再生を行う。既存のFramePack API (`api/api.py`) に機能を追加し、他の機能への影響を最小限に抑える。
6+
7+
## 計画詳細
8+
9+
1. **設定更新 (`api/settings.py`):**
10+
* 監視対象ディレクトリパス `VIDEO_DIR` を定義する。デフォルトはプロジェクトルート下の `outputs` ディレクトリとする。環境変数 `VIDEO_DIR` が設定されていれば、その値を優先する。
11+
* 動画ファイル配信用エンドポイントのベースURL `VIDEO_BASE_URL``/videos/` として定義する (主にクライアント側の参考情報)。
12+
13+
2. **ファイル監視ロジック (`api/video_watcher.py` - 新規ファイル):**
14+
* `watchdog` ライブラリを使用する `VideoHandler` クラスを作成する。
15+
* `on_created` イベントハンドラを実装し、`.mp4` ファイルが作成された場合のみ、FastAPI側のSSEクライアントキューリスト (`sse_clients`) にファイル名を追加する。
16+
* 監視を開始/停止する関数 (`start_watcher`, `stop_watcher`) を作成する。
17+
* `start_watcher(path, clients)`: 指定されたパスを監視し、通知先のクライアントキューリストを受け取る。`watchdog.observers.Observer` インスタンスを初期化・開始し、そのインスタンスを返す。
18+
* `stop_watcher(observer)`: 受け取った `Observer` インスタンスを停止・結合する。
19+
20+
3. **FastAPIエンドポイント追加 (`api/api.py`):**
21+
* **グローバル変数:**
22+
* `sse_clients = []`: SSEクライアントごとの通知キュー (`asyncio.Queue` など) を保持するリスト。
23+
* `observer = None`: `watchdog``Observer` インスタンスを保持する変数。
24+
* **`/video_stream` (GET, SSE):**
25+
* 新しいクライアント接続時に、専用の通知キューを作成し `sse_clients` に追加する。
26+
* 非同期ジェネレータ関数を定義する。
27+
* 無限ループでクライアントの接続状態をチェックする。
28+
* キューから新しいファイル名を取得し、`data: {filename}\n\n` 形式で `yield` する。
29+
* クライアント切断時には、対応するキューを `sse_clients` から削除し、ループを終了する。
30+
* `StreamingResponse` で上記ジェネレータを返す (`media_type="text/event-stream"`)。
31+
* **`/videos/{filename}` (GET):**
32+
* `settings.VIDEO_DIR` とリクエストされた `filename` を結合して、動画ファイルのフルパスを構築する。
33+
* `os.path.exists` でファイルの存在を確認する。
34+
* 存在すれば `FileResponse` を使用して動画ファイル (`media_type="video/mp4"`) を返す。
35+
* 存在しなければ `HTTPException(status_code=404, detail="File not found")` を発生させる。
36+
* **`/videos` (GET):**
37+
* `settings.VIDEO_DIR` 内のファイルを `os.listdir` で取得する。
38+
* ファイル名が `.mp4` で終わるもののみをフィルタリングする。
39+
* フィルタリングされたファイル名のリストをJSON形式で返す。
40+
41+
4. **ライフサイクル管理 (`api/api.py``lifespan`):**
42+
* 既存の `lifespan` コンテキストマネージャを修正する。
43+
* **Startup:**
44+
* `video_watcher.start_watcher(settings.VIDEO_DIR, sse_clients)` を呼び出し、返された `Observer` インスタンスをグローバル変数 `observer` に格納する。
45+
* **Shutdown:**
46+
* グローバル変数 `observer``None` でなければ、`video_watcher.stop_watcher(observer)` を呼び出してファイル監視プロセスを安全に停止する。
47+
48+
5. **依存関係:**
49+
* `watchdog` ライブラリが必要となるため、プロジェクトの依存関係ファイル (`requirements.txt``pyproject.toml` など) に `watchdog` を追加する。
50+
51+
## Mermaid図
52+
53+
```mermaid
54+
graph TD
55+
subgraph FastAPI Backend (api/api.py)
56+
A[Client connects to /video_stream] --> B{Create SSE queue (e.g., asyncio.Queue)};
57+
B --> C[Add queue to global sse_clients list];
58+
C --> D[Start SSE generation loop (async def)];
59+
D -- New filename in queue --> E[yield f"data: {filename}\n\n"];
60+
D -- Client disconnects --> F[Remove queue from sse_clients & break loop];
61+
62+
G[Client requests /videos/{filename}] --> H{Build file path using settings.VIDEO_DIR};
63+
H -- os.path.exists is True --> I[Return FileResponse(path, media_type="video/mp4")];
64+
H -- os.path.exists is False --> J[Raise HTTPException(404)];
65+
66+
K[Client requests /videos] --> L{os.listdir(settings.VIDEO_DIR)};
67+
L --> M[Filter for .mp4 files, return JSON list];
68+
69+
N[lifespan startup] --> O[observer = video_watcher.start_watcher(VIDEO_DIR, sse_clients)];
70+
P[lifespan shutdown] --> Q[if observer: video_watcher.stop_watcher(observer)];
71+
end
72+
73+
subgraph File System Watcher (api/video_watcher.py - New File)
74+
R[Watchdog Observer monitors VIDEO_DIR] -- New .mp4 created --> S[VideoHandler.on_created];
75+
S --> T{Get filename};
76+
T --> U[Add filename to all queues in sse_clients list];
77+
V[start_watcher(path, clients)] --> W[Initialize Observer & Handler, observer.start(), return observer];
78+
X[stop_watcher(observer)] --> Y[observer.stop(), observer.join()];
79+
end
80+
81+
subgraph React Frontend (Out of scope)
82+
Z[Page load requests /videos] --> AA[Get initial file list];
83+
AA --> BB[Initialize playlist];
84+
CC[Connects to /video_stream] --> DD[Receive filename via SSE];
85+
DD --> EE[Add filename to playlist];
86+
BB & EE --> FF[Select random video from playlist];
87+
FF --> GG[Request /videos/{filename}];
88+
GG --> HH[Receive video data & play];
89+
end
90+
91+
FastAPI_Backend -- Manages --> File_System_Watcher;
92+
```

api/api.py

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import base64 # 追加: Base64エンコード用
99
import mimetypes # 追加: MIMEタイプ判定用
10+
import logging # 追加: Logging
1011
from contextlib import asynccontextmanager
1112
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form, Request # Request を追加
1213
from fastapi.responses import FileResponse, StreamingResponse # JSONResponse を削除
@@ -15,12 +16,14 @@
1516
from PIL import Image
1617
import numpy as np
1718
from typing import List, Optional # Import Optional (Dict removed as unused)
19+
from watchdog.observers import Observer # 追加: Watchdog Observer
1820

1921
# Import modules created earlier (relative imports)
2022
from . import settings
2123
from . import models
2224
from . import queue_manager
2325
from . import worker
26+
from . import video_watcher
2427

2528
# --- Global State ---
2629
# Dictionary to hold loaded models
@@ -30,13 +33,16 @@
3033
worker_thread = None
3134
# Variable to store the ID of the currently processing job
3235
currently_processing_job_id: str | None = None
36+
# --- Video Watcher State ---
37+
sse_clients: List[asyncio.Queue] = [] # List to hold client queues for SSE
38+
observer: Observer | None = None # type: ignore # Watchdog observer instance
3339

3440

3541
# --- Lifespan Context Manager ---
3642
@asynccontextmanager # Use the imported decorator directly
3743
async def lifespan(app: FastAPI):
3844
# Startup logic
39-
global loaded_models, worker_running, worker_thread
45+
global loaded_models, worker_running, worker_thread, observer, sse_clients # Add observer and sse_clients
4046
print("API starting up via lifespan...")
4147
# Load models
4248
try:
@@ -60,10 +66,32 @@ async def lifespan(app: FastAPI):
6066
else:
6167
print("Worker already running? Skipping start in lifespan.")
6268

69+
# Start video watcher
70+
try:
71+
print(f"Attempting to start video watcher for directory: {settings.VIDEO_DIR}")
72+
# Pass the global sse_clients list to the watcher
73+
observer = video_watcher.start_watcher(settings.VIDEO_DIR, sse_clients)
74+
print("Video watcher started successfully via lifespan.")
75+
except Exception as e:
76+
print(f"FATAL: Failed to start video watcher on startup: {e}")
77+
traceback.print_exc()
78+
observer = None # Ensure observer is None if startup failed
79+
6380
yield
6481

6582
# Shutdown logic
6683
print("API shutting down via lifespan...")
84+
85+
# Stop video watcher first
86+
if observer:
87+
try:
88+
print("Stopping video watcher...")
89+
video_watcher.stop_watcher(observer)
90+
print("Video watcher stopped.")
91+
except Exception as e:
92+
print(f"Error stopping video watcher: {e}")
93+
traceback.print_exc()
94+
6795
# Stop background worker
6896
if worker_running:
6997
worker_running = False
@@ -572,8 +600,93 @@ async def list_loras():
572600
return LoraListResponse(loras=lora_files) # Correct indentation for return
573601

574602

603+
# === Video Streaming Endpoints ===
604+
605+
@app.get("/video_stream")
606+
async def video_stream(request: Request):
607+
"""
608+
Streams new video filenames using Server-Sent Events (SSE).
609+
"""
610+
client_queue = asyncio.Queue()
611+
sse_clients.append(client_queue)
612+
logging.info(f"SSE client connected. Total clients: {len(sse_clients)}")
613+
614+
async def event_generator():
615+
try:
616+
while True:
617+
# Check connection status first
618+
if await request.is_disconnected():
619+
logging.info("SSE client disconnected.")
620+
break
621+
622+
try:
623+
# Wait for a new filename from the queue
624+
filename = await asyncio.wait_for(client_queue.get(), timeout=1.0)
625+
logging.info(f"Sending SSE data: {filename}")
626+
yield f"data: {filename}\n\n"
627+
client_queue.task_done()
628+
except asyncio.TimeoutError:
629+
# No new file, continue loop to check connection status
630+
continue
631+
except Exception as e:
632+
logging.error(f"Error in SSE generator: {e}")
633+
# Optionally send an error event to the client
634+
# yield f"event: error\ndata: {json.dumps({'message': 'Internal server error'})}\n\n"
635+
break # Stop streaming on unexpected errors
636+
finally:
637+
# Cleanup when client disconnects or loop breaks
638+
if client_queue in sse_clients:
639+
sse_clients.remove(client_queue)
640+
logging.info(f"SSE client queue removed. Total clients: {len(sse_clients)}")
641+
642+
return StreamingResponse(event_generator(), media_type="text/event-stream")
643+
644+
645+
@app.get("/videos/{filename}")
646+
async def get_video(filename: str):
647+
"""
648+
Serves a specific video file from the VIDEO_DIR.
649+
"""
650+
# Basic security check: prevent directory traversal
651+
if ".." in filename or filename.startswith("/"):
652+
raise HTTPException(status_code=400, detail="Invalid filename.")
653+
654+
filepath = os.path.join(settings.VIDEO_DIR, filename)
655+
logging.info(f"Request for video file: {filepath}")
656+
657+
if not os.path.exists(filepath) or not os.path.isfile(filepath):
658+
logging.warning(f"Video file not found: {filepath}")
659+
raise HTTPException(status_code=404, detail="Video file not found")
660+
661+
# Check if the file is an mp4 file (optional but recommended)
662+
if not filename.lower().endswith(".mp4"):
663+
raise HTTPException(status_code=400, detail="Invalid file type, only MP4 is supported.")
664+
665+
return FileResponse(filepath, media_type="video/mp4", filename=filename)
666+
667+
668+
@app.get("/videos", response_model=List[str])
669+
async def list_videos():
670+
"""
671+
Lists all .mp4 files currently in the VIDEO_DIR.
672+
"""
673+
try:
674+
all_files = os.listdir(settings.VIDEO_DIR)
675+
mp4_files = sorted([f for f in all_files if f.lower().endswith(".mp4") and os.path.isfile(os.path.join(settings.VIDEO_DIR, f))])
676+
logging.info(f"Found {len(mp4_files)} MP4 files in {settings.VIDEO_DIR}")
677+
return mp4_files
678+
except FileNotFoundError:
679+
logging.error(f"VIDEO_DIR not found: {settings.VIDEO_DIR}")
680+
raise HTTPException(status_code=500, detail="Video directory not found on server.")
681+
except Exception as e:
682+
logging.error(f"Error listing videos in {settings.VIDEO_DIR}: {e}")
683+
raise HTTPException(status_code=500, detail="Error listing video files.")
684+
685+
575686
# --- Main execution (for running with uvicorn) ---
576687
if __name__ == "__main__":
577688
import uvicorn
689+
# Configure logging for the main execution context as well
690+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
578691
print(f"Starting Uvicorn server on {settings.API_HOST}:{settings.API_PORT}")
579-
uvicorn.run(app, host=settings.API_HOST, port=settings.API_PORT)
692+
uvicorn.run("api.api:app", host=settings.API_HOST, port=settings.API_PORT, reload=True) # Use string import for reload

api/queue_manager.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import uuid
55
import numpy as np
66
import logging
7-
from dataclasses import dataclass, field # Import field
8-
from typing import Optional
9-
from datetime import datetime, timezone # Import datetime and timezone
7+
from dataclasses import dataclass, field
8+
from typing import Optional, Dict
9+
from datetime import datetime, timezone
1010
from PIL import Image
1111
# from PIL.PngImagePlugin import PngInfo # No longer needed for JPEG saving
1212

@@ -137,6 +137,9 @@ def from_dict(cls, data):
137137
# Initialize job queue as a list
138138
job_queue = []
139139

140+
# Dictionary to hold the latest preview image (Base64) for processing jobs (in-memory only)
141+
current_previews: Dict[str, str] = {}
142+
140143

141144
def save_queue():
142145
global job_queue
@@ -416,13 +419,43 @@ def update_job_status(job_id: str, status: str, thumbnail: str = None):
416419
else:
417420
print(f"Job with ID {job_id} not found in memory or file for status update.")
418421

422+
# Clear preview if the job reached a terminal state
423+
is_terminal = status == "completed" or status == "cancelled" or status.startswith("failed")
424+
if is_terminal:
425+
clear_current_preview(job_id)
426+
419427
return job_updated
420428

421429

422430
# Configure logging (moved import to top)
423431
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
424432

425433

434+
# --- Preview Data Management (In-Memory) ---
435+
436+
def update_current_preview(job_id: str, preview_base64: str):
437+
"""Stores the latest preview image Base64 string for a job."""
438+
global current_previews
439+
current_previews[job_id] = preview_base64
440+
# logging.debug(f"Updated preview for job {job_id}") # Optional: Debug logging
441+
442+
443+
def get_current_preview(job_id: str) -> Optional[str]:
444+
"""Retrieves the latest preview image Base64 string for a job."""
445+
global current_previews
446+
return current_previews.get(job_id)
447+
448+
449+
def clear_current_preview(job_id: str):
450+
"""Removes the preview image Base64 string for a job."""
451+
global current_previews
452+
if job_id in current_previews:
453+
del current_previews[job_id]
454+
logging.info(f"Cleared preview for job {job_id}")
455+
456+
457+
# --- Job Progress Update ---
458+
426459
def update_job_progress(job_id: str, progress: float, step: int, total: int, info: str):
427460
"""Updates the progress fields of a job in the global queue and saves the file."""
428461
logging.info(f"Attempting to update progress for job {job_id}: progress={progress}, step={step}, total={total}, info='{info}'")
@@ -649,6 +682,8 @@ def cleanup_jobs_by_max_count(max_completed_jobs: int = settings.MAX_COMPLETED_J
649682
files_to_delete.append(job.image_path)
650683
if job.thumbnail:
651684
files_to_delete.append(job.thumbnail)
685+
# Also clear any lingering preview data for the removed job
686+
clear_current_preview(job.job_id)
652687
removed_count += 1
653688

654689
# Combine kept terminal jobs and active jobs for the new queue

api/settings.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,17 @@
2727
QUEUE_FILE_PATH = os.path.join(PROJECT_ROOT, 'job_queue.json')
2828
HF_HOME_DIR = os.path.join(PROJECT_ROOT, 'hf_download')
2929
LORA_DIR = os.environ.get("LORA_DIR", os.path.join(PROJECT_ROOT, 'loras'))
30+
# Video Watcher Settings
31+
VIDEO_DIR = os.environ.get("VIDEO_DIR", OUTPUTS_DIR) # Use OUTPUTS_DIR as default if env var not set
3032

3133

3234
# Ensure directories exist (These should ideally be created outside the API module if they don't exist)
3335
os.makedirs(OUTPUTS_DIR, exist_ok=True)
3436
os.makedirs(TEMP_QUEUE_IMAGES_DIR, exist_ok=True)
3537
os.makedirs(HF_HOME_DIR, exist_ok=True)
3638
os.makedirs(LORA_DIR, exist_ok=True)
39+
# Ensure VIDEO_DIR exists (especially if it's different from OUTPUTS_DIR)
40+
os.makedirs(VIDEO_DIR, exist_ok=True)
3741

3842
# Set Hugging Face home directory environment variable
3943
os.environ['HF_HOME'] = HF_HOME_DIR
@@ -55,6 +59,7 @@
5559
print(f" LoRA Dir: {LORA_DIR}")
5660
print(f" Worker Check Interval: {WORKER_CHECK_INTERVAL}")
5761
print(f" Allowed Origins: {ALLOWED_ORIGINS}")
62+
print(f" Video Dir (for watcher): {VIDEO_DIR}")
5863

5964
# --- Job Cleanup Settings ---
6065
# Maximum number of completed, cancelled, or failed jobs to keep in the queue file.

0 commit comments

Comments
 (0)