Skip to content

Commit 891b764

Browse files
Merge pull request #3288 from AI-Hypercomputer:vladk/sft-completion-fix2
PiperOrigin-RevId: 878653657
2 parents 12fe4ce + fc2dab7 commit 891b764

3 files changed

Lines changed: 56 additions & 17 deletions

File tree

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
Features = dict[str, tf.Tensor]
3434
AUTOTUNE = tf.data.experimental.AUTOTUNE
35+
INPUT_TOKENS_KEY = "input_ids"
3536

3637
########## Functions used by TFDS pipeline
3738

@@ -171,6 +172,42 @@ def is_conversational(features, data_columns):
171172
return False
172173

173174

175+
def _get_completion_in_chat_template(tokenizer_model, round_msgs):
176+
"""
177+
Calculates the completion part of a conversation turn when formatted with a chat template.
178+
179+
This function handles both older and current Hugging Face tokenizers. Modern tokenizers
180+
may return a `BatchEncoding` object instead of a simple list of token IDs.
181+
182+
Args:
183+
tokenizer_model: The tokenizer instance.
184+
round_msgs: A list of messages for the current conversational turn, including the assistant's response.
185+
186+
Returns:
187+
A string representing the completion formatted by the chat template.
188+
"""
189+
prompt_completion_tokens = tokenizer_model.apply_chat_template(round_msgs, add_generation_prompt=False, tokenize=True)
190+
# include generation_prompt as part of the prompt tokens
191+
prompt_tokens = tokenizer_model.apply_chat_template(round_msgs[:-1], add_generation_prompt=True, tokenize=True)
192+
193+
# attention masks in BatchEncoding are effectively ignored
194+
if hasattr(prompt_completion_tokens, INPUT_TOKENS_KEY):
195+
prompt_completion_ids = getattr(prompt_completion_tokens, INPUT_TOKENS_KEY)
196+
prompt_ids = getattr(prompt_tokens, INPUT_TOKENS_KEY)
197+
elif isinstance(prompt_completion_tokens, dict) and INPUT_TOKENS_KEY in prompt_completion_tokens:
198+
prompt_completion_ids = prompt_completion_tokens[INPUT_TOKENS_KEY]
199+
prompt_ids = prompt_tokens[INPUT_TOKENS_KEY]
200+
elif isinstance(prompt_completion_tokens, list):
201+
prompt_completion_ids = prompt_completion_tokens
202+
prompt_ids = prompt_tokens
203+
else:
204+
raise ValueError(f"Can't handle the chat template output of type {type(prompt_completion_tokens)}")
205+
206+
completion_tokens = prompt_completion_ids[len(prompt_ids) :]
207+
completion_in_chat_template = tokenizer_model.decode(completion_tokens, skip_special_tokens=False)
208+
return completion_in_chat_template
209+
210+
174211
def apply_chat_template(example, tokenizer_model, data_column_name):
175212
"""Formats conversational data by applying the tokenizer's chat template
176213
and identifying prompt/completion segments for SFT masking.
@@ -210,14 +247,7 @@ def apply_chat_template(example, tokenizer_model, data_column_name):
210247
is_prompt.append(True)
211248
elif message["role"] == "assistant":
212249
round_msgs.append(message)
213-
prompt_completion_tokens = tokenizer_model.apply_chat_template(
214-
round_msgs, add_generation_prompt=False, tokenize=True
215-
)
216-
# include generation_prompt as part of the prompt tokens
217-
prompt_tokens = tokenizer_model.apply_chat_template(round_msgs[:-1], add_generation_prompt=True, tokenize=True)
218-
completion_tokens = prompt_completion_tokens[len(prompt_tokens) :]
219-
completion_in_chat_template = tokenizer_model.decode(completion_tokens, skip_special_tokens=False)
220-
messages.append(completion_in_chat_template)
250+
messages.append(_get_completion_in_chat_template(tokenizer_model, round_msgs))
221251
is_prompt.append(False)
222252
# Round ended, clearing the buffer.
223253
round_msgs.clear()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"data": {"messages": [{"role": "user", "content": "Hello, what is your name?"}, {"role": "assistant", "content": "I am a chatbot. How can I help?"}]}, "tokens": [1, 518, 25580, 29962, 15043, 29892, 825, 338, 596, 1024, 29973, 518, 29914, 25580, 29962, 306, 626, 263, 13563, 7451, 29889, 1128, 508, 306, 1371, 29973, 29871, 2], "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "token_log_probs": [-10.900627136230469, -11.805438995361328, -9.937392234802246, -10.478547096252441, -10.477544784545898, -10.665718078613281, -11.027463912963867, -10.303316116333008, -10.548932075500488, -10.392480850219727, -11.593963623046875, -11.837165832519531, -12.416250228881836, -10.1104736328125, -11.313142776489258, -12.341060638427734, -11.190383911132812, -9.143855094909668, -10.817261695861816, -11.793390274047852, -11.39107894897461, -11.716558456420898, -11.232498168945312, -12.146818161010742, -11.292530059814453, -10.039775848388672, -9.972617149353027]}
1+
{"data": {"messages": [{"role": "user", "content": "Hello, what is your name?"}, {"role": "assistant", "content": "I am a chatbot. How can I help?"}]}, "tokens": [1, 29961, 25580, 29962, 15043, 29892, 825, 338, 596, 1024, 29973, 518, 29914, 25580, 29962, 306, 626, 263, 13563, 7451, 29889, 1128, 508, 306, 1371, 29973, 29871, 2], "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "token_log_probs": [-10.360702514648438, -11.012994766235352, -10.751636505126953, -9.73588752746582, -11.174783706665039, -11.906787872314453, -10.50442123413086, -11.422593116760254, -12.447595596313477, -10.885910034179688, -11.982933044433594, -10.058539390563965, -10.950790405273438, -12.060896873474121, -10.68459701538086, -11.916288375854492, -12.050270080566406, -9.983818054199219, -10.710721015930176, -9.216376304626465, -11.008810043334961, -9.728713989257812, -12.391929626464844, -11.235883712768555, -9.664995193481445, -11.548173904418945, -10.014203071594238]}

tests/assets/logits_generation/generate_sft_golden_data.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from trl import SFTConfig, SFTTrainer
4040

4141
from maxtext.configs import pyconfig
42-
from maxtext.utils.globals import MAXTEXT_PKG_DIR, MAXTEXT_TEST_ASSETS_ROOT
42+
from maxtext.utils.globals import MAXTEXT_PKG_DIR, MAXTEXT_TEST_ASSETS_ROOT, MAXTEXT_ASSETS_ROOT
4343
from tests.integration.sft_trainer_correctness_test import get_maxtext_logits, get_token_log_probs, prepare_maxtext_inputs
4444

4545

@@ -54,7 +54,7 @@
5454
def initialize_maxtext_config(config):
5555
"""Initializes configuration for MaxText."""
5656
cfg_with_ckpt = pyconfig.initialize(
57-
[sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "sft.yml")],
57+
[sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "post_train", "sft.yml")],
5858
run_name="compare_maxtext_with_trl_logits",
5959
model_name=config.model_name,
6060
tokenizer_path=config.tokenizer_path,
@@ -70,7 +70,7 @@ def initialize_maxtext_config(config):
7070
)
7171

7272
cfg_without_ckpt = pyconfig.initialize(
73-
[sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "sft.yml")],
73+
[sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "post_train", "sft.yml")],
7474
run_name="generate_sft_golden_data",
7575
model_name="default",
7676
enable_checkpointing=False,
@@ -85,10 +85,10 @@ def initialize_maxtext_config(config):
8585
return cfg_with_ckpt, cfg_without_ckpt
8686

8787

88-
def get_hf_model(tokenizer_path):
88+
def get_hf_model(model_path):
8989
"""Load model from Hugging Face."""
9090
return AutoModelForCausalLM.from_pretrained(
91-
tokenizer_path,
91+
model_path,
9292
torch_dtype=torch.float32,
9393
)
9494

@@ -116,7 +116,7 @@ def setup_sft_trainer(data, hf_model, tokenizer, max_target_length):
116116
data_collator=None,
117117
args=SFTConfig(
118118
dataset_kwargs={"skip_prepare_dataset": True},
119-
max_seq_length=max_target_length,
119+
max_length=max_target_length,
120120
**training_args.to_dict(),
121121
),
122122
)
@@ -143,7 +143,7 @@ def prepare_trl_inputs(tokenizer_path, max_target_length):
143143

144144
def get_trl_logits(config, trl_data, max_target_length):
145145
"""Get logits generated by TRL."""
146-
hf_model = get_hf_model(config.tokenizer_path)
146+
hf_model = get_hf_model(config.hf_model_path)
147147
tokenizer = get_tokenizer(config.tokenizer_path, max_target_length)
148148
trl_trainer = setup_sft_trainer(trl_data, hf_model, tokenizer, max_target_length)
149149
_, trl_outputs = trl_trainer.compute_loss(hf_model, trl_data, return_outputs=True)
@@ -199,7 +199,16 @@ def test_with_trl_and_save_golden_data(config):
199199

200200
parser = argparse.ArgumentParser()
201201
parser.add_argument("--model-name", type=str, required=False, default="llama2-7b")
202-
parser.add_argument("--tokenizer-path", type=str, required=False, default="meta-llama/Llama-2-7b-chat-hf")
202+
203+
# Reasons to use the local tokenizer.
204+
# 1. In transformers=5.2.0 (at least), the Llama-2 tokenizer incorrectly injects an extra space
205+
# before <s> and [INST] tokens when applying its chat template.
206+
# 2. Consistency with the tokenizer used by sft_trainer_correctness_test.py
207+
# which depends on generated gold data here.
208+
parser.add_argument(
209+
"--tokenizer-path", type=str, required=False, default=os.path.join(MAXTEXT_ASSETS_ROOT, "llama2-chat-tokenizer")
210+
)
211+
parser.add_argument("--hf-model-path", type=str, required=False, default="meta-llama/Llama-2-7b-chat-hf")
203212
parser.add_argument(
204213
"--model-ckpt-path",
205214
type=str,

0 commit comments

Comments
 (0)