Skip to content

Commit d203934

Browse files
update qwen-image text token length error (#1042)
1 parent 2c6fc47 commit d203934

7 files changed

Lines changed: 40 additions & 6 deletions

File tree

lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ def infer(self, text, image_list=None):
218218
drop_idx = self.prompt_template_encode_start_idx
219219
txt = [template.format(e) for e in text]
220220

221+
token_lengths = [len(ids) for ids in self.tokenizer(txt, add_special_tokens=True)["input_ids"]]
222+
max_token_len = max(token_lengths)
223+
if max_token_len > self.tokenizer_max_length + drop_idx:
224+
raise ValueError(f"Input text token length ({max_token_len - drop_idx}) exceeds ({self.tokenizer_max_length}). Please shorten the input text.")
221225
model_inputs = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(AI_DEVICE)
222226
encoder_hidden_states = self.text_encoder(
223227
input_ids=model_inputs.input_ids,

lightx2v/server/api/server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ async def _process_single_task(self, task_info: Any):
145145

146146
except Exception as e:
147147
logger.exception(f"Task {task_id} processing failed: {str(e)}")
148-
task_manager.fail_task(task_id, str(e))
148+
original_et = getattr(e, "original_error_type", "") or ""
149+
task_manager.fail_task(task_id, str(e), error_type=original_et or None)
149150
finally:
150151
if lock_acquired:
151152
task_manager.release_processing_lock(task_id)

lightx2v/server/api/tasks/image.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@ async def _wait_task_and_stream_result(task_id: str, timeout_seconds: int, poll_
3434
raise HTTPException(status_code=500, detail=f"Task completed but no in-memory image found: {task_id}")
3535

3636
if status == TaskStatus.FAILED.value:
37-
raise HTTPException(status_code=500, detail=task_status.get("error", "Task failed"))
37+
error_type = task_status.get("error_type", "")
38+
error_detail = task_status.get("error", "Task failed")
39+
if error_type == "ValueError":
40+
raise HTTPException(status_code=413, detail=error_detail)
41+
raise HTTPException(status_code=500, detail=error_detail)
3842

3943
if status == TaskStatus.CANCELLED.value:
4044
raise HTTPException(status_code=409, detail=task_status.get("error", "Task cancelled"))
@@ -108,6 +112,8 @@ async def create_image_task(message: ImageTaskRequest):
108112
save_result_path=message.save_result_path,
109113
)
110114
except RuntimeError as e:
115+
if getattr(e, "original_error_type", "") == "ValueError":
116+
raise HTTPException(status_code=413, detail=str(e))
111117
raise HTTPException(status_code=503, detail=str(e))
112118
except Exception as e:
113119
logger.error(f"Failed to create image task: {e}")
@@ -168,6 +174,8 @@ async def create_image_task_sync(
168174
raise
169175

170176
except RuntimeError as e:
177+
if getattr(e, "original_error_type", "") == "ValueError":
178+
raise HTTPException(status_code=413, detail=str(e))
171179
raise HTTPException(status_code=503, detail=str(e))
172180
except HTTPException:
173181
raise
@@ -229,6 +237,8 @@ async def save_file_async(file: UploadFile, target_dir: Path) -> str:
229237
save_result_path=message.save_result_path,
230238
)
231239
except RuntimeError as e:
240+
if getattr(e, "original_error_type", "") == "ValueError":
241+
raise HTTPException(status_code=413, detail=str(e))
232242
raise HTTPException(status_code=503, detail=str(e))
233243
except Exception as e:
234244
logger.error(f"Failed to create image form task: {e}")

lightx2v/server/services/generation/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,10 @@ async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[A
190190
)
191191
else:
192192
error_msg = result.get("error", "Inference failed")
193-
raise RuntimeError(error_msg)
193+
error_type = result.get("error_type", "")
194+
exc = RuntimeError(error_msg)
195+
exc.original_error_type = error_type
196+
raise exc
194197

195198
except Exception as e:
196199
logger.exception(f"Task {message.task_id} processing failed: {str(e)}")

lightx2v/server/services/generation/image.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[A
7878
)
7979
else:
8080
error_msg = result.get("error", "Inference failed")
81-
raise RuntimeError(error_msg)
81+
error_type = result.get("error_type", "")
82+
exc = RuntimeError(error_msg)
83+
exc.original_error_type = error_type
84+
raise exc
8285

8386
except Exception as e:
8487
logger.exception(f"Task {message.task_id} processing failed: {str(e)}")

lightx2v/server/services/inference/worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def init(self, args) -> bool:
6868
async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
6969
has_error = False
7070
error_msg = ""
71+
error_type = ""
7172
pipeline_return = None
7273

7374
try:
@@ -103,6 +104,7 @@ async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
103104
except Exception as e:
104105
has_error = True
105106
error_msg = str(e)
107+
error_type = type(e).__name__
106108
logger.exception(f"Rank {self.rank} inference failed: {error_msg}")
107109

108110
if self.world_size > 1:
@@ -114,6 +116,7 @@ async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
114116
"task_id": task_data.get("task_id", "unknown"),
115117
"status": "failed",
116118
"error": error_msg,
119+
"error_type": error_type,
117120
"message": f"Inference failed: {error_msg}",
118121
}
119122
else:

lightx2v/server/task_manager.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class TaskInfo:
2727
start_time: datetime = field(default_factory=datetime.now)
2828
end_time: Optional[datetime] = None
2929
error: Optional[str] = None
30+
error_type: Optional[str] = None
3031
save_result_path: Optional[str] = None
3132
result_png: Optional[bytes] = None
3233
stop_event: threading.Event = field(default_factory=threading.Event)
@@ -97,7 +98,7 @@ def complete_task(self, task_id: str, save_result_path: Optional[str] = None, re
9798
self.completed_tasks += 1
9899
self._emit_queue_metrics_unlocked()
99100

100-
def fail_task(self, task_id: str, error: str):
101+
def fail_task(self, task_id: str, error: str, error_type: Optional[str] = None):
101102
with self._lock:
102103
if task_id not in self._tasks:
103104
logger.warning(f"Task {task_id} not found for failure")
@@ -107,6 +108,7 @@ def fail_task(self, task_id: str, error: str):
107108
task.status = TaskStatus.FAILED
108109
task.end_time = datetime.now()
109110
task.error = error
111+
task.error_type = error_type
110112

111113
self.failed_tasks += 1
112114
self._emit_queue_metrics_unlocked()
@@ -154,7 +156,15 @@ def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
154156
if not task:
155157
return None
156158

157-
return {"task_id": task.task_id, "status": task.status.value, "start_time": task.start_time, "end_time": task.end_time, "error": task.error, "save_result_path": task.save_result_path}
159+
return {
160+
"task_id": task.task_id,
161+
"status": task.status.value,
162+
"start_time": task.start_time,
163+
"end_time": task.end_time,
164+
"error": task.error,
165+
"error_type": task.error_type or "",
166+
"save_result_path": task.save_result_path,
167+
}
158168

159169
def get_all_tasks(self):
160170
with self._lock:

0 commit comments

Comments
 (0)