Skip to content

Commit fcbc174

Browse files
committed
Training
1 parent ea741a6 commit fcbc174

1 file changed

Lines changed: 21 additions & 7 deletions

File tree

squeez/training/train.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,19 @@ def _prepare_text_tokenizer(model_name: str, tokenizer):
4747
logger.info("Extracting text tokenizer from VL processor")
4848
tokenizer = tokenizer.tokenizer
4949

50-
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
50+
im_end_token = "<|im_end|>"
51+
im_end_id = tokenizer.convert_tokens_to_ids(im_end_token)
5152
unk_id = getattr(tokenizer, "unk_token_id", None)
52-
if tokenizer.eos_token in {None, "<EOS_TOKEN>"} and im_end_id is not None and im_end_id != unk_id:
53-
tokenizer.eos_token = "<|im_end|>"
53+
if im_end_id is not None and im_end_id != unk_id:
54+
tokenizer.eos_token = im_end_token
5455
tokenizer.eos_token_id = im_end_id
56+
if hasattr(tokenizer, "init_kwargs"):
57+
tokenizer.init_kwargs["eos_token"] = im_end_token
5558

5659
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
57-
tokenizer.chat_template = tokenizer.chat_template.replace("<EOS_TOKEN>", tokenizer.eos_token)
60+
tokenizer.chat_template = tokenizer.chat_template.replace(
61+
"<EOS_TOKEN>", tokenizer.eos_token
62+
)
5863

5964
if tokenizer.pad_token is None:
6065
tokenizer.pad_token = tokenizer.eos_token
@@ -66,15 +71,25 @@ def _prepare_text_tokenizer(model_name: str, tokenizer):
6671
f"Current eos_token={tokenizer.eos_token!r}, eos_token_id={tokenizer.eos_token_id!r}."
6772
)
6873

74+
logger.info(
75+
"Using tokenizer %s with eos_token=%r (id=%s), pad_token=%r (id=%s)",
76+
tokenizer.__class__.__name__,
77+
tokenizer.eos_token,
78+
tokenizer.eos_token_id,
79+
tokenizer.pad_token,
80+
tokenizer.pad_token_id,
81+
)
82+
6983
return tokenizer
7084

7185

7286
def train(args: argparse.Namespace):
7387
"""Run LoRA fine-tuning with Unsloth + SFTTrainer."""
88+
from unsloth import FastLanguageModel # noqa: I001
89+
from unsloth.chat_templates import train_on_responses_only
90+
7491
from datasets import Dataset
7592
from trl import SFTConfig, SFTTrainer
76-
from unsloth import FastLanguageModel
77-
from unsloth.chat_templates import train_on_responses_only
7893

7994
config = load_config(args.config)
8095

@@ -157,7 +172,6 @@ def train(args: argparse.Namespace):
157172
"report_to": "none",
158173
"seed": 42,
159174
"dataset_num_proc": 1,
160-
"eos_token": tokenizer.eos_token,
161175
}
162176
if eval_dataset:
163177
sft_config_kwargs["eval_strategy"] = "steps"

0 commit comments

Comments
 (0)