Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions PLAN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# 動画連続再生機能 実装計画

## 概要

ローカルの特定ディレクトリ (`outputs`) に順次追加される動画ファイル (`.mp4`) を検知し、FastAPIのSSE (Server-Sent Events) を通じてクライアント (React想定) に通知する。クライアントは通知されたファイル名を元に動画をリクエストし、連続再生を行う。既存のFramePack API (`api/api.py`) に機能を追加し、他の機能への影響を最小限に抑える。

## 計画詳細

1. **設定更新 (`api/settings.py`):**
* 監視対象ディレクトリパス `VIDEO_DIR` を定義する。デフォルトはプロジェクトルート下の `outputs` ディレクトリとする。環境変数 `VIDEO_DIR` が設定されていれば、その値を優先する。
* 動画ファイル配信用エンドポイントのベースURL `VIDEO_BASE_URL` を `/videos/` として定義する (主にクライアント側の参考情報)。

2. **ファイル監視ロジック (`api/video_watcher.py` - 新規ファイル):**
* `watchdog` ライブラリを使用する `VideoHandler` クラスを作成する。
* `on_created` イベントハンドラを実装し、`.mp4` ファイルが作成された場合のみ、FastAPI側のSSEクライアントキューリスト (`sse_clients`) にファイル名を追加する。
* 監視を開始/停止する関数 (`start_watcher`, `stop_watcher`) を作成する。
* `start_watcher(path, clients)`: 指定されたパスを監視し、通知先のクライアントキューリストを受け取る。`watchdog.observers.Observer` インスタンスを初期化・開始し、そのインスタンスを返す。
* `stop_watcher(observer)`: 受け取った `Observer` インスタンスを停止・結合する。

3. **FastAPIエンドポイント追加 (`api/api.py`):**
* **グローバル変数:**
* `sse_clients = []`: SSEクライアントごとの通知キュー (`asyncio.Queue` など) を保持するリスト。
* `observer = None`: `watchdog` の `Observer` インスタンスを保持する変数。
* **`/video_stream` (GET, SSE):**
* 新しいクライアント接続時に、専用の通知キューを作成し `sse_clients` に追加する。
* 非同期ジェネレータ関数を定義する。
* 無限ループでクライアントの接続状態をチェックする。
* キューから新しいファイル名を取得し、`data: {filename}\n\n` 形式で `yield` する。
* クライアント切断時には、対応するキューを `sse_clients` から削除し、ループを終了する。
* `StreamingResponse` で上記ジェネレータを返す (`media_type="text/event-stream"`)。
* **`/videos/{filename}` (GET):**
* `settings.VIDEO_DIR` とリクエストされた `filename` を結合して、動画ファイルのフルパスを構築する。
* `os.path.exists` でファイルの存在を確認する。
* 存在すれば `FileResponse` を使用して動画ファイル (`media_type="video/mp4"`) を返す。
* 存在しなければ `HTTPException(status_code=404, detail="File not found")` を発生させる。
* **`/videos` (GET):**
* `settings.VIDEO_DIR` 内のファイルを `os.listdir` で取得する。
* ファイル名が `.mp4` で終わるもののみをフィルタリングする。
* フィルタリングされたファイル名のリストをJSON形式で返す。

4. **ライフサイクル管理 (`api/api.py` の `lifespan`):**
* 既存の `lifespan` コンテキストマネージャを修正する。
* **Startup:**
* `video_watcher.start_watcher(settings.VIDEO_DIR, sse_clients)` を呼び出し、返された `Observer` インスタンスをグローバル変数 `observer` に格納する。
* **Shutdown:**
* グローバル変数 `observer` が `None` でなければ、`video_watcher.stop_watcher(observer)` を呼び出してファイル監視プロセスを安全に停止する。

5. **依存関係:**
* `watchdog` ライブラリが必要となるため、プロジェクトの依存関係ファイル (`requirements.txt` や `pyproject.toml` など) に `watchdog` を追加する。

## Mermaid図

```mermaid
graph TD
subgraph FastAPI Backend (api/api.py)
A[Client connects to /video_stream] --> B{Create SSE queue (e.g., asyncio.Queue)};
B --> C[Add queue to global sse_clients list];
C --> D[Start SSE generation loop (async def)];
D -- New filename in queue --> E[yield f"data: {filename}\n\n"];
D -- Client disconnects --> F[Remove queue from sse_clients & break loop];

G[Client requests /videos/{filename}] --> H{Build file path using settings.VIDEO_DIR};
H -- os.path.exists is True --> I[Return FileResponse(path, media_type="video/mp4")];
H -- os.path.exists is False --> J[Raise HTTPException(404)];

K[Client requests /videos] --> L{os.listdir(settings.VIDEO_DIR)};
L --> M[Filter for .mp4 files, return JSON list];

N[lifespan startup] --> O[observer = video_watcher.start_watcher(VIDEO_DIR, sse_clients)];
P[lifespan shutdown] --> Q[if observer: video_watcher.stop_watcher(observer)];
end

subgraph File System Watcher (api/video_watcher.py - New File)
R[Watchdog Observer monitors VIDEO_DIR] -- New .mp4 created --> S[VideoHandler.on_created];
S --> T{Get filename};
T --> U[Add filename to all queues in sse_clients list];
V[start_watcher(path, clients)] --> W[Initialize Observer & Handler, observer.start(), return observer];
X[stop_watcher(observer)] --> Y[observer.stop(), observer.join()];
end

subgraph React Frontend (Out of scope)
Z[Page load requests /videos] --> AA[Get initial file list];
AA --> BB[Initialize playlist];
CC[Connects to /video_stream] --> DD[Receive filename via SSE];
DD --> EE[Add filename to playlist];
BB & EE --> FF[Select random video from playlist];
FF --> GG[Request /videos/{filename}];
GG --> HH[Receive video data & play];
end

FastAPI_Backend -- Manages --> File_System_Watcher;
```
117 changes: 115 additions & 2 deletions api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import base64 # 追加: Base64エンコード用
import mimetypes # 追加: MIMEタイプ判定用
import logging # 追加: Logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form, Request # Request を追加
from fastapi.responses import FileResponse, StreamingResponse # JSONResponse を削除
Expand All @@ -15,12 +16,14 @@
from PIL import Image
import numpy as np
from typing import List, Optional # Import Optional (Dict removed as unused)
from watchdog.observers import Observer # 追加: Watchdog Observer

# Import modules created earlier (relative imports)
from . import settings
from . import models
from . import queue_manager
from . import worker
from . import video_watcher

# --- Global State ---
# Dictionary to hold loaded models
Expand All @@ -30,13 +33,16 @@
worker_thread = None
# Variable to store the ID of the currently processing job
currently_processing_job_id: str | None = None
# --- Video Watcher State ---
sse_clients: List[asyncio.Queue] = [] # List to hold client queues for SSE
observer: Observer | None = None # type: ignore # Watchdog observer instance


# --- Lifespan Context Manager ---
@asynccontextmanager # Use the imported decorator directly
async def lifespan(app: FastAPI):
# Startup logic
global loaded_models, worker_running, worker_thread
global loaded_models, worker_running, worker_thread, observer, sse_clients # Add observer and sse_clients
print("API starting up via lifespan...")
# Load models
try:
Expand All @@ -60,10 +66,32 @@ async def lifespan(app: FastAPI):
else:
print("Worker already running? Skipping start in lifespan.")

# Start video watcher
try:
print(f"Attempting to start video watcher for directory: {settings.VIDEO_DIR}")
# Pass the global sse_clients list to the watcher
observer = video_watcher.start_watcher(settings.VIDEO_DIR, sse_clients)
print("Video watcher started successfully via lifespan.")
except Exception as e:
print(f"FATAL: Failed to start video watcher on startup: {e}")
traceback.print_exc()
observer = None # Ensure observer is None if startup failed

yield

# Shutdown logic
print("API shutting down via lifespan...")

# Stop video watcher first
if observer:
try:
print("Stopping video watcher...")
video_watcher.stop_watcher(observer)
print("Video watcher stopped.")
except Exception as e:
print(f"Error stopping video watcher: {e}")
traceback.print_exc()

# Stop background worker
if worker_running:
worker_running = False
Expand Down Expand Up @@ -572,8 +600,93 @@ async def list_loras():
return LoraListResponse(loras=lora_files) # Correct indentation for return


# === Video Streaming Endpoints ===

@app.get("/video_stream")
async def video_stream(request: Request):
"""
Streams new video filenames using Server-Sent Events (SSE).
"""
client_queue = asyncio.Queue()
sse_clients.append(client_queue)
logging.info(f"SSE client connected. Total clients: {len(sse_clients)}")

async def event_generator():
try:
while True:
# Check connection status first
if await request.is_disconnected():
logging.info("SSE client disconnected.")
break

try:
# Wait for a new filename from the queue
filename = await asyncio.wait_for(client_queue.get(), timeout=1.0)
logging.info(f"Sending SSE data: {filename}")
yield f"data: {filename}\n\n"
client_queue.task_done()
except asyncio.TimeoutError:
# No new file, continue loop to check connection status
continue
except Exception as e:
logging.error(f"Error in SSE generator: {e}")
# Optionally send an error event to the client
# yield f"event: error\ndata: {json.dumps({'message': 'Internal server error'})}\n\n"
break # Stop streaming on unexpected errors
finally:
# Cleanup when client disconnects or loop breaks
if client_queue in sse_clients:
sse_clients.remove(client_queue)
logging.info(f"SSE client queue removed. Total clients: {len(sse_clients)}")

return StreamingResponse(event_generator(), media_type="text/event-stream")


@app.get("/videos/{filename}")
async def get_video(filename: str):
"""
Serves a specific video file from the VIDEO_DIR.
"""
# Basic security check: prevent directory traversal
if ".." in filename or filename.startswith("/"):
raise HTTPException(status_code=400, detail="Invalid filename.")

filepath = os.path.join(settings.VIDEO_DIR, filename)
logging.info(f"Request for video file: {filepath}")

if not os.path.exists(filepath) or not os.path.isfile(filepath):
logging.warning(f"Video file not found: {filepath}")
raise HTTPException(status_code=404, detail="Video file not found")

# Check if the file is an mp4 file (optional but recommended)
if not filename.lower().endswith(".mp4"):
raise HTTPException(status_code=400, detail="Invalid file type, only MP4 is supported.")

return FileResponse(filepath, media_type="video/mp4", filename=filename)


@app.get("/videos", response_model=List[str])
async def list_videos():
"""
Lists all .mp4 files currently in the VIDEO_DIR.
"""
try:
all_files = os.listdir(settings.VIDEO_DIR)
mp4_files = sorted([f for f in all_files if f.lower().endswith(".mp4") and os.path.isfile(os.path.join(settings.VIDEO_DIR, f))])
logging.info(f"Found {len(mp4_files)} MP4 files in {settings.VIDEO_DIR}")
return mp4_files
except FileNotFoundError:
logging.error(f"VIDEO_DIR not found: {settings.VIDEO_DIR}")
raise HTTPException(status_code=500, detail="Video directory not found on server.")
except Exception as e:
logging.error(f"Error listing videos in {settings.VIDEO_DIR}: {e}")
raise HTTPException(status_code=500, detail="Error listing video files.")


# --- Main execution (for running with uvicorn) ---
if __name__ == "__main__":
import uvicorn
# Configure logging for the main execution context as well
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
print(f"Starting Uvicorn server on {settings.API_HOST}:{settings.API_PORT}")
uvicorn.run(app, host=settings.API_HOST, port=settings.API_PORT)
uvicorn.run("api.api:app", host=settings.API_HOST, port=settings.API_PORT, reload=True) # Use string import for reload
41 changes: 38 additions & 3 deletions api/queue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import uuid
import numpy as np
import logging
from dataclasses import dataclass, field # Import field
from typing import Optional
from datetime import datetime, timezone # Import datetime and timezone
from dataclasses import dataclass, field
from typing import Optional, Dict
from datetime import datetime, timezone
from PIL import Image
# from PIL.PngImagePlugin import PngInfo # No longer needed for JPEG saving

Expand Down Expand Up @@ -137,6 +137,9 @@ def from_dict(cls, data):
# Initialize job queue as a list
job_queue = []

# Dictionary to hold the latest preview image (Base64) for processing jobs (in-memory only)
current_previews: Dict[str, str] = {}


def save_queue():
global job_queue
Expand Down Expand Up @@ -416,13 +419,43 @@ def update_job_status(job_id: str, status: str, thumbnail: str = None):
else:
print(f"Job with ID {job_id} not found in memory or file for status update.")

# Clear preview if the job reached a terminal state
is_terminal = status == "completed" or status == "cancelled" or status.startswith("failed")
if is_terminal:
clear_current_preview(job_id)

return job_updated


# Configure logging (moved import to top)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


# --- Preview Data Management (In-Memory) ---

def update_current_preview(job_id: str, preview_base64: str):
"""Stores the latest preview image Base64 string for a job."""
global current_previews
current_previews[job_id] = preview_base64
# logging.debug(f"Updated preview for job {job_id}") # Optional: Debug logging


def get_current_preview(job_id: str) -> Optional[str]:
"""Retrieves the latest preview image Base64 string for a job."""
global current_previews
return current_previews.get(job_id)


def clear_current_preview(job_id: str):
"""Removes the preview image Base64 string for a job."""
global current_previews
if job_id in current_previews:
del current_previews[job_id]
logging.info(f"Cleared preview for job {job_id}")


# --- Job Progress Update ---

def update_job_progress(job_id: str, progress: float, step: int, total: int, info: str):
"""Updates the progress fields of a job in the global queue and saves the file."""
logging.info(f"Attempting to update progress for job {job_id}: progress={progress}, step={step}, total={total}, info='{info}'")
Expand Down Expand Up @@ -649,6 +682,8 @@ def cleanup_jobs_by_max_count(max_completed_jobs: int = settings.MAX_COMPLETED_J
files_to_delete.append(job.image_path)
if job.thumbnail:
files_to_delete.append(job.thumbnail)
# Also clear any lingering preview data for the removed job
clear_current_preview(job.job_id)
removed_count += 1

# Combine kept terminal jobs and active jobs for the new queue
Expand Down
5 changes: 5 additions & 0 deletions api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@
QUEUE_FILE_PATH = os.path.join(PROJECT_ROOT, 'job_queue.json')
HF_HOME_DIR = os.path.join(PROJECT_ROOT, 'hf_download')
LORA_DIR = os.environ.get("LORA_DIR", os.path.join(PROJECT_ROOT, 'loras'))
# Video Watcher Settings
VIDEO_DIR = os.environ.get("VIDEO_DIR", OUTPUTS_DIR) # Use OUTPUTS_DIR as default if env var not set


# Ensure directories exist (These should ideally be created outside the API module if they don't exist)
os.makedirs(OUTPUTS_DIR, exist_ok=True)
os.makedirs(TEMP_QUEUE_IMAGES_DIR, exist_ok=True)
os.makedirs(HF_HOME_DIR, exist_ok=True)
os.makedirs(LORA_DIR, exist_ok=True)
# Ensure VIDEO_DIR exists (especially if it's different from OUTPUTS_DIR)
os.makedirs(VIDEO_DIR, exist_ok=True)

# Set Hugging Face home directory environment variable
os.environ['HF_HOME'] = HF_HOME_DIR
Expand All @@ -55,6 +59,7 @@
print(f" LoRA Dir: {LORA_DIR}")
print(f" Worker Check Interval: {WORKER_CHECK_INTERVAL}")
print(f" Allowed Origins: {ALLOWED_ORIGINS}")
print(f" Video Dir (for watcher): {VIDEO_DIR}")

# --- Job Cleanup Settings ---
# Maximum number of completed, cancelled, or failed jobs to keep in the queue file.
Expand Down
Loading