Skip to content

Commit 5ef948b

Browse files
author
niushengxiao
committed
fix: fix a eos id bug
1 parent 630f7b8 commit 5ef948b

1 file changed

Lines changed: 22 additions & 2 deletions

File tree

lightllm/utils/config_utils.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,16 +238,36 @@ def get_eos_token_ids(model_path: str) -> Optional[List[int]]:
238238

239239
# Qwen3.5 checkpoints can have an eos_token_id in config that differs from
240240
# tokenizer.eos_token_id. In practice tokenizer.eos_token_id is the reliable
241-
# stop id (<|im_end|>) for detokenization/stop behavior.
241+
# stop id (<|im_end|>, <|endoftext|>) for detokenization/stop behavior.
242242
try:
243243
config_json = get_config_json(model_path)
244244
model_type = config_json.get("model_type") or config_json.get("text_config", {}).get("model_type")
245245
if model_type in {"qwen3_5", "qwen3_5_text", "qwen3_5_moe", "qwen3_5_moe_text"}:
246246
from transformers import AutoTokenizer
247247

248+
eos_token_ids = []
249+
248250
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False)
249251
if tokenizer.eos_token_id is not None:
250-
return [int(tokenizer.eos_token_id)]
252+
eos_token_ids.append(int(tokenizer.eos_token_id))
253+
254+
generation_config_path = os.path.join(model_path, "generation_config.json")
255+
if os.path.exists(generation_config_path):
256+
with open(generation_config_path, "r") as file:
257+
generation_eos_token_id = json.load(file).get("eos_token_id")
258+
if isinstance(generation_eos_token_id, int):
259+
eos_token_ids.append(generation_eos_token_id)
260+
elif isinstance(generation_eos_token_id, list):
261+
eos_token_ids.extend(generation_eos_token_id)
262+
263+
config_eos_token_id = _get_config_llm_keyvalue(model_path=model_path, key_name=["eos_token_id"])
264+
if isinstance(config_eos_token_id, int):
265+
eos_token_ids.append(config_eos_token_id)
266+
elif isinstance(config_eos_token_id, list):
267+
eos_token_ids.extend(config_eos_token_id)
268+
269+
if eos_token_ids:
270+
return list(dict.fromkeys(int(eos_id) for eos_id in eos_token_ids))
251271
except Exception:
252272
pass
253273

0 commit comments

Comments
 (0)