diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 00000000..d35c6b2b --- /dev/null +++ b/PLAN.md @@ -0,0 +1,92 @@ +# 動画連続再生機能 実装計画 + +## 概要 + +ローカルの特定ディレクトリ (`outputs`) に順次追加される動画ファイル (`.mp4`) を検知し、FastAPIのSSE (Server-Sent Events) を通じてクライアント (React想定) に通知する。クライアントは通知されたファイル名を元に動画をリクエストし、連続再生を行う。既存のFramePack API (`api/api.py`) に機能を追加し、他の機能への影響を最小限に抑える。 + +## 計画詳細 + +1. **設定更新 (`api/settings.py`):** + * 監視対象ディレクトリパス `VIDEO_DIR` を定義する。デフォルトはプロジェクトルート下の `outputs` ディレクトリとする。環境変数 `VIDEO_DIR` が設定されていれば、その値を優先する。 + * 動画ファイル配信用エンドポイントのベースURL `VIDEO_BASE_URL` を `/videos/` として定義する (主にクライアント側の参考情報)。 + +2. **ファイル監視ロジック (`api/video_watcher.py` - 新規ファイル):** + * `watchdog` ライブラリを使用する `VideoHandler` クラスを作成する。 + * `on_created` イベントハンドラを実装し、`.mp4` ファイルが作成された場合のみ、FastAPI側のSSEクライアントキューリスト (`sse_clients`) にファイル名を追加する。 + * 監視を開始/停止する関数 (`start_watcher`, `stop_watcher`) を作成する。 + * `start_watcher(path, clients)`: 指定されたパスを監視し、通知先のクライアントキューリストを受け取る。`watchdog.observers.Observer` インスタンスを初期化・開始し、そのインスタンスを返す。 + * `stop_watcher(observer)`: 受け取った `Observer` インスタンスを停止・結合する。 + +3. **FastAPIエンドポイント追加 (`api/api.py`):** + * **グローバル変数:** + * `sse_clients = []`: SSEクライアントごとの通知キュー (`asyncio.Queue` など) を保持するリスト。 + * `observer = None`: `watchdog` の `Observer` インスタンスを保持する変数。 + * **`/video_stream` (GET, SSE):** + * 新しいクライアント接続時に、専用の通知キューを作成し `sse_clients` に追加する。 + * 非同期ジェネレータ関数を定義する。 + * 無限ループでクライアントの接続状態をチェックする。 + * キューから新しいファイル名を取得し、`data: {filename}\n\n` 形式で `yield` する。 + * クライアント切断時には、対応するキューを `sse_clients` から削除し、ループを終了する。 + * `StreamingResponse` で上記ジェネレータを返す (`media_type="text/event-stream"`)。 + * **`/videos/{filename}` (GET):** + * `settings.VIDEO_DIR` とリクエストされた `filename` を結合して、動画ファイルのフルパスを構築する。 + * `os.path.exists` でファイルの存在を確認する。 + * 存在すれば `FileResponse` を使用して動画ファイル (`media_type="video/mp4"`) を返す。 + * 存在しなければ `HTTPException(status_code=404, detail="File not found")` を発生させる。 + * **`/videos` (GET):** + * `settings.VIDEO_DIR` 内のファイルを `os.listdir` で取得する。 + * ファイル名が `.mp4` で終わるもののみをフィルタリングする。 + * フィルタリングされたファイル名のリストをJSON形式で返す。 + +4. **ライフサイクル管理 (`api/api.py` の `lifespan`):** + * 既存の `lifespan` コンテキストマネージャを修正する。 + * **Startup:** + * `video_watcher.start_watcher(settings.VIDEO_DIR, sse_clients)` を呼び出し、返された `Observer` インスタンスをグローバル変数 `observer` に格納する。 + * **Shutdown:** + * グローバル変数 `observer` が `None` でなければ、`video_watcher.stop_watcher(observer)` を呼び出してファイル監視プロセスを安全に停止する。 + +5. **依存関係:** + * `watchdog` ライブラリが必要となるため、プロジェクトの依存関係ファイル (`requirements.txt` や `pyproject.toml` など) に `watchdog` を追加する。 + +## Mermaid図 + +```mermaid +graph TD + subgraph FastAPI Backend (api/api.py) + A[Client connects to /video_stream] --> B{Create SSE queue (e.g., asyncio.Queue)}; + B --> C[Add queue to global sse_clients list]; + C --> D[Start SSE generation loop (async def)]; + D -- New filename in queue --> E[yield f"data: {filename}\n\n"]; + D -- Client disconnects --> F[Remove queue from sse_clients & break loop]; + + G[Client requests /videos/{filename}] --> H{Build file path using settings.VIDEO_DIR}; + H -- os.path.exists is True --> I[Return FileResponse(path, media_type="video/mp4")]; + H -- os.path.exists is False --> J[Raise HTTPException(404)]; + + K[Client requests /videos] --> L{os.listdir(settings.VIDEO_DIR)}; + L --> M[Filter for .mp4 files, return JSON list]; + + N[lifespan startup] --> O[observer = video_watcher.start_watcher(VIDEO_DIR, sse_clients)]; + P[lifespan shutdown] --> Q[if observer: video_watcher.stop_watcher(observer)]; + end + + subgraph File System Watcher (api/video_watcher.py - New File) + R[Watchdog Observer monitors VIDEO_DIR] -- New .mp4 created --> S[VideoHandler.on_created]; + S --> T{Get filename}; + T --> U[Add filename to all queues in sse_clients list]; + V[start_watcher(path, clients)] --> W[Initialize Observer & Handler, observer.start(), return observer]; + X[stop_watcher(observer)] --> Y[observer.stop(), observer.join()]; + end + + subgraph React Frontend (Out of scope) + Z[Page load requests /videos] --> AA[Get initial file list]; + AA --> BB[Initialize playlist]; + CC[Connects to /video_stream] --> DD[Receive filename via SSE]; + DD --> EE[Add filename to playlist]; + BB & EE --> FF[Select random video from playlist]; + FF --> GG[Request /videos/{filename}]; + GG --> HH[Receive video data & play]; + end + + FastAPI_Backend -- Manages --> File_System_Watcher; +``` diff --git a/api/api.py b/api/api.py index 6339febb..ffa72b9a 100644 --- a/api/api.py +++ b/api/api.py @@ -7,6 +7,7 @@ import json import base64 # 追加: Base64エンコード用 import mimetypes # 追加: MIMEタイプ判定用 +import logging # 追加: Logging from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form, Request # Request を追加 from fastapi.responses import FileResponse, StreamingResponse # JSONResponse を削除 @@ -15,12 +16,14 @@ from PIL import Image import numpy as np from typing import List, Optional # Import Optional (Dict removed as unused) +from watchdog.observers import Observer # 追加: Watchdog Observer # Import modules created earlier (relative imports) from . import settings from . import models from . import queue_manager from . import worker +from . import video_watcher # --- Global State --- # Dictionary to hold loaded models @@ -30,13 +33,16 @@ worker_thread = None # Variable to store the ID of the currently processing job currently_processing_job_id: str | None = None +# --- Video Watcher State --- +sse_clients: List[asyncio.Queue] = [] # List to hold client queues for SSE +observer: Observer | None = None # type: ignore # Watchdog observer instance # --- Lifespan Context Manager --- @asynccontextmanager # Use the imported decorator directly async def lifespan(app: FastAPI): # Startup logic - global loaded_models, worker_running, worker_thread + global loaded_models, worker_running, worker_thread, observer, sse_clients # Add observer and sse_clients print("API starting up via lifespan...") # Load models try: @@ -60,10 +66,32 @@ async def lifespan(app: FastAPI): else: print("Worker already running? Skipping start in lifespan.") + # Start video watcher + try: + print(f"Attempting to start video watcher for directory: {settings.VIDEO_DIR}") + # Pass the global sse_clients list to the watcher + observer = video_watcher.start_watcher(settings.VIDEO_DIR, sse_clients) + print("Video watcher started successfully via lifespan.") + except Exception as e: + print(f"FATAL: Failed to start video watcher on startup: {e}") + traceback.print_exc() + observer = None # Ensure observer is None if startup failed + yield # Shutdown logic print("API shutting down via lifespan...") + + # Stop video watcher first + if observer: + try: + print("Stopping video watcher...") + video_watcher.stop_watcher(observer) + print("Video watcher stopped.") + except Exception as e: + print(f"Error stopping video watcher: {e}") + traceback.print_exc() + # Stop background worker if worker_running: worker_running = False @@ -572,8 +600,93 @@ async def list_loras(): return LoraListResponse(loras=lora_files) # Correct indentation for return +# === Video Streaming Endpoints === + +@app.get("/video_stream") +async def video_stream(request: Request): + """ + Streams new video filenames using Server-Sent Events (SSE). + """ + client_queue = asyncio.Queue() + sse_clients.append(client_queue) + logging.info(f"SSE client connected. Total clients: {len(sse_clients)}") + + async def event_generator(): + try: + while True: + # Check connection status first + if await request.is_disconnected(): + logging.info("SSE client disconnected.") + break + + try: + # Wait for a new filename from the queue + filename = await asyncio.wait_for(client_queue.get(), timeout=1.0) + logging.info(f"Sending SSE data: {filename}") + yield f"data: {filename}\n\n" + client_queue.task_done() + except asyncio.TimeoutError: + # No new file, continue loop to check connection status + continue + except Exception as e: + logging.error(f"Error in SSE generator: {e}") + # Optionally send an error event to the client + # yield f"event: error\ndata: {json.dumps({'message': 'Internal server error'})}\n\n" + break # Stop streaming on unexpected errors + finally: + # Cleanup when client disconnects or loop breaks + if client_queue in sse_clients: + sse_clients.remove(client_queue) + logging.info(f"SSE client queue removed. Total clients: {len(sse_clients)}") + + return StreamingResponse(event_generator(), media_type="text/event-stream") + + +@app.get("/videos/{filename}") +async def get_video(filename: str): + """ + Serves a specific video file from the VIDEO_DIR. + """ + # Basic security check: prevent directory traversal + if ".." in filename or filename.startswith("/"): + raise HTTPException(status_code=400, detail="Invalid filename.") + + filepath = os.path.join(settings.VIDEO_DIR, filename) + logging.info(f"Request for video file: {filepath}") + + if not os.path.exists(filepath) or not os.path.isfile(filepath): + logging.warning(f"Video file not found: {filepath}") + raise HTTPException(status_code=404, detail="Video file not found") + + # Check if the file is an mp4 file (optional but recommended) + if not filename.lower().endswith(".mp4"): + raise HTTPException(status_code=400, detail="Invalid file type, only MP4 is supported.") + + return FileResponse(filepath, media_type="video/mp4", filename=filename) + + +@app.get("/videos", response_model=List[str]) +async def list_videos(): + """ + Lists all .mp4 files currently in the VIDEO_DIR. + """ + try: + all_files = os.listdir(settings.VIDEO_DIR) + mp4_files = sorted([f for f in all_files if f.lower().endswith(".mp4") and os.path.isfile(os.path.join(settings.VIDEO_DIR, f))]) + logging.info(f"Found {len(mp4_files)} MP4 files in {settings.VIDEO_DIR}") + return mp4_files + except FileNotFoundError: + logging.error(f"VIDEO_DIR not found: {settings.VIDEO_DIR}") + raise HTTPException(status_code=500, detail="Video directory not found on server.") + except Exception as e: + logging.error(f"Error listing videos in {settings.VIDEO_DIR}: {e}") + raise HTTPException(status_code=500, detail="Error listing video files.") + + # --- Main execution (for running with uvicorn) --- if __name__ == "__main__": import uvicorn + # Configure logging for the main execution context as well + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') print(f"Starting Uvicorn server on {settings.API_HOST}:{settings.API_PORT}") - uvicorn.run(app, host=settings.API_HOST, port=settings.API_PORT) + uvicorn.run("api.api:app", host=settings.API_HOST, port=settings.API_PORT, reload=True) # Use string import for reload diff --git a/api/queue_manager.py b/api/queue_manager.py index 16422e4e..d8063b94 100644 --- a/api/queue_manager.py +++ b/api/queue_manager.py @@ -4,9 +4,9 @@ import uuid import numpy as np import logging -from dataclasses import dataclass, field # Import field -from typing import Optional -from datetime import datetime, timezone # Import datetime and timezone +from dataclasses import dataclass, field +from typing import Optional, Dict +from datetime import datetime, timezone from PIL import Image # from PIL.PngImagePlugin import PngInfo # No longer needed for JPEG saving @@ -137,6 +137,9 @@ def from_dict(cls, data): # Initialize job queue as a list job_queue = [] +# Dictionary to hold the latest preview image (Base64) for processing jobs (in-memory only) +current_previews: Dict[str, str] = {} + def save_queue(): global job_queue @@ -416,6 +419,11 @@ def update_job_status(job_id: str, status: str, thumbnail: str = None): else: print(f"Job with ID {job_id} not found in memory or file for status update.") + # Clear preview if the job reached a terminal state + is_terminal = status == "completed" or status == "cancelled" or status.startswith("failed") + if is_terminal: + clear_current_preview(job_id) + return job_updated @@ -423,6 +431,31 @@ def update_job_status(job_id: str, status: str, thumbnail: str = None): logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +# --- Preview Data Management (In-Memory) --- + +def update_current_preview(job_id: str, preview_base64: str): + """Stores the latest preview image Base64 string for a job.""" + global current_previews + current_previews[job_id] = preview_base64 + # logging.debug(f"Updated preview for job {job_id}") # Optional: Debug logging + + +def get_current_preview(job_id: str) -> Optional[str]: + """Retrieves the latest preview image Base64 string for a job.""" + global current_previews + return current_previews.get(job_id) + + +def clear_current_preview(job_id: str): + """Removes the preview image Base64 string for a job.""" + global current_previews + if job_id in current_previews: + del current_previews[job_id] + logging.info(f"Cleared preview for job {job_id}") + + +# --- Job Progress Update --- + def update_job_progress(job_id: str, progress: float, step: int, total: int, info: str): """Updates the progress fields of a job in the global queue and saves the file.""" logging.info(f"Attempting to update progress for job {job_id}: progress={progress}, step={step}, total={total}, info='{info}'") @@ -649,6 +682,8 @@ def cleanup_jobs_by_max_count(max_completed_jobs: int = settings.MAX_COMPLETED_J files_to_delete.append(job.image_path) if job.thumbnail: files_to_delete.append(job.thumbnail) + # Also clear any lingering preview data for the removed job + clear_current_preview(job.job_id) removed_count += 1 # Combine kept terminal jobs and active jobs for the new queue diff --git a/api/settings.py b/api/settings.py index 2ebb5b2d..cf57557f 100644 --- a/api/settings.py +++ b/api/settings.py @@ -27,6 +27,8 @@ QUEUE_FILE_PATH = os.path.join(PROJECT_ROOT, 'job_queue.json') HF_HOME_DIR = os.path.join(PROJECT_ROOT, 'hf_download') LORA_DIR = os.environ.get("LORA_DIR", os.path.join(PROJECT_ROOT, 'loras')) +# Video Watcher Settings +VIDEO_DIR = os.environ.get("VIDEO_DIR", OUTPUTS_DIR) # Use OUTPUTS_DIR as default if env var not set # Ensure directories exist (These should ideally be created outside the API module if they don't exist) @@ -34,6 +36,8 @@ os.makedirs(TEMP_QUEUE_IMAGES_DIR, exist_ok=True) os.makedirs(HF_HOME_DIR, exist_ok=True) os.makedirs(LORA_DIR, exist_ok=True) +# Ensure VIDEO_DIR exists (especially if it's different from OUTPUTS_DIR) +os.makedirs(VIDEO_DIR, exist_ok=True) # Set Hugging Face home directory environment variable os.environ['HF_HOME'] = HF_HOME_DIR @@ -55,6 +59,7 @@ print(f" LoRA Dir: {LORA_DIR}") print(f" Worker Check Interval: {WORKER_CHECK_INTERVAL}") print(f" Allowed Origins: {ALLOWED_ORIGINS}") +print(f" Video Dir (for watcher): {VIDEO_DIR}") # --- Job Cleanup Settings --- # Maximum number of completed, cancelled, or failed jobs to keep in the queue file. diff --git a/api/video_watcher.py b/api/video_watcher.py new file mode 100644 index 00000000..68f6f7e2 --- /dev/null +++ b/api/video_watcher.py @@ -0,0 +1,93 @@ +import os +import asyncio +import logging +from watchdog.observers import Observer +from watchdog.events import FileSystemEventHandler +from typing import List + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + +class VideoHandler(FileSystemEventHandler): + """Handles file system events for video files.""" + + def __init__(self, sse_clients: List[asyncio.Queue]): + """ + Initializes the handler with a list of SSE client queues. + + Args: + sse_clients: A list of asyncio.Queue instances for notifying clients. + """ + self.sse_clients = sse_clients + logging.info(f"VideoHandler initialized with {len(sse_clients)} client queues initially.") + + def on_created(self, event): + """ + Called when a file or directory is created. + + Args: + event: The event object representing the file system event. + """ + if event.is_directory: + logging.debug(f"Ignoring directory creation: {event.src_path}") + return + + if event.src_path.lower().endswith(".mp4"): + filename = os.path.basename(event.src_path) + logging.info(f"New MP4 file detected: {filename}") + # Notify all connected SSE clients + # Use a copy of the list to avoid issues if a client disconnects during iteration + clients_to_notify = list(self.sse_clients) + logging.info(f"Notifying {len(clients_to_notify)} clients about new file: {filename}") + for queue in clients_to_notify: + try: + # Use put_nowait as this handler runs in a separate thread + # managed by watchdog, not in the main asyncio event loop. + queue.put_nowait(filename) + logging.debug(f"Added '{filename}' to a client queue.") + except asyncio.QueueFull: + logging.warning(f"Client queue is full. Could not add '{filename}'.") + except Exception as e: + logging.error(f"Error adding '{filename}' to client queue: {e}") + else: + logging.debug(f"Ignoring non-MP4 file creation: {event.src_path}") + + +def start_watcher(path: str, sse_clients: List[asyncio.Queue]) -> Observer: + """ + Starts the file system watcher. + + Args: + path: The directory path to watch. + sse_clients: The list of SSE client queues to notify. + + Returns: + The Observer instance that was started. + """ + if not os.path.isdir(path): + logging.error(f"Watch directory does not exist or is not a directory: {path}") + # Consider raising an exception or returning None + raise ValueError(f"Invalid watch directory: {path}") + + event_handler = VideoHandler(sse_clients) + observer = Observer() + observer.schedule(event_handler, path, recursive=False) + observer.start() + logging.info(f"Started watching directory: {path}") + return observer + + +def stop_watcher(observer: Observer): + """ + Stops the file system watcher. + + Args: + observer: The Observer instance to stop. + """ + if observer and observer.is_alive(): + observer.stop() + observer.join() # Wait for the thread to finish + logging.info("Stopped file system watcher.") + else: + logging.info("File system watcher was not running or already stopped.") \ No newline at end of file diff --git a/api/worker.py b/api/worker.py index 669c60b4..d10b3fab 100644 --- a/api/worker.py +++ b/api/worker.py @@ -6,6 +6,9 @@ from PIL import Image # Removed ImageDraw, ImageFont # from PIL.PngImagePlugin import PngInfo # No longer needed for JPEG saving import traceback +import base64 # Add base64 +import io # Add io +import einops # Add einops # Assuming models and tokenizers are loaded elsewhere and passed or accessed globally/via context # This will be refined when creating models.py and integrating @@ -16,6 +19,7 @@ encode_prompt_conds, vae_decode, vae_encode, + vae_decode_fake, # Add vae_decode_fake ) from diffusers_helper.utils import ( @@ -416,6 +420,7 @@ def update_progress(step_info: str, percentage: float = 0.0, current_step: int = # K-Diffusion Sampling Callback def callback(d): + # --- Progress Update --- current_cb_step = d['i'] + 1 # 1-based step for the current section total_cb_steps = steps # Total steps for this section @@ -437,7 +442,38 @@ def callback(d): info=hint ) - # Check for cancellation signal within callback + # --- Preview Generation --- + try: + # Only generate preview every N steps or on the last step to reduce overhead + if current_cb_step % 2 == 0 or current_cb_step == total_cb_steps: + preview_latent = d['denoised'] + preview_tensor = vae_decode_fake(preview_latent) # Use vae_decode_fake + + # Convert tensor to PIL Image + preview_np = (preview_tensor * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8) + # Rearrange: b c t h w -> (b h) (t w) c (assuming single batch) + # vae_decode_fake likely returns B C T H W, where T might be > 1 + preview_np_rearranged = einops.rearrange(preview_np, 'b c t h w -> (b h) (t w) c') + + preview_image = Image.fromarray(preview_np_rearranged) + + # Save image to buffer as JPEG + buffer = io.BytesIO() + preview_image.save(buffer, format="JPEG", quality=75) # Adjust quality as needed + buffer.seek(0) + + # Encode to Base64 + preview_base64_data = base64.b64encode(buffer.getvalue()).decode('utf-8') + preview_base64_string = f"data:image/jpeg;base64,{preview_base64_data}" + + # Update preview in queue_manager (in-memory) + queue_manager.update_current_preview(job_id, preview_base64_string) + # logging.debug(f"Job {job_id}: Updated preview at step {current_cb_step}") # Optional debug log + except Exception as preview_e: + # Log error but don't necessarily stop the whole process + logging.warning(f"Job {job_id}: Error generating preview at step {current_cb_step}: {preview_e}") + + # --- Cancellation Check --- current_job_status_inner = queue_manager.get_job_by_id(job_id) # Use the function that reads the file if current_job_status_inner and current_job_status_inner.status == "cancelled": # Use current_cb_step which is defined in this scope @@ -561,6 +597,8 @@ def callback(d): unload_complete_models( text_encoder, text_encoder_2, image_encoder, vae, transformer ) + # Clear preview data from memory + queue_manager.clear_current_preview(job_id) print(f"Worker finished for job {job_id}") diff --git a/api_preview_plan.md b/api_preview_plan.md new file mode 100644 index 00000000..52fab76d --- /dev/null +++ b/api_preview_plan.md @@ -0,0 +1,125 @@ +# API プレビュー機能 実装計画 + +## 1. 目的 + +`api/` ディレクトリ内の FastAPI アプリケーションにおいて、`demo_gradio.py` と同様の動画生成中のプレビュー機能を実装する。これにより、API クライアントは生成プロセスの途中経過をリアルタイムで確認できるようになる。 + +## 2. 現状分析と課題 + +* **`demo_gradio.py` のプレビュー:** + * `worker` 内の `sample_hunyuan` コールバックで `vae_decode_fake` を使用し、プレビュー画像を生成。 + * `AsyncStream` を介して Gradio UI にプレビュー画像を送信。 +* **`api/worker.py` の現状:** + * `sample_hunyuan` のコールバックは進捗テキスト更新とキャンセルチェックのみ。プレビュー生成・送信は未実装。 + * `worker` からクライアントへのリアルタイムデータ送信手段が直接はない。 +* **API での実現課題:** + * `api/worker.py` のコールバックでプレビュー画像を生成する必要がある。 + * 生成したプレビュー画像を API クライアントにリアルタイムで送信する仕組みが必要。 + * `vae_decode_fake` 関数の利用可否確認 (→ 確認済み、利用可能)。 + +## 3. 計画 + +### 3.1. 情報収集 (完了) + +* `diffusers_helper/hunyuan.py` を確認し、`vae_decode_fake` 関数が存在することを確認した。 + +### 3.2. API 設計 + +* 既存の Server-Sent Events (SSE) エンドポイント `/stream/status/{job_id}` (`api.py` 内) を拡張し、プレビュー画像データも送信するようにする。 +* SSE イベントのデータ構造に、オプションとして Base64 エンコードされたプレビュー画像 (`preview_image_base64`) を追加する。 + + ```json + { + "job_id": "...", + "status": "processing", + "progress": 25.5, + "progress_step": 5, + "progress_total": 20, + "progress_info": "Sampling...", + "preview_image_base64": "data:image/jpeg;base64,..." // Optional + } + ``` + +### 3.3. 実装方針 + +* **`api/worker.py` の修正:** + * `callback` 関数内で、`sample_hunyuan` から渡される中間潜在変数 (`d['denoised']`) を取得する。 + * `vae_decode_fake` を使用してプレビュー画像を生成する。 + * 生成した画像を JPEG 形式にエンコードし、Base64 文字列に変換する (`data:image/jpeg;base64,...` 形式)。 + * 変換した Base64 文字列を `queue_manager` の新しい関数 (例: `update_current_preview`) を呼び出してメモリ上のストアに一時保存する。 +* **`api/queue_manager.py` の修正:** + * プレビュー情報 (Base64 文字列) を一時的に保持するためのグローバルな辞書 (例: `current_previews = {}`) を追加する。 + * `worker.py` からプレビュー情報を受け取り、`current_previews` を更新する関数 (例: `update_current_preview(job_id, preview_base64)`) を追加する。 + * SSE ハンドラからプレビュー情報を取得する関数 (例: `get_current_preview(job_id)`) を追加する。 + * ジョブ完了時または失敗時に `current_previews` から該当ジョブのエントリを削除する処理を追加する (例: `clear_current_preview(job_id)`)。 + * **注意:** このプレビュー情報は揮発性であり、JSON キューファイル (`job_queue.json`) には保存しない。 +* **`api/api.py` の修正:** + * `/stream/status/{job_id}` の SSE `event_generator` 関数を修正する。 + * ジョブが `processing` 状態の場合、`queue_manager` の新しい関数 (例: `get_current_preview`) を呼び出して最新のプレビュー画像 Base64 文字列を取得する。 + * 取得した Base64 文字列を SSE イベントデータの `preview_image_base64` フィールドに含めてクライアントに送信する。 + +### 3.4. 処理フロー (Mermaid図) + +```mermaid +sequenceDiagram + participant Client + participant FastAPI (api.py) + participant QueueManager (queue_manager.py) + participant Worker (worker.py) + participant Models (models.py / diffusers_helper) + + Client->>FastAPI: POST /generate (画像, プロンプト) + FastAPI->>QueueManager: add_to_queue() + QueueManager-->>FastAPI: job_id + FastAPI-->>Client: {job_id: ...} + + Client->>FastAPI: GET /stream/status/{job_id} (SSE接続) + FastAPI->>FastAPI: event_generator() 開始 + + loop Worker Thread + Worker->>QueueManager: get_next_job() + QueueManager-->>Worker: job (or None) + opt job is not None + Worker->>QueueManager: update_job_status(job_id, "processing") # ファイル更新 + Worker->>Models: モデルロード/準備 ... + Worker->>Models: sample_hunyuan(..., callback=callback_func) + loop Sampling Steps + Models->>Worker: callback_func(d) 呼び出し + Worker->>Models: vae_decode_fake(d['denoised']) # プレビュー生成 + Models-->>Worker: preview_image_tensor + Worker->>Worker: 画像をJPEG Base64に変換 + Worker->>QueueManager: update_current_preview(job_id, preview_base64) # メモリ更新 + Worker->>QueueManager: update_job_progress(...) # ファイル更新 (進捗のみ) + end + Models-->>Worker: generated_latents + Worker->>Models: vae_decode() # 最終デコード + Models-->>Worker: final_pixels + Worker->>Worker: save_bcthw_as_mp4() + Worker->>QueueManager: update_job_status(job_id, "completed") # ファイル更新 + Worker->>QueueManager: clear_current_preview(job_id) # メモリクリア + end + end + + loop SSE Connection (event_generator) + FastAPI->>QueueManager: get_job_by_id(job_id) (ファイルから進捗取得) + alt job is processing + FastAPI->>QueueManager: get_current_preview(job_id) (メモリからプレビュー取得) + QueueManager-->>FastAPI: preview_base64 (or None) + end + FastAPI->>Client: event: progress, data: {..., preview_image_base64: ...} # 進捗とプレビュー送信 + alt job is terminal + FastAPI->>Client: event: status, data: {...} # 最終ステータス送信 + break + end + FastAPI->>FastAPI: asyncio.sleep(1) + end +``` + +## 4. 懸念点 + +* **データ受け渡し:** `worker` スレッドと SSE ハンドラ (FastAPI の非同期コンテキスト) 間でのプレビューデータ受け渡し (`queue_manager` のメモリ上の辞書) が、スレッドセーフティやパフォーマンスの観点から問題ないか。高頻度更新時の競合やメモリ使用量に注意が必要。 +* **パフォーマンス:** プレビュー画像の生成 (`vae_decode_fake`)、JPEG エンコード、Base64 エンコードが `worker` のコールバック内で実行されるため、全体の生成時間に影響を与える可能性がある。 + +## 5. 次のステップ + +* この計画に基づき、`code` モードに切り替えて実装を開始する。 diff --git a/diffusers_helper/memory.py b/diffusers_helper/memory.py index 3380c538..49ba80c8 100644 --- a/diffusers_helper/memory.py +++ b/diffusers_helper/memory.py @@ -3,9 +3,16 @@ import torch - cpu = torch.device('cpu') -gpu = torch.device(f'cuda:{torch.cuda.current_device()}') + +try: + if torch.cuda.is_available(): + gpu = torch.device(f'cuda:{torch.cuda.current_device()}') + else: + gpu = torch.device('cpu') +except Exception: + gpu = torch.device('cpu') + gpu_complete_modules = [] @@ -72,13 +79,22 @@ def get_cuda_free_memory_gb(device=None): if device is None: device = gpu - memory_stats = torch.cuda.memory_stats(device) - bytes_active = memory_stats['active_bytes.all.current'] - bytes_reserved = memory_stats['reserved_bytes.all.current'] - bytes_free_cuda, _ = torch.cuda.mem_get_info(device) - bytes_inactive_reserved = bytes_reserved - bytes_active - bytes_total_available = bytes_free_cuda + bytes_inactive_reserved - return bytes_total_available / (1024 ** 3) + if not torch.cuda.is_available(): + return 0 # GPUが無い場合は0を返す + + try: + memory_stats = torch.cuda.memory_stats(device) + bytes_active = memory_stats.get('active_bytes.all.current') + if bytes_active is None: + # キーが存在しない場合、安全のため0を返す + return 0 + bytes_reserved = memory_stats['reserved_bytes.all.current'] + bytes_free_cuda, _ = torch.cuda.mem_get_info(device) + bytes_inactive_reserved = bytes_reserved - bytes_active + bytes_total_available = bytes_free_cuda + bytes_inactive_reserved + return bytes_total_available / (1024 ** 3) + except Exception: + return 0 def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0): diff --git a/diffusers_helper/models/hunyuan_video_packed.py b/diffusers_helper/models/hunyuan_video_packed.py index a724e361..96c10132 100644 --- a/diffusers_helper/models/hunyuan_video_packed.py +++ b/diffusers_helper/models/hunyuan_video_packed.py @@ -591,13 +591,13 @@ def forward( # 3. Modulation and residual connection hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) - print(f"DEBUG: Before proj_out - hidden_states device: {hidden_states.device}") - print(f"DEBUG: Before proj_out - gate device: {gate.device}") + # print(f"DEBUG: Before proj_out - hidden_states device: {hidden_states.device}") + # print(f"DEBUG: Before proj_out - gate device: {gate.device}") # Check proj_out layer's device (assuming it's a nn.Module with parameters) proj_out_device = next(self.proj_out.parameters()).device if list(self.proj_out.parameters()) else "No parameters" - print(f"DEBUG: Before proj_out - self.proj_out device: {proj_out_device}") + # print(f"DEBUG: Before proj_out - self.proj_out device: {proj_out_device}") proj_out_result = self.proj_out(hidden_states) - print(f"DEBUG: After proj_out - proj_out_result device: {proj_out_result.device}") + # print(f"DEBUG: After proj_out - proj_out_result device: {proj_out_result.device}") hidden_states = gate * proj_out_result # Error likely occurs here hidden_states = hidden_states + residual diff --git a/requirements.txt b/requirements.txt index e7dd8ecb..002f7153 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,8 @@ numpy==1.26.2 scipy==1.12.0 requests==2.31.0 torchsde==0.2.6 +torch +torchvision einops opencv-contrib-python @@ -17,7 +19,8 @@ peft fastapi uvicorn[standard] +watchdog # Testing pytest -pytest-mock +pytest-mock \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py index 6fb13f75..e3c8af64 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -13,6 +13,11 @@ # Assuming your FastAPI app instance is named 'app' in 'api/api.py' from api.api import app +if torch.cuda.is_available(): + gpu = torch.device(f'cuda:{torch.cuda.current_device()}') +else: + gpu = torch.device('cpu') + # Create a TestClient instance client = TestClient(app)