diff --git a/lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py b/lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py index 5560c7cf2..eb95f7fab 100755 --- a/lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py +++ b/lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py @@ -218,6 +218,10 @@ def infer(self, text, image_list=None): drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in text] + token_lengths = [len(ids) for ids in self.tokenizer(txt, add_special_tokens=True)["input_ids"]] + max_token_len = max(token_lengths) + if max_token_len > self.tokenizer_max_length + drop_idx: + raise ValueError(f"Input text token length ({max_token_len - drop_idx}) exceeds ({self.tokenizer_max_length}). Please shorten the input text.") model_inputs = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(AI_DEVICE) encoder_hidden_states = self.text_encoder( input_ids=model_inputs.input_ids, diff --git a/lightx2v/server/api/server.py b/lightx2v/server/api/server.py index ff1ac0ba5..7027b5a1a 100644 --- a/lightx2v/server/api/server.py +++ b/lightx2v/server/api/server.py @@ -145,7 +145,8 @@ async def _process_single_task(self, task_info: Any): except Exception as e: logger.exception(f"Task {task_id} processing failed: {str(e)}") - task_manager.fail_task(task_id, str(e)) + original_et = getattr(e, "original_error_type", "") or "" + task_manager.fail_task(task_id, str(e), error_type=original_et or None) finally: if lock_acquired: task_manager.release_processing_lock(task_id) diff --git a/lightx2v/server/api/tasks/image.py b/lightx2v/server/api/tasks/image.py index 5855aaac7..ad5d2e67a 100644 --- a/lightx2v/server/api/tasks/image.py +++ b/lightx2v/server/api/tasks/image.py @@ -34,7 +34,11 @@ async def _wait_task_and_stream_result(task_id: str, timeout_seconds: int, poll_ raise HTTPException(status_code=500, detail=f"Task completed but no in-memory image found: {task_id}") if status == TaskStatus.FAILED.value: - raise HTTPException(status_code=500, detail=task_status.get("error", "Task failed")) + error_type = task_status.get("error_type", "") + error_detail = task_status.get("error", "Task failed") + if error_type == "ValueError": + raise HTTPException(status_code=413, detail=error_detail) + raise HTTPException(status_code=500, detail=error_detail) if status == TaskStatus.CANCELLED.value: raise HTTPException(status_code=409, detail=task_status.get("error", "Task cancelled")) @@ -108,6 +112,8 @@ async def create_image_task(message: ImageTaskRequest): save_result_path=message.save_result_path, ) except RuntimeError as e: + if getattr(e, "original_error_type", "") == "ValueError": + raise HTTPException(status_code=413, detail=str(e)) raise HTTPException(status_code=503, detail=str(e)) except Exception as e: logger.error(f"Failed to create image task: {e}") @@ -168,6 +174,8 @@ async def create_image_task_sync( raise except RuntimeError as e: + if getattr(e, "original_error_type", "") == "ValueError": + raise HTTPException(status_code=413, detail=str(e)) raise HTTPException(status_code=503, detail=str(e)) except HTTPException: raise @@ -229,6 +237,8 @@ async def save_file_async(file: UploadFile, target_dir: Path) -> str: save_result_path=message.save_result_path, ) except RuntimeError as e: + if getattr(e, "original_error_type", "") == "ValueError": + raise HTTPException(status_code=413, detail=str(e)) raise HTTPException(status_code=503, detail=str(e)) except Exception as e: logger.error(f"Failed to create image form task: {e}") diff --git a/lightx2v/server/services/generation/base.py b/lightx2v/server/services/generation/base.py index ad5d93d62..471e2712d 100644 --- a/lightx2v/server/services/generation/base.py +++ b/lightx2v/server/services/generation/base.py @@ -190,7 +190,10 @@ async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[A ) else: error_msg = result.get("error", "Inference failed") - raise RuntimeError(error_msg) + error_type = result.get("error_type", "") + exc = RuntimeError(error_msg) + exc.original_error_type = error_type + raise exc except Exception as e: logger.exception(f"Task {message.task_id} processing failed: {str(e)}") diff --git a/lightx2v/server/services/generation/image.py b/lightx2v/server/services/generation/image.py index abe82aa9f..885448693 100644 --- a/lightx2v/server/services/generation/image.py +++ b/lightx2v/server/services/generation/image.py @@ -78,7 +78,10 @@ async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[A ) else: error_msg = result.get("error", "Inference failed") - raise RuntimeError(error_msg) + error_type = result.get("error_type", "") + exc = RuntimeError(error_msg) + exc.original_error_type = error_type + raise exc except Exception as e: logger.exception(f"Task {message.task_id} processing failed: {str(e)}") diff --git a/lightx2v/server/services/inference/worker.py b/lightx2v/server/services/inference/worker.py index 238ffd9a5..6f2688f78 100644 --- a/lightx2v/server/services/inference/worker.py +++ b/lightx2v/server/services/inference/worker.py @@ -68,6 +68,7 @@ def init(self, args) -> bool: async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]: has_error = False error_msg = "" + error_type = "" pipeline_return = None try: @@ -103,6 +104,7 @@ async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]: except Exception as e: has_error = True error_msg = str(e) + error_type = type(e).__name__ logger.exception(f"Rank {self.rank} inference failed: {error_msg}") if self.world_size > 1: @@ -114,6 +116,7 @@ async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]: "task_id": task_data.get("task_id", "unknown"), "status": "failed", "error": error_msg, + "error_type": error_type, "message": f"Inference failed: {error_msg}", } else: diff --git a/lightx2v/server/task_manager.py b/lightx2v/server/task_manager.py index 93b8a4e40..5c53ea096 100644 --- a/lightx2v/server/task_manager.py +++ b/lightx2v/server/task_manager.py @@ -27,6 +27,7 @@ class TaskInfo: start_time: datetime = field(default_factory=datetime.now) end_time: Optional[datetime] = None error: Optional[str] = None + error_type: Optional[str] = None save_result_path: Optional[str] = None result_png: Optional[bytes] = None stop_event: threading.Event = field(default_factory=threading.Event) @@ -97,7 +98,7 @@ def complete_task(self, task_id: str, save_result_path: Optional[str] = None, re self.completed_tasks += 1 self._emit_queue_metrics_unlocked() - def fail_task(self, task_id: str, error: str): + def fail_task(self, task_id: str, error: str, error_type: Optional[str] = None): with self._lock: if task_id not in self._tasks: logger.warning(f"Task {task_id} not found for failure") @@ -107,6 +108,7 @@ def fail_task(self, task_id: str, error: str): task.status = TaskStatus.FAILED task.end_time = datetime.now() task.error = error + task.error_type = error_type self.failed_tasks += 1 self._emit_queue_metrics_unlocked() @@ -154,7 +156,15 @@ def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]: if not task: return None - return {"task_id": task.task_id, "status": task.status.value, "start_time": task.start_time, "end_time": task.end_time, "error": task.error, "save_result_path": task.save_result_path} + return { + "task_id": task.task_id, + "status": task.status.value, + "start_time": task.start_time, + "end_time": task.end_time, + "error": task.error, + "error_type": task.error_type or "", + "save_result_path": task.save_result_path, + } def get_all_tasks(self): with self._lock: