diff --git a/api/api.py b/api/api.py index ffa72b9a..23e09228 100644 --- a/api/api.py +++ b/api/api.py @@ -5,18 +5,19 @@ import traceback import asyncio import json -import base64 # 追加: Base64エンコード用 -import mimetypes # 追加: MIMEタイプ判定用 -import logging # 追加: Logging +import base64 +import mimetypes +import logging +import enum from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form, Request # Request を追加 -from fastapi.responses import FileResponse, StreamingResponse # JSONResponse を削除 +from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form, Request +from fastapi.responses import FileResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from PIL import Image import numpy as np -from typing import List, Optional # Import Optional (Dict removed as unused) -from watchdog.observers import Observer # 追加: Watchdog Observer +from typing import List, Optional +from watchdog.observers import Observer # Import modules created earlier (relative imports) from . import settings @@ -39,15 +40,15 @@ # --- Lifespan Context Manager --- -@asynccontextmanager # Use the imported decorator directly +@asynccontextmanager async def lifespan(app: FastAPI): # Startup logic - global loaded_models, worker_running, worker_thread, observer, sse_clients # Add observer and sse_clients + global loaded_models, worker_running, worker_thread, observer, sse_clients print("API starting up via lifespan...") # Load models try: # Consider running blocking IO in a threadpool executor in async context - # e.g., await asyncio.to_thread(models.load_models) # lora_path removed + # e.g., await asyncio.to_thread(models.load_models) # For simplicity now, keeping the direct call but be aware of potential blocking loaded_models = models.load_models() print("Models loaded successfully via lifespan.") @@ -75,7 +76,7 @@ async def lifespan(app: FastAPI): except Exception as e: print(f"FATAL: Failed to start video watcher on startup: {e}") traceback.print_exc() - observer = None # Ensure observer is None if startup failed + observer = None yield @@ -106,7 +107,7 @@ async def lifespan(app: FastAPI): # Cleanup resources print("Attempting to unload models...") try: - models.unload_models(loaded_models) # Call the function from models module + models.unload_models(loaded_models) print("Models unloaded successfully (or placeholder executed).") except Exception as unload_e: print(f"Error during model unloading: {unload_e}") @@ -121,10 +122,10 @@ async def lifespan(app: FastAPI): # --- CORS Middleware Configuration --- app.add_middleware( CORSMiddleware, - allow_origins=settings.ALLOWED_ORIGINS, # Use the loaded origins from settings - allow_credentials=True, # Allow credentials (cookies, authorization headers, etc.) - allow_methods=["*"], # Allow all methods (GET, POST, etc.) - allow_headers=["*"], # Allow all headers (Content-Type, Authorization, etc.) + allow_origins=settings.ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], ) # --- End CORS Middleware Configuration --- @@ -132,7 +133,7 @@ async def lifespan(app: FastAPI): # --- Pydantic Models for API Requests/Responses --- class GenerateRequest(BaseModel): prompt: str = Field(..., description="Text prompt for video generation.") - # image: str = Field(..., description="Base64 encoded input image.") # Changed to use UploadFile + # image: str = Field(..., description="Base64 encoded input image.") video_length: float = Field(5.0, description="Length of the video in seconds.", gt=0) seed: int = Field(-1, description="Seed for generation. -1 for random.") use_teacache: bool = Field(False, description="Enable TEACache optimization.") @@ -177,6 +178,12 @@ class ResultResponse(BaseModel): thumbnail_base64: Optional[str] = None +# --- Enum for Sampling Mode --- +class SamplingMode(str, enum.Enum): + reverse = "reverse" + forward = "forward" + + # --- Background Worker --- def background_worker_task(): global worker_running, currently_processing_job_id @@ -184,15 +191,15 @@ def background_worker_task(): while worker_running: next_job = queue_manager.get_next_job() if next_job: - currently_processing_job_id = next_job.job_id # Set current job ID + currently_processing_job_id = next_job.job_id print(f"Worker picked up job: {currently_processing_job_id}") try: # Ensure models are loaded before processing if not loaded_models: print("Error: Models not loaded. Cannot process job.") queue_manager.update_job_status(currently_processing_job_id, "failed - models not loaded") - currently_processing_job_id = None # Clear current job ID on error - continue # Skip to next loop iteration + currently_processing_job_id = None + continue worker.worker(next_job, loaded_models) except Exception as e: @@ -220,18 +227,19 @@ def background_worker_task(): @app.post("/generate", response_model=GenerateResponse) async def generate_video( background_tasks: BackgroundTasks, - prompt: str = Form("A character doing some simple body movements."), # Set default prompt + prompt: str = Form("A character doing some simple body movements."), video_length: float = Form(5.0), seed: int = Form(-1), - use_teacache: bool = Form(True), # Default to True (matching demo_gradio.py) - gpu_memory_preservation: float = Form(6.0), # Default to 6.0 GB (matching demo_gradio.py) + use_teacache: bool = Form(True), + gpu_memory_preservation: float = Form(6.0), steps: int = Form(25), cfg: float = Form(1.0), gs: float = Form(10.0), rs: float = Form(0.0), mp4_crf: float = Form(16.0), - lora_scale: float = Form(1.0), # 追加: LoRA強度パラメータ - lora_path: Optional[str] = Form(None, description="Path to the LoRA file to use for this request (overrides server default if provided)."), # 追加: LoRAファイルパス + lora_scale: float = Form(1.0), + lora_path: Optional[str] = Form(None, description="Path to the LoRA file to use for this request (overrides server default if provided)."), + sampling_mode: SamplingMode = Form(SamplingMode.reverse, description="Sampling loop direction."), image: UploadFile = File(...) ): """ @@ -258,12 +266,24 @@ async def generate_video( finally: await image.close() + # Determine the transformer model based on sampling_mode + # Use sampling_mode.value to get the string value from the Enum + if sampling_mode == SamplingMode.forward: + actual_transformer_model = "f1" + elif sampling_mode == SamplingMode.reverse: + actual_transformer_model = "base" + else: + # This 'else' block might be unreachable if using Enum correctly, + # but kept for safety or future expansion. FastAPI handles invalid Enum values. + print(f"Warning: Unexpected sampling_mode '{sampling_mode.value}'. Defaulting transformer_model to 'base'.") + actual_transformer_model = "base" + # Add job to the queue using queue_manager try: job_id = queue_manager.add_to_queue( prompt=prompt, image=image_np, - original_exif=original_exif, # Pass extracted Exif data + original_exif=original_exif, video_length=video_length, seed=seed, use_teacache=use_teacache, @@ -273,9 +293,11 @@ async def generate_video( gs=gs, rs=rs, mp4_crf=mp4_crf, - lora_scale=lora_scale, # 追加: lora_scale を渡す - lora_path=lora_path, # 追加: lora_path を渡す - status="pending" # Explicitly set initial status + lora_scale=lora_scale, + lora_path=lora_path, + sampling_mode=sampling_mode.value, + transformer_model=actual_transformer_model, + status="pending" ) except Exception as e: print(f"Error adding job via queue_manager: {e}") @@ -313,7 +335,7 @@ async def get_job_status(job_id: str): return JobStatusResponse(job_id=job_id, status="processing", progress_info="Details temporarily unavailable") # 2. Check if the job exists in the queue file (pending, failed, potentially completed but file not checked yet) - job_in_file = queue_manager.get_job_by_id(job_id) # Use the function that reads file + job_in_file = queue_manager.get_job_by_id(job_id) if job_in_file: # Return the status and progress details from the file return JobStatusResponse( @@ -347,7 +369,7 @@ async def stream_job_status(job_id: str, request: Request): """ async def event_generator(): last_data_sent = None - # terminal_statuses = {"completed", "cancelled"} # Unused variable removed + # terminal_statuses = {"completed", "cancelled"} while True: # Check if client disconnected @@ -388,14 +410,14 @@ async def event_generator(): # Send final status event if it hasn't been sent already if current_data_json != last_data_sent: yield f"event: progress\ndata: {current_data_json}\n\n" - last_data_sent = current_data_json # Ensure last_data_sent is updated even for the final message + last_data_sent = current_data_json print(f"Sent final progress update for job {job_id}: Status {job.status}") # Send a dedicated 'status' event to signal completion/failure/cancellation final_status_data = json.dumps({"status": job.status, "message": "Job finished."}) yield f"event: status\ndata: {final_status_data}\n\n" print(f"Job {job_id} reached terminal state: {job.status}. Closing stream.") - break # Exit loop after sending final status + break else: # Wait before checking again only if not terminal await asyncio.sleep(1) # Check every 1 second @@ -404,7 +426,7 @@ async def event_generator(): final_data = json.dumps({"status": job.status, "message": "Job finished."}) yield f"event: status\ndata: {final_data}\n\n" print(f"Job {job_id} reached terminal state: {job.status}. Closing stream.") - break # Exit loop after sending final status + break # Wait before checking again await asyncio.sleep(1) # Check every 1 second @@ -413,7 +435,7 @@ async def event_generator(): @app.get("/result/{job_id}", response_model=ResultResponse) -async def get_job_result(job_id: str, request: Request): # requestを追加してURLを構築 +async def get_job_result(job_id: str, request: Request): """ Returns the download URL for the completed video and the Base64 encoded thumbnail. """ @@ -440,7 +462,7 @@ async def get_job_result(job_id: str, request: Request): # requestを追加し thumbnail_base64 = f"data:{mime_type};base64,{thumbnail_base64_data}" else: # MIMEタイプが不明な場合はデフォルトを使用(またはエラー処理) - thumbnail_base64 = f"data:image/jpeg;base64,{thumbnail_base64_data}" # デフォルトをJPEGに + thumbnail_base64 = f"data:image/jpeg;base64,{thumbnail_base64_data}" print(f"Job {job_id}: Encoded thumbnail from {job.thumbnail}") except Exception as e: print(f"Job {job_id}: Error reading or encoding thumbnail {job.thumbnail}: {e}") @@ -479,7 +501,7 @@ async def get_input_image(job_id: str): Returns the input JPEG image file associated with a job, potentially including Exif metadata. """ job = queue_manager.get_job_by_id(job_id) - filename_base = f"queue_image_{job_id}.jpg" # Changed extension to jpg + filename_base = f"queue_image_{job_id}.jpg" input_image_path_in_temp = os.path.join(settings.TEMP_QUEUE_IMAGES_DIR, filename_base) if not job: @@ -563,16 +585,16 @@ async def get_worker_status(): @app.post("/cleanup_jobs", status_code=200) async def trigger_cleanup_jobs(): - """ # Correct indentation for docstring + """ Manually triggers the cleanup of old completed, cancelled, or failed jobs based on the MAX_COMPLETED_JOBS setting. """ - try: # Correct indentation for try block + try: removed_count = queue_manager.cleanup_jobs_by_max_count() return {"message": f"Cleanup process completed. Removed {removed_count} old job entries."} except Exception as e: - print(f"Error during manual job cleanup: {e}") # Correct indentation - traceback.print_exc() # Correct indentation + print(f"Error during manual job cleanup: {e}") + traceback.print_exc() raise HTTPException(status_code=500, detail=f"Failed to perform job cleanup: {e}") @@ -582,7 +604,7 @@ async def trigger_cleanup_jobs(): async def list_loras(): """Lists available LoRA files from the configured directory.""" lora_files = [] - allowed_extensions = {".safetensors", ".pt", ".bin"} # Common LoRA extensions + allowed_extensions = {".safetensors", ".pt", ".bin"} try: if os.path.isdir(settings.LORA_DIR): for filename in os.listdir(settings.LORA_DIR): @@ -590,14 +612,14 @@ async def list_loras(): _, ext = os.path.splitext(filename) if ext.lower() in allowed_extensions: lora_files.append(filename) - lora_files.sort() # Sort alphabetically + lora_files.sort() else: print(f"Warning: LORA_DIR '{settings.LORA_DIR}' is not a valid directory.") except Exception as e: print(f"Error listing LoRA files: {e}") # Return empty list on error, or raise HTTPException # raise HTTPException(status_code=500, detail=f"Failed to list LoRA files: {e}") - return LoraListResponse(loras=lora_files) # Correct indentation for return + return LoraListResponse(loras=lora_files) # === Video Streaming Endpoints === @@ -632,7 +654,7 @@ async def event_generator(): logging.error(f"Error in SSE generator: {e}") # Optionally send an error event to the client # yield f"event: error\ndata: {json.dumps({'message': 'Internal server error'})}\n\n" - break # Stop streaming on unexpected errors + break finally: # Cleanup when client disconnects or loop breaks if client_queue in sse_clients: @@ -689,4 +711,4 @@ async def list_videos(): # Configure logging for the main execution context as well logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') print(f"Starting Uvicorn server on {settings.API_HOST}:{settings.API_PORT}") - uvicorn.run("api.api:app", host=settings.API_HOST, port=settings.API_PORT, reload=True) # Use string import for reload + uvicorn.run("api.api:app", host=settings.API_HOST, port=settings.API_PORT, reload=True) diff --git a/api/models.py b/api/models.py index 25ae91d6..ee36a9c4 100644 --- a/api/models.py +++ b/api/models.py @@ -2,12 +2,12 @@ from diffusers import AutoencoderKLHunyuanVideo from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer, SiglipImageProcessor, SiglipVisionModel -from diffusers_helper.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller # Removed cpu, load_model_as_complete +from diffusers_helper.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked # Determine VRAM mode free_mem_gb = get_cuda_free_memory_gb(gpu) -high_vram = free_mem_gb > 60 # Threshold might need adjustment +high_vram = free_mem_gb > 60 print(f'Models: Free VRAM {free_mem_gb} GB') print(f'Models: High-VRAM Mode: {high_vram}') @@ -15,6 +15,7 @@ HUNYUAN_VIDEO_BASE = "hunyuanvideo-community/HunyuanVideo" FLUX_REDUX_BASE = "lllyasviel/flux_redux_bfl" FRAMEPACK_BASE = 'lllyasviel/FramePackI2V_HY' +FRAMEPACK_F1 = 'lllyasviel/FramePack_F1_I2V_HY_20250503' def load_models(): @@ -44,14 +45,17 @@ def load_models(): # Load Transformer print("Loading transformer...") - transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(FRAMEPACK_BASE, torch_dtype=torch.bfloat16).cpu() + transformer_base = HunyuanVideoTransformer3DModelPacked.from_pretrained(FRAMEPACK_BASE, torch_dtype=torch.bfloat16).cpu() + print("Loading F1 transformer...") + transformer_f1 = HunyuanVideoTransformer3DModelPacked.from_pretrained(FRAMEPACK_F1, torch_dtype=torch.bfloat16).cpu() # Set models to evaluation mode vae.eval() text_encoder.eval() text_encoder_2.eval() image_encoder.eval() - transformer.eval() + transformer_base.eval() + transformer_f1.eval() # Apply VRAM optimizations if needed if not high_vram: @@ -60,11 +64,13 @@ def load_models(): vae.enable_tiling() # Configure transformer output quality - transformer.high_quality_fp32_output_for_inference = True - print('transformer.high_quality_fp32_output_for_inference = True') + transformer_base.high_quality_fp32_output_for_inference = True + transformer_f1.high_quality_fp32_output_for_inference = True + print('transformer.high_quality_fp32_output_for_inference = True (for both models)') # Set model dtypes - transformer.to(dtype=torch.bfloat16) + transformer_base.to(dtype=torch.bfloat16) + transformer_f1.to(dtype=torch.bfloat16) vae.to(dtype=torch.float16) image_encoder.to(dtype=torch.float16) text_encoder.to(dtype=torch.float16) @@ -75,25 +81,27 @@ def load_models(): text_encoder.requires_grad_(False) text_encoder_2.requires_grad_(False) image_encoder.requires_grad_(False) - transformer.requires_grad_(False) + transformer_base.requires_grad_(False) + transformer_f1.requires_grad_(False) # LoRA loading moved to worker function # Move models to appropriate device based on VRAM if not high_vram: print("Installing DynamicSwap for low VRAM mode...") - # DynamicSwapInstaller is same as huggingface's enable_sequential_offload but potentially faster - DynamicSwapInstaller.install_model(transformer, device=gpu) + DynamicSwapInstaller.install_model(transformer_base, device=gpu) + DynamicSwapInstaller.install_model(transformer_f1, device=gpu) DynamicSwapInstaller.install_model(text_encoder, device=gpu) # Note: VAE, text_encoder_2, image_encoder will be loaded/offloaded as needed by the worker in low VRAM mode - print("DynamicSwap installed for transformer and text_encoder.") + print("DynamicSwap installed for transformers and text_encoder.") else: print("Moving all models to GPU for high VRAM mode...") text_encoder.to(gpu) text_encoder_2.to(gpu) image_encoder.to(gpu) vae.to(gpu) - transformer.to(gpu) + transformer_base.to(gpu) + transformer_f1.to(gpu) print("All models moved to GPU.") print("Model loading complete.") @@ -106,8 +114,9 @@ def load_models(): "tokenizer_2": tokenizer_2, "feature_extractor": feature_extractor, "image_encoder": image_encoder, - "transformer": transformer, - "high_vram": high_vram # Include vram mode info + "transformer_base": transformer_base, + "transformer_f1": transformer_f1, + "high_vram": high_vram } @@ -116,7 +125,7 @@ def load_models(): # You might need to handle HF_HOME environment variable here if running standalone # os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), '../hf_download'))) - # parser = argparse.ArgumentParser() # LoRA path no longer needed here + # parser = argparse.ArgumentParser() # parser.add_argument("--lora", type=str, default=None, help="Lora path for testing") # args = parser.parse_args() @@ -124,9 +133,10 @@ def load_models(): loaded_models = load_models() print(f"Models loaded: {list(loaded_models.keys())}") print(f"High VRAM mode detected: {loaded_models['high_vram']}") - # Add more checks if needed, e.g., checking model devices + # Add more checks if needed if loaded_models['high_vram']: - print(f"Transformer device: {loaded_models['transformer'].device}") + print(f"Base Transformer device: {loaded_models['transformer_base'].device}") + print(f"F1 Transformer device: {loaded_models['transformer_f1'].device}") else: print("Running in low VRAM mode, models might be on CPU or dynamically swapped.") @@ -139,11 +149,10 @@ def unload_models(models_dict): Explicitly unloads models and releases resources, especially GPU memory. """ print("Unloading models...") - model_keys = ["vae", "text_encoder", "text_encoder_2", "image_encoder", "transformer"] + model_keys = ["vae", "text_encoder", "text_encoder_2", "image_encoder", "transformer_base", "transformer_f1"] for key in model_keys: if key in models_dict: try: - # Directly delete the reference from the dictionary del models_dict[key] print(f"Removed reference to model: {key}") except Exception as e: @@ -164,7 +173,7 @@ def unload_models(models_dict): for key in other_keys: if key in models_dict: try: - del models_dict[key] # Remove reference from dict + del models_dict[key] print(f"Removed reference to: {key}") except Exception as e: print(f"Error removing reference to {key}: {e}") diff --git a/api/queue_manager.py b/api/queue_manager.py index d8063b94..f29791f9 100644 --- a/api/queue_manager.py +++ b/api/queue_manager.py @@ -44,6 +44,9 @@ class QueuedJob: progress_info: str = "" lora_scale: float = 1.0 # 追加: LoRA強度 lora_path: Optional[str] = None + # Add sampling mode and transformer model selection + sampling_mode: str = "reverse" # "reverse" or "forward" + transformer_model: str = "base" # "base" or "f1" # Add updated_at timestamp, default to current UTC time updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) # Add field for original Exif data (bytes) - will not be saved in JSON @@ -76,6 +79,8 @@ def to_dict(self): 'progress_info': self.progress_info, 'lora_scale': self.lora_scale, # 追加 'lora_path': self.lora_path, + 'sampling_mode': self.sampling_mode, # Add sampling_mode + 'transformer_model': self.transformer_model, # Add transformer_model 'updated_at': updated_at_iso, # Add updated_at } except Exception as e: @@ -126,6 +131,8 @@ def from_dict(cls, data): progress_info=data.get('progress_info', ''), lora_scale=data.get('lora_scale', 1.0), # 追加 lora_path=data.get('lora_path', None), + sampling_mode=data.get('sampling_mode', 'reverse'), # Add sampling_mode with default + transformer_model=data.get('transformer_model', 'base'), # Add transformer_model with default updated_at=updated_at_dt # Add updated_at ) except Exception as e: @@ -216,7 +223,7 @@ def save_image_to_temp(image: np.ndarray, job_id: str, prompt: str, seed: int, e return "" -def add_to_queue(prompt, image, original_exif: Optional[bytes], video_length, seed, use_teacache, gpu_memory_preservation, steps, cfg, gs, rs, status="pending", mp4_crf=16, lora_scale: float = 1.0, lora_path: Optional[str] = None): +def add_to_queue(prompt, image, original_exif: Optional[bytes], video_length, seed, use_teacache, gpu_memory_preservation, steps, cfg, gs, rs, status="pending", mp4_crf=16, lora_scale: float = 1.0, lora_path: Optional[str] = None, sampling_mode: str = "reverse", transformer_model: str = "base"): global job_queue try: # Generate a unique hex ID for the job @@ -245,6 +252,8 @@ def add_to_queue(prompt, image, original_exif: Optional[bytes], video_length, se mp4_crf=mp4_crf, lora_scale=lora_scale, lora_path=lora_path, + sampling_mode=sampling_mode, # Add sampling_mode + transformer_model=transformer_model, # Add transformer_model original_exif=original_exif # Store exif in job object (won't be saved to JSON) ) job_queue.append(job) diff --git a/api/worker.py b/api/worker.py index d10b3fab..d104fec3 100644 --- a/api/worker.py +++ b/api/worker.py @@ -1,14 +1,14 @@ import os import torch import numpy as np -from pathlib import Path # 追加 -import logging # 追加 -from PIL import Image # Removed ImageDraw, ImageFont -# from PIL.PngImagePlugin import PngInfo # No longer needed for JPEG saving +from pathlib import Path +import logging +from PIL import Image +# from PIL.PngImagePlugin import PngInfo import traceback -import base64 # Add base64 -import io # Add io -import einops # Add einops +import base64 +import io +import einops # Assuming models and tokenizers are loaded elsewhere and passed or accessed globally/via context # This will be refined when creating models.py and integrating @@ -19,7 +19,7 @@ encode_prompt_conds, vae_decode, vae_encode, - vae_decode_fake, # Add vae_decode_fake + vae_decode_fake, ) from diffusers_helper.utils import ( @@ -44,17 +44,17 @@ # Assuming queue_manager is available for status updates (relative import) from . import queue_manager -from .queue_manager import update_job_progress # Import the specific function -from . import settings # Import settings to get OUTPUTS_DIR -from diffusers_helper.load_lora import load_lora # 追加 +from .queue_manager import update_job_progress +from . import settings +from diffusers_helper.load_lora import load_lora # Define output folder using settings outputs_folder = settings.OUTPUTS_DIR -# os.makedirs(outputs_folder, exist_ok=True) # Directory creation handled in settings.py +# os.makedirs(outputs_folder, exist_ok=True) # Determine VRAM mode - consider moving to settings.py or detecting dynamically free_mem_gb = get_cuda_free_memory_gb(gpu) -high_vram = free_mem_gb > 60 # Threshold might need adjustment +high_vram = free_mem_gb > 60 print(f"Worker: Free VRAM {free_mem_gb} GB") print(f"Worker: High-VRAM Mode: {high_vram}") @@ -67,13 +67,13 @@ def worker(job: queue_manager.QueuedJob, models: dict): job (QueuedJob): The job object containing parameters. models (dict): A dictionary containing the loaded models and tokenizers. Expected keys: 'vae', 'text_encoder', 'text_encoder_2', - 'image_encoder', 'transformer', 'tokenizer', - 'tokenizer_2', 'feature_extractor'. + 'image_encoder', 'transformer_base', 'transformer_f1', + 'tokenizer', 'tokenizer_2', 'feature_extractor'. """ input_image_path = job.image_path prompt = job.prompt - # n_prompt = job.n_prompt # Assuming negative prompt might be added to QueuedJob - n_prompt = "" # Default negative prompt + # n_prompt = job.n_prompt + n_prompt = "" seed = job.seed total_second_length = job.video_length latent_window_size = 9 @@ -87,9 +87,13 @@ def worker(job: queue_manager.QueuedJob, models: dict): job_id = job.job_id lora_scale = job.lora_scale lora_path = job.lora_path - original_exif = job.original_exif # Get Exif data from job object + original_exif = job.original_exif + # --- Get sampling mode and transformer model from job --- + sampling_mode = job.sampling_mode + transformer_model_name = job.transformer_model + print(f"Job {job_id}: Sampling Mode='{sampling_mode}', Transformer Model='{transformer_model_name}'") - thumbnail_path = None # Initialize thumbnail_path + thumbnail_path = None # Update job status to processing, including the thumbnail path (will be updated again if thumbnail generated) # We update here initially in case thumbnail generation fails later @@ -100,7 +104,14 @@ def worker(job: queue_manager.QueuedJob, models: dict): text_encoder = models["text_encoder"] text_encoder_2 = models["text_encoder_2"] image_encoder = models["image_encoder"] - transformer = models["transformer"] + # --- Select the correct transformer model --- + if transformer_model_name == "f1": + transformer = models["transformer_f1"] + print(f"Job {job_id}: Using F1 Transformer model.") + else: + transformer = models["transformer_base"] + print(f"Job {job_id}: Using Base Transformer model.") + # --- End Transformer Selection --- tokenizer = models["tokenizer"] tokenizer_2 = models["tokenizer_2"] feature_extractor = models["feature_extractor"] @@ -115,39 +126,42 @@ def update_progress(step_info: str, percentage: float = 0.0, current_step: int = job_id=job_id, progress=percentage, step=current_step, - total=total_steps if total_steps > 0 else steps, # Use overall steps if section total is 0 + total=total_steps if total_steps > 0 else steps, info=step_info ) except Exception as e: print(f"Error updating job progress for {job_id}: {e}") # Decide if this error should halt the process or just be logged - update_progress("Starting ...", 0, 0, steps) # Initial progress + update_progress("Starting ...", 0, 0, steps) + + # Initialize history_pixels before the main try block + history_pixels = None try: - # Load input image + # Load input image try-except try: - pil_input_image = Image.open(input_image_path) # Keep this line to load the image - # logging.info(f"[Job {job_id}] Exif in input_image after open: {pil_input_image.info.get('exif') is not None}") # DEBUG: Removed + pil_input_image = Image.open(input_image_path) + # logging.info(f"[Job {job_id}] Exif in input_image after open: {pil_input_image.info.get('exif') is not None}") # --- Thumbnail Generation (Moved Here) --- try: # Generate thumbnail from the loaded input image (pil_input_image) - thumb_size = (128, 128) # Define thumbnail size (adjust as needed) + thumb_size = (128, 128) thumb_img = pil_input_image.copy() thumb_img.thumbnail(thumb_size, Image.Resampling.LANCZOS) thumbnail_filename = f"thumb_{job_id}.jpg" # Use previously initialized thumbnail_path variable thumbnail_path = os.path.join(settings.TEMP_QUEUE_IMAGES_DIR, thumbnail_filename) - thumb_img.save(thumbnail_path, "JPEG", quality=85) # Save as JPEG + thumb_img.save(thumbnail_path, "JPEG", quality=85) print(f"Job {job_id}: Thumbnail saved to {thumbnail_path}") # Update job status again with the actual thumbnail path queue_manager.update_job_status(job_id, "processing", thumbnail=thumbnail_path) except Exception as thumb_e: print(f"Job {job_id}: Warning - Failed to generate thumbnail: {thumb_e}") - thumbnail_path = None # Ensure path is None if generation fails (already set initially) + thumbnail_path = None - # input_image = np.array(pil_input_image) # Moved numpy conversion later + # input_image = np.array(pil_input_image) except FileNotFoundError: print(f"Error: Input image not found at {input_image_path}") @@ -163,7 +177,7 @@ def update_progress(step_info: str, percentage: float = 0.0, current_step: int = # Clean GPU if not high_vram if not high_vram: - update_progress("Cleaning GPU memory...", 1, 0, steps) # Small progress increment + update_progress("Cleaning GPU memory...", 1, 0, steps) unload_complete_models( text_encoder, text_encoder_2, image_encoder, vae, transformer ) @@ -210,7 +224,7 @@ def update_progress(step_info: str, percentage: float = 0.0, current_step: int = # Fallback to zeros if negative encoding fails but cfg != 1 if ( llama_vec is None or clip_l_pooler is None - ): # Should not happen due to earlier check, but safety first + ): print( f"Error: Cannot create negative embeddings because positive embeddings are None for job {job_id}" ) @@ -234,7 +248,6 @@ def update_progress(step_info: str, percentage: float = 0.0, current_step: int = ) # --- LoRA Loading (moved here, after text encoding, before image processing) --- - # --- LoRA Loading --- # Construct full path if lora_path is provided (assumed to be filename from /loras endpoint) full_lora_path = None if lora_path: @@ -247,7 +260,7 @@ def update_progress(step_info: str, percentage: float = 0.0, current_step: int = if full_lora_path and os.path.exists(full_lora_path): print(f"Job {job_id}: Loading LoRA from: {full_lora_path} with scale {lora_scale}") - update_progress(f"Loading LoRA '{lora_path}' (scale={lora_scale})...", 11, 0, steps) # Progress update with filename + update_progress(f"Loading LoRA '{lora_path}' (scale={lora_scale})...", 11, 0, steps) try: # load_lora expects directory and filename separately lora_dir, lora_name = os.path.split(full_lora_path) @@ -256,19 +269,18 @@ def update_progress(step_info: str, percentage: float = 0.0, current_step: int = print(f"Job {job_id}: LoRA loaded successfully.") except Exception as e: print(f"Job {job_id}: Error loading LoRA: {e}") - # LoRAロード失敗時の処理 (例: ログ出力して続行、ジョブを失敗させるなど) # queue_manager.update_job_status(job_id, f"failed - LoRA load error: {e}") # return - elif lora_path: # Only print warning if lora_path was provided but file not found + elif lora_path: print(f"Job {job_id}: Warning - LoRA path '{lora_path}' specified but file not found at '{full_lora_path}'.") else: print(f"Job {job_id}: No LoRA path specified, skipping LoRA loading.") # --- End LoRA Loading --- # Processing input image (Convert to numpy array here if needed for processing) - input_image = np.array(pil_input_image) # Convert PIL image to numpy array now - update_progress("Image processing ...", 12, 0, steps) # Progress update (percentage adjusted) - input_image = np.squeeze(input_image) # Ensure 3D + input_image = np.array(pil_input_image) + update_progress("Image processing ...", 12, 0, steps) + input_image = np.squeeze(input_image) if input_image.ndim != 3 or input_image.shape[2] != 3: print(f"Error: Invalid image shape {input_image.shape} for job {job_id}") queue_manager.update_job_status(job_id, "failed - invalid image shape") @@ -277,7 +289,7 @@ def update_progress(step_info: str, percentage: float = 0.0, current_step: int = H, W, C = input_image.shape height, width = find_nearest_bucket( H, W, resolution=640 - ) # Assuming default resolution + ) input_image_np = resize_and_center_crop( input_image, target_width=width, target_height=height ) @@ -291,7 +303,7 @@ def update_progress(step_info: str, percentage: float = 0.0, current_step: int = # Prepare save arguments for JPEG with Exif save_kwargs = { "format": "JPEG", - "quality": 70, # Lower quality for smaller file size + "quality": 70, } # Add exif data if it exists (retrieved from the job object) if original_exif: @@ -302,7 +314,7 @@ def update_progress(step_info: str, percentage: float = 0.0, current_step: int = # Save processed input image as JPEG with or without Exif processed_pil_image.save(processed_input_image_path, **save_kwargs) - # logging.info(f"[Job {job_id}] Saved processed input image to {processed_input_image_path} (JPEG)") # DEBUG: Removed + # logging.info(f"[Job {job_id}] Saved processed input image to {processed_input_image_path} (JPEG)") input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1 input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None] @@ -333,250 +345,390 @@ def update_progress(step_info: str, percentage: float = 0.0, current_step: int = transformer.dtype ) - # Sampling + # Sampling Logic based on sampling_mode update_progress("Start sampling ...", 25, 0, steps) rnd = torch.Generator("cpu").manual_seed(seed) - num_frames = latent_window_size * 4 - 3 - - history_latents = torch.zeros( - size=(1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32 - ).cpu() - history_pixels = None - total_generated_latent_frames = 0 - latent_paddings = reversed(range(total_latent_sections)) - if total_latent_sections > 4: - latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0] - else: - latent_paddings = list(latent_paddings) + # ============================================================== + # === Forward Sampling Mode (like demo_gradio_f1.py) === + # ============================================================== + if sampling_mode == "forward": + print(f"Job {job_id}: Entering FORWARD sampling mode.") + history_latents = torch.zeros(size=(1, 16, 16 + 2 + 1, height // 8, width // 8), dtype=torch.float32).cpu() + # history_pixels initialized before try block + + history_latents = torch.cat([history_latents, start_latent.to(history_latents)], dim=2) + total_generated_latent_frames = 1 + + sampling_step_count = total_latent_sections + current_sampling_step = 0 + + for section_index in range(total_latent_sections): + current_sampling_step += 1 + section_progress_start = 25 + (current_sampling_step - 1) * (70 / sampling_step_count) + section_progress_end = 25 + current_sampling_step * (70 / sampling_step_count) + + # Check for cancellation signal + current_job_status_section_start = queue_manager.get_job_by_id(job_id) + if current_job_status_section_start and current_job_status_section_start.status == "cancelled": + print(f"Job {job_id} cancellation detected at start of forward section {current_sampling_step}.") + # Update status and exit worker function + queue_manager.update_job_status(job_id, "cancelled") + return - sampling_step_count = len(latent_paddings) - current_sampling_step = 0 + print(f'Job {job_id}: Forward section_index = {section_index}, total_latent_sections = {total_latent_sections}') + update_progress( + f"Sampling forward section {current_sampling_step}/{sampling_step_count}", + section_progress_start, 0, steps + ) - for latent_padding in latent_paddings: - current_sampling_step += 1 - section_progress_start = 25 + (current_sampling_step - 1) * ( - 70 / sampling_step_count - ) - section_progress_end = 25 + current_sampling_step * ( - 70 / sampling_step_count - ) + if not high_vram: + unload_complete_models() + move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation) + + if use_teacache: + transformer.initialize_teacache(enable_teacache=True, num_steps=steps) + else: + transformer.initialize_teacache(enable_teacache=False) + + # K-Diffusion Sampling Callback (Forward Mode) + def callback_forward(d): + current_cb_step = d['i'] + 1 + total_cb_steps = steps + section_progress_fraction = current_cb_step / total_cb_steps + overall_sampling_progress = section_progress_fraction * (70 / sampling_step_count) + overall_percentage = section_progress_start + overall_sampling_progress + + hint = f'Sampling forward section {current_sampling_step}/{sampling_step_count} - Step {current_cb_step}/{total_cb_steps}' + print(f"Job {job_id} Progress: {hint} ({overall_percentage:.1f}%)") + update_job_progress(job_id=job_id, progress=overall_percentage, step=current_cb_step, total=total_cb_steps, info=hint) + + # Preview Generation (same as reverse mode callback) + try: + if current_cb_step % 2 == 0 or current_cb_step == total_cb_steps: + preview_latent = d['denoised'] + preview_tensor = vae_decode_fake(preview_latent) + preview_np = (preview_tensor * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8) + preview_np_rearranged = einops.rearrange(preview_np, 'b c t h w -> (b h) (t w) c') + preview_image = Image.fromarray(preview_np_rearranged) + buffer = io.BytesIO() + preview_image.save(buffer, format="JPEG", quality=75) + buffer.seek(0) + preview_base64_data = base64.b64encode(buffer.getvalue()).decode('utf-8') + preview_base64_string = f"data:image/jpeg;base64,{preview_base64_data}" + queue_manager.update_current_preview(job_id, preview_base64_string) + except Exception as preview_e: + logging.warning(f"Job {job_id}: Error generating preview at forward step {current_cb_step}: {preview_e}") + + # Cancellation Check + current_job_status_inner = queue_manager.get_job_by_id(job_id) + if current_job_status_inner and current_job_status_inner.status == "cancelled": + print(f"Job {job_id} cancelled during forward sampling step {current_cb_step}.") + raise InterruptedError("Job cancelled") + + # Prepare arguments for sample_hunyuan (Forward Mode) + indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0) + clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split([1, 16, 2, 1, latent_window_size], dim=1) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1) + + clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]):, :, :].split([16, 2, 1], dim=2) + clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2) - # Check for cancellation signal at the beginning of each section - current_job_status_section_start = queue_manager.get_job_by_id(job_id) - if current_job_status_section_start and current_job_status_section_start.status == "cancelled": - print(f"Job {job_id} cancellation detected at start of section {current_sampling_step}.") - return + try: + generated_latents = sample_hunyuan( + transformer=transformer, + sampler='unipc', + width=width, + height=height, + frames=latent_window_size * 4 - 3, + real_guidance_scale=cfg, + distilled_guidance_scale=gs, + guidance_rescale=rs, + num_inference_steps=steps, + generator=rnd, + prompt_embeds=llama_vec, + prompt_embeds_mask=llama_attention_mask, + prompt_poolers=clip_l_pooler, + negative_prompt_embeds=llama_vec_n, + negative_prompt_embeds_mask=llama_attention_mask_n, + negative_prompt_poolers=clip_l_pooler_n, + device=gpu, + dtype=torch.bfloat16, + image_embeddings=image_encoder_last_hidden_state, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + callback=callback_forward, + ) + except InterruptedError: + print(f"Job {job_id}: Cancellation detected during forward sampling.") + # Update status and exit worker function + queue_manager.update_job_status(job_id, "cancelled") + return - is_last_section = latent_padding == 0 - latent_padding_size = latent_padding * latent_window_size + # Update history (Forward Mode) + total_generated_latent_frames += int(generated_latents.shape[2]) + history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2) - print( - f"Job {job_id}: latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}" - ) - # Update progress for the start of the section - update_progress( - f"Sampling section {current_sampling_step}/{sampling_step_count}", - section_progress_start, - 0, # Step count resets for the section's callback - steps # Total steps for this section remains the overall steps parameter - ) + # Decode and append frames (Forward Mode) + update_progress( + f"VAE decoding forward section {current_sampling_step}/{sampling_step_count}", + section_progress_end - 1, 0, steps + ) + if not high_vram: + unload_complete_models() + load_model_as_complete(vae, target_device=gpu) + + real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :] + + if history_pixels is None: + # This should only happen on the very first section, but the logic is slightly different + # demo_gradio_f1 starts history_pixels after the first sampling loop completes. + # We decode the *entire* history up to this point. + history_pixels = vae_decode(real_history_latents, vae).cpu() + else: + # Decode the newly generated section and append + section_latent_frames = latent_window_size * 2 + overlapped_frames = latent_window_size * 4 - 3 + + # Decode only the relevant part of the history for the current section + # We need the last `section_latent_frames` from the *generated* latents part + current_pixels = vae_decode(real_history_latents[:, :, -section_latent_frames:], vae).cpu() + history_pixels = soft_append_bcthw(history_pixels, current_pixels, overlapped_frames) + + print(f'Job {job_id}: Decoded forward. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}') + + # Save intermediate result (optional, can be useful for debugging) + # intermediate_filename = os.path.join(outputs_folder, f'{job_id}_forward_part_{current_sampling_step}.mp4') + # save_bcthw_as_mp4(history_pixels, intermediate_filename, crf=mp4_crf, fps=30) + + # ============================================================== + # === Reverse Sampling Mode (Original Logic) === + # ============================================================== + else: + print(f"Job {job_id}: Entering REVERSE sampling mode (default).") + # Correct initialization for reverse mode + history_latents = torch.zeros( + size=(1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32 + ).cpu() + # history_pixels initialized before try block + total_generated_latent_frames = 0 + + latent_paddings = reversed(range(total_latent_sections)) + if total_latent_sections > 4: + latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0] + else: + latent_paddings = list(latent_paddings) + + sampling_step_count = len(latent_paddings) + current_sampling_step = 0 + + for latent_padding in latent_paddings: + current_sampling_step += 1 + section_progress_start = 25 + (current_sampling_step - 1) * (70 / sampling_step_count) + section_progress_end = 25 + current_sampling_step * (70 / sampling_step_count) + + # Check for cancellation signal at the beginning of each section + current_job_status_section_start = queue_manager.get_job_by_id(job_id) + if current_job_status_section_start and current_job_status_section_start.status == "cancelled": + print(f"Job {job_id} cancellation detected at start of reverse section {current_sampling_step}.") + # Update status and exit worker function + queue_manager.update_job_status(job_id, "cancelled") + return - indices = torch.arange( - 0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16]) - ).unsqueeze(0) - ( - clean_latent_indices_pre, - blank_indices, - latent_indices, - clean_latent_indices_post, - clean_latent_2x_indices, - clean_latent_4x_indices, - ) = indices.split( - [1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1 - ) - clean_latent_indices = torch.cat( - [clean_latent_indices_pre, clean_latent_indices_post], dim=1 - ) + is_last_section = latent_padding == 0 + latent_padding_size = latent_padding * latent_window_size - clean_latents_pre = start_latent.to(history_latents) - clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[ - :, :, : 1 + 2 + 16, :, : - ].split([1, 2, 16], dim=2) - clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) - - if not high_vram: - unload_complete_models() - move_model_to_device_with_memory_preservation( - transformer, - target_device=gpu, - preserved_memory_gb=gpu_memory_preservation, + print( + f"Job {job_id}: latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}" + ) + # Update progress for the start of the section + update_progress( + f"Sampling reverse section {current_sampling_step}/{sampling_step_count}", + section_progress_start, + 0, + steps ) - if use_teacache: - transformer.initialize_teacache(enable_teacache=True, num_steps=steps) - else: - transformer.initialize_teacache(enable_teacache=False) - - # K-Diffusion Sampling Callback - def callback(d): - # --- Progress Update --- - current_cb_step = d['i'] + 1 # 1-based step for the current section - total_cb_steps = steps # Total steps for this section - - # Calculate overall progress percentage - # Sampling spans from 25% to 95% (70% total) - section_progress_fraction = current_cb_step / total_cb_steps - overall_sampling_progress = section_progress_fraction * (70 / sampling_step_count) - overall_percentage = section_progress_start + overall_sampling_progress - - hint = f'Sampling section {current_sampling_step}/{sampling_step_count} - Step {current_cb_step}/{total_cb_steps}' - print(f"Job {job_id} Progress: {hint} ({overall_percentage:.1f}%)") - - # Update progress via queue_manager - update_job_progress( - job_id=job_id, - progress=overall_percentage, - step=current_cb_step, - total=total_cb_steps, - info=hint + indices = torch.arange( + 0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16]) + ).unsqueeze(0) + ( + clean_latent_indices_pre, + blank_indices, + latent_indices, + clean_latent_indices_post, + clean_latent_2x_indices, + clean_latent_4x_indices, + ) = indices.split( + [1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1 ) + clean_latent_indices = torch.cat( + [clean_latent_indices_pre, clean_latent_indices_post], dim=1 + ) + + clean_latents_pre = start_latent.to(history_latents) + clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[ + :, :, : 1 + 2 + 16, :, : + ].split([1, 2, 16], dim=2) + clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) + + if not high_vram: + unload_complete_models() + move_model_to_device_with_memory_preservation( + transformer, + target_device=gpu, + preserved_memory_gb=gpu_memory_preservation, + ) + + if use_teacache: + transformer.initialize_teacache(enable_teacache=True, num_steps=steps) + else: + transformer.initialize_teacache(enable_teacache=False) + + # K-Diffusion Sampling Callback (Reverse Mode) + def callback(d): + # --- Progress Update --- + current_cb_step = d['i'] + 1 + total_cb_steps = steps + + # Calculate overall progress percentage + section_progress_fraction = current_cb_step / total_cb_steps + overall_sampling_progress = section_progress_fraction * (70 / sampling_step_count) + overall_percentage = section_progress_start + overall_sampling_progress + + hint = f'Sampling reverse section {current_sampling_step}/{sampling_step_count} - Step {current_cb_step}/{total_cb_steps}' + print(f"Job {job_id} Progress: {hint} ({overall_percentage:.1f}%)") + + # Update progress via queue_manager + update_job_progress( + job_id=job_id, + progress=overall_percentage, + step=current_cb_step, + total=total_cb_steps, + info=hint + ) + + # --- Preview Generation --- + try: + if current_cb_step % 2 == 0 or current_cb_step == total_cb_steps: + preview_latent = d['denoised'] + preview_tensor = vae_decode_fake(preview_latent) # Use vae_decode_fake + + preview_np = (preview_tensor * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8) + preview_np_rearranged = einops.rearrange(preview_np, 'b c t h w -> (b h) (t w) c') + + preview_image = Image.fromarray(preview_np_rearranged) + + buffer = io.BytesIO() + preview_image.save(buffer, format="JPEG", quality=75) + buffer.seek(0) + + preview_base64_data = base64.b64encode(buffer.getvalue()).decode('utf-8') + preview_base64_string = f"data:image/jpeg;base64,{preview_base64_data}" + + queue_manager.update_current_preview(job_id, preview_base64_string) + except Exception as preview_e: + logging.warning(f"Job {job_id}: Error generating preview at reverse step {current_cb_step}: {preview_e}") + + # --- Cancellation Check --- + current_job_status_inner = queue_manager.get_job_by_id(job_id) + if current_job_status_inner and current_job_status_inner.status == "cancelled": + print(f"Job {job_id} cancelled during reverse sampling step {current_cb_step}.") + raise InterruptedError("Job cancelled") - # --- Preview Generation --- try: - # Only generate preview every N steps or on the last step to reduce overhead - if current_cb_step % 2 == 0 or current_cb_step == total_cb_steps: - preview_latent = d['denoised'] - preview_tensor = vae_decode_fake(preview_latent) # Use vae_decode_fake - - # Convert tensor to PIL Image - preview_np = (preview_tensor * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8) - # Rearrange: b c t h w -> (b h) (t w) c (assuming single batch) - # vae_decode_fake likely returns B C T H W, where T might be > 1 - preview_np_rearranged = einops.rearrange(preview_np, 'b c t h w -> (b h) (t w) c') - - preview_image = Image.fromarray(preview_np_rearranged) - - # Save image to buffer as JPEG - buffer = io.BytesIO() - preview_image.save(buffer, format="JPEG", quality=75) # Adjust quality as needed - buffer.seek(0) - - # Encode to Base64 - preview_base64_data = base64.b64encode(buffer.getvalue()).decode('utf-8') - preview_base64_string = f"data:image/jpeg;base64,{preview_base64_data}" - - # Update preview in queue_manager (in-memory) - queue_manager.update_current_preview(job_id, preview_base64_string) - # logging.debug(f"Job {job_id}: Updated preview at step {current_cb_step}") # Optional debug log - except Exception as preview_e: - # Log error but don't necessarily stop the whole process - logging.warning(f"Job {job_id}: Error generating preview at step {current_cb_step}: {preview_e}") - - # --- Cancellation Check --- - current_job_status_inner = queue_manager.get_job_by_id(job_id) # Use the function that reads the file - if current_job_status_inner and current_job_status_inner.status == "cancelled": - # Use current_cb_step which is defined in this scope - print(f"Job {job_id} cancelled during sampling step {current_cb_step}.") - raise InterruptedError( - "Job cancelled" - ) # Raise exception to stop sampling + num_frames = latent_window_size * 4 - 3 + generated_latents = sample_hunyuan( + transformer=transformer, + sampler='unipc', + width=width, + height=height, + frames=num_frames, + real_guidance_scale=cfg, + distilled_guidance_scale=gs, + guidance_rescale=rs, + num_inference_steps=steps, + generator=rnd, + prompt_embeds=llama_vec, + prompt_embeds_mask=llama_attention_mask, + prompt_poolers=clip_l_pooler, + negative_prompt_embeds=llama_vec_n, + negative_prompt_embeds_mask=llama_attention_mask_n, + negative_prompt_poolers=clip_l_pooler_n, + device=gpu, + dtype=torch.bfloat16, + image_embeddings=image_encoder_last_hidden_state, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + callback=callback, + ) + except InterruptedError: + print(f"Job {job_id}: Cancellation detected during reverse sampling.") + # Update status and exit worker function + queue_manager.update_job_status(job_id, "cancelled") + return - try: - # args_to_check = { ... } - # none_args = [name for name, val in args_to_check.items() if val is None] - # if none_args: - # error_msg = f"Error: The following arguments are None before calling sample_hunyuan: {', '.join(none_args)} for job {job_id}" - # print(error_msg) - # queue_manager.update_job_status(job_id, "failed - internal error (None arg)") - # return # Stop processing this section - - generated_latents = sample_hunyuan( - transformer=transformer, - sampler='unipc', - width=width, - height=height, - frames=num_frames, - real_guidance_scale=cfg, - distilled_guidance_scale=gs, - guidance_rescale=rs, - num_inference_steps=steps, - generator=rnd, - prompt_embeds=llama_vec, - prompt_embeds_mask=llama_attention_mask, - prompt_poolers=clip_l_pooler, - negative_prompt_embeds=llama_vec_n, - negative_prompt_embeds_mask=llama_attention_mask_n, - negative_prompt_poolers=clip_l_pooler_n, - device=gpu, - dtype=torch.bfloat16, - image_embeddings=image_encoder_last_hidden_state, - latent_indices=latent_indices, - clean_latents=clean_latents, - clean_latent_indices=clean_latent_indices, - clean_latents_2x=clean_latents_2x, - clean_latent_2x_indices=clean_latent_2x_indices, - clean_latents_4x=clean_latents_4x, - clean_latent_4x_indices=clean_latent_4x_indices, - callback=callback, - # seed=seed + current_sampling_step, - # positive_image_encoder_hidden_states=image_encoder_last_hidden_state, - # negative_image_encoder_hidden_states=torch.zeros_like(image_encoder_last_hidden_state), + # Update history (Reverse Mode) + if is_last_section: + generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2) + + total_generated_latent_frames += int(generated_latents.shape[2]) + history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2) + + # Decode and append frames (Reverse Mode) + update_progress( + f"VAE decoding reverse section {current_sampling_step}/{sampling_step_count}", + section_progress_end - 1, 0, steps ) - except InterruptedError: - # Job was cancelled during sampling via callback - return # Exit worker function - - # Update history - # Update total generated frames *before* updating history (moved from L446) - total_generated_latent_frames += int(generated_latents.shape[2]) - # Update history by concatenating (like demo_gradio.py L557) - history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2) - - # Decode and append frames - update_progress( - f"VAE decoding section {current_sampling_step}/{sampling_step_count}", - section_progress_end - 1, # Approximate percentage - 0, # Reset step count for this phase - steps # Use overall steps as total for this phase marker - ) - if not high_vram: - unload_complete_models() - load_model_as_complete(vae, target_device=gpu) + if not high_vram: + unload_complete_models() + load_model_as_complete(vae, target_device=gpu) - # Use the full history for decoding and appending (like demo_gradio.py L562-571) - real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :] + real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :] - if history_pixels is None: - # Decode the entire history for the first section - history_pixels = vae_decode(real_history_latents, vae).cpu() - else: - # Calculate frames for current section and overlap - # Note: demo_gradio.py L567 seems to have a potential off-by-one or logic mismatch - # compared to L553/L555. Using the logic from demo_gradio.py L567 & L570 for now. - section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2) - overlapped_frames = latent_window_size * 4 - 3 # Same calculation as demo_gradio.py L568 + if history_pixels is None: + history_pixels = vae_decode(real_history_latents, vae).cpu() + else: + section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2) + overlapped_frames = latent_window_size * 4 - 3 + + current_pixels = vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu() + history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames) - # Decode only the relevant part of the history for the current section - current_pixels = vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu() - # Append using the calculated overlap (like demo_gradio.py L571) - history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames) + print(f'Job {job_id}: Decoded reverse. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}') - # total_generated_latent_frames update moved before history_latents update (see L419 block) + # Save intermediate result (optional) + # intermediate_filename = os.path.join(outputs_folder, f'{job_id}_reverse_part_{current_sampling_step}.mp4') + # save_bcthw_as_mp4(history_pixels, intermediate_filename, crf=mp4_crf, fps=30) - # Save intermediate result (optional) - # intermediate_filename = os.path.join(outputs_folder, f'{job_id}_part_{current_sampling_step}.mp4') - # save_bcthw_as_mp4(history_pixels, intermediate_filename, crf=mp4_crf, frame_rate=30) + if is_last_section: + break - # Final save + # ============================================================== + # === Final Save (Common to both modes) === + # ============================================================== update_progress("Saving final video...", 98, 0, steps) final_filename = os.path.join(outputs_folder, f"{job_id}.mp4") - save_bcthw_as_mp4(history_pixels, final_filename, crf=mp4_crf, fps=30) # Use fps instead of frame_rate - - # Update job status and final progress - update_progress("Finished", 100, steps, steps) # Mark 100% progress - queue_manager.update_job_status(job_id, "completed") - print(f"Job {job_id} completed successfully. Output: {final_filename}") + if history_pixels is not None: + save_bcthw_as_mp4(history_pixels, final_filename, crf=mp4_crf, fps=30) + # Update job status and final progress + update_progress("Finished", 100, steps, steps) # Mark 100% progress + queue_manager.update_job_status(job_id, "completed") + print(f"Job {job_id} completed successfully. Output: {final_filename}") + else: + # This should not happen if sampling ran, but handle defensively + print(f"Error: history_pixels is None after sampling for job {job_id}. Cannot save video.") + queue_manager.update_job_status(job_id, "failed - internal error (no pixels)") except Exception as e: print(f"Error processing job {job_id}: {str(e)}") @@ -605,6 +757,5 @@ def callback(d): # Example usage (for testing purposes, would be called by a background task runner) if __name__ == "__main__": print("Worker module loaded. Contains the 'worker' function.") - # Add test code here if needed, e.g., creating a dummy job and models - # and calling worker(dummy_job, dummy_models) + # Add test code here if needed pass diff --git a/plan/api_sampling_mode_plan.md b/plan/api_sampling_mode_plan.md new file mode 100644 index 00000000..2ccbee42 --- /dev/null +++ b/plan/api_sampling_mode_plan.md @@ -0,0 +1,82 @@ +# APIでのビデオ生成モード選択計画 + +`api/` ディレクトリ内のコードを変更し、`demo_gradio_f1.py` と同様の生成方法(順方向サンプリングと対応するモデル)も選択できるようにする計画。 + +**目標:** API 経由でビデオ生成をリクエストする際に、従来の逆方向サンプリング (`base`) と、`f1` スタイルの順方向サンプリング (`forward`) を選択可能にする。 + +**具体的な変更点:** + +1. **モデルのロード (`api/models.py`):** + * `load_models` 関数で、既存の `transformer` (`lllyasviel/FramePackI2V_HY`) に加えて、`f1` スタイルの `transformer_f1` (`lllyasviel/FramePack_F1_I2V_HY_20250503`) もロードするように変更します。 + * ロードされたモデルは、区別できるキー(例: `'transformer_base'`, `'transformer_f1'`)で辞書に格納します。 + +2. **ジョブキュー (`api/queue_manager.py`):** + * `QueuedJob` データクラスに、以下のフィールドを追加します。 + * `sampling_mode: str` (値: `"reverse"` または `"forward"`, デフォルト: `"reverse"`) + * `transformer_model: str` (値: `"base"` または `"f1"`, デフォルト: `"base"`) + * `to_dict`, `from_dict` メソッドを更新し、新しいフィールドを含めます。 + * `add_to_queue` 関数の引数に `sampling_mode` と `transformer_model` を追加し、`QueuedJob` オブジェクト生成時にこれらの値を設定するようにします。 + +3. **API エンドポイント (`api/api.py`):** + * `/generate` エンドポイントの `Form` パラメータに以下を追加します。 + * `sampling_mode: str = Form("reverse", description="Sampling loop direction ('reverse' or 'forward').")` + * `transformer_model: str = Form("base", description="Transformer model to use ('base' or 'f1').")` + * `queue_manager.add_to_queue` を呼び出す際に、これらの新しいパラメータを渡します。 + +4. **ワーカーロジック (`api/worker.py`):** + * `worker` 関数の冒頭で、`job` オブジェクトから `sampling_mode` と `transformer_model` を取得します。 + * `transformer_model` の値に基づいて、`models` 辞書から適切な Transformer モデルを選択して使用します。 + * サンプリングループの部分を `if job.sampling_mode == "forward":` と `else:` で分岐させます。 + * **`forward` の場合:** `demo_gradio_f1.py` の L188-L287 のロジック(`history_latents` の初期化・更新、`sample_hunyuan` への引数準備、`vae_decode`, `soft_append_bcthw` の呼び出し)を実装します。 + * **`else` (`reverse`) の場合:** 現在の `api/worker.py` の L341-L575 のロジックを維持します。 + * `callback` 関数内のプログレス計算も、選択されたモードに応じて適切に表示されるように調整します(特にステップの進捗を示すテキスト)。 + +**処理フロー図 (Mermaid):** + +```mermaid +graph TD + A[API Request /generate] -- Job Params (prompt, image, sampling_mode, transformer_model...) --> B(api.py: generate_video); + B -- image_np, params --> C(queue_manager.py: add_to_queue); + C -- Creates QueuedJob (with mode/model) --> D(job_queue.json); + E(background_worker_task) -- Checks queue --> F(queue_manager.py: get_next_job); + F -- Reads job_queue.json --> G{Job Found?}; + G -- Yes --> H(Returns QueuedJob); + G -- No --> E; + H -- QueuedJob, loaded_models --> I(worker.py: worker); + I -- Gets transformer_model --> J{Select Transformer}; + J -- transformer_model == 'f1' --> K[Use Transformer F1]; + J -- else --> L[Use Transformer Base]; + I -- Gets sampling_mode --> M{Select Sampling Loop}; + M -- sampling_mode == 'forward' --> N[Forward Sampling Loop (f1 style)]; + M -- else --> O[Reverse Sampling Loop (base style)]; + K & N -- Use F1 Model & Forward Loop --> P(Prepare Args for F1); + L & O -- Use Base Model & Reverse Loop --> Q(Prepare Args for Base); + P -- Calls sample_hunyuan --> R(Generate Latents); + Q -- Calls sample_hunyuan --> R; + R -- Latents --> S(VAE Decode & Append); + S -- Pixels --> T(Save MP4); + T -- MP4 Path --> U(Update Job Status: completed); + I -- Updates Progress --> V(queue_manager.py: update_job_progress); + I -- Updates Status --> W(queue_manager.py: update_job_status); + + subgraph Model Loading [api/models.py] + direction LR + ML1[Load Base Transformer] + ML2[Load F1 Transformer] + ML3[Load Other Models (VAE, Encoders...)] + end + + subgraph Job Definition [api/queue_manager.py] + direction LR + JD1[QueuedJob Class] + JD1 -- Add --> JD2[sampling_mode] + JD1 -- Add --> JD3[transformer_model] + end + + subgraph API Endpoint [api/api.py] + direction LR + AE1[/generate Endpoint] + AE1 -- Add Form Param --> AE2[sampling_mode] + AE1 -- Add Form Param --> AE3[transformer_model] + end +``` diff --git a/tests/test_api.py b/tests/test_api.py index e3c8af64..67cdd8a1 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -8,7 +8,7 @@ # import asyncio # Not directly used when only AsyncMock is needed from unittest.mock import mock_open from PIL import Image -from api import queue_manager, settings # settings をインポート +from api import queue_manager, settings # Assuming your FastAPI app instance is named 'app' in 'api/api.py' from api.api import app @@ -139,6 +139,10 @@ def test_generate_job_success(mocker, lora_path_param, lora_scale_param, expecte assert response_json["message"] == "Video generation job added to queue." # Verify that add_to_queue was called with the correct arguments + # Determine expected transformer model based on sampling mode (default is reverse -> base) + # In a real scenario with parameterized sampling_mode, this would need adjustment. + expected_transformer_model = "base" + mock_add_to_queue.assert_called_once_with( prompt=data["prompt"], image=mocker.ANY, @@ -154,6 +158,8 @@ def test_generate_job_success(mocker, lora_path_param, lora_scale_param, expecte mp4_crf=data["mp4_crf"], lora_scale=expected_lora_scale_in_queue, lora_path=expected_lora_path_in_queue, + sampling_mode="reverse", + transformer_model=expected_transformer_model, status="pending" )