1919from threading import current_thread
2020from typing import Any , Iterable , TYPE_CHECKING
2121
22+ from jinja2 import TemplateError
23+
2224if TYPE_CHECKING :
2325 import datasets
2426 import tensorflow as tf
@@ -178,6 +180,103 @@ def is_conversational(features, data_columns):
178180 return False
179181
180182
183+ def _extract_token_ids (tokens ):
184+ """Extracts token IDs from various tokenizer output formats.
185+
186+ This helper function standardizes the extraction of tokenized integer IDs
187+ from common return types of Hugging Face tokenizers, including
188+ `BatchEncoding` objects, dictionaries, or simple lists.
189+
190+ Args:
191+ tokens: The object containing token IDs. Supported types include:
192+ - A list of integers.
193+ - A dictionary containing the `INPUT_TOKENS_KEY`.
194+ - An object (e.g., `BatchEncoding`) with an attribute named `INPUT_TOKENS_KEY`.
195+
196+ Returns:
197+ A list of integer token IDs.
198+
199+ Raises:
200+ ValueError: If the input type is not supported or does not contain the expected key.
201+ """
202+ # attention masks in BatchEncoding are effectively ignored
203+ if hasattr (tokens , INPUT_TOKENS_KEY ):
204+ return getattr (tokens , INPUT_TOKENS_KEY )
205+ elif isinstance (tokens , dict ) and INPUT_TOKENS_KEY in tokens :
206+ return tokens [INPUT_TOKENS_KEY ]
207+ elif isinstance (tokens , list ):
208+ return tokens
209+ else :
210+ raise ValueError (f"Can't extract token_ids from type { type (tokens )} " )
211+
212+
213+ def verify_chat_template_generation_prompt_logic (tokenizer_model ):
214+ """Verifies the tokenizer's chat template for correct SFT loss masking.
215+
216+ This function ensures that the tokens added by `add_generation_prompt=True`
217+ are identical to the tokens that begin an assistant's turn in a complete
218+ conversation, which is critical for masking prompt tokens during SFT loss
219+ calculation.
220+
221+ Example of a mismatch:
222+ A `ValueError` is raised if the generation prompt and the actual
223+ assistant prefix do not match. For example:
224+
225+ - `add_generation_prompt=True` on a user message produces a prompt ending in:
226+ `...<|im_start|>generation\n `
227+ - A full turn with an assistant message starts the reply with:
228+ `...<|im_start|>assistant\n ...`
229+
230+ This function would fail because the tokens for "generation" do not
231+ match the tokens for "assistant".
232+
233+ Args:
234+ tokenizer_model: The Hugging Face tokenizer instance to verify.
235+
236+ Raises:
237+ ValueError: If the `add_generation_prompt` tokens do not exactly
238+ match the beginning of an assistant message in the template.
239+ """
240+ dummy_msgs = [{"role" : "system" , "content" : "System message" }, {"role" : "user" , "content" : "Test message" }]
241+
242+ try :
243+ prompt_wo_gen_tokens = tokenizer_model .apply_chat_template (dummy_msgs , add_generation_prompt = False , tokenize = True )
244+ except TemplateError :
245+ max_logging .info (
246+ "Tokenizer failed to apply chat template with 'system' role. "
247+ "Falling back to 'user' role only for chat template verification."
248+ )
249+ dummy_msgs .pop (0 )
250+ prompt_wo_gen_tokens = tokenizer_model .apply_chat_template (dummy_msgs , add_generation_prompt = False , tokenize = True )
251+ prompt_wo_gen_ids = _extract_token_ids (prompt_wo_gen_tokens )
252+
253+ prompt_w_gen_tokens = tokenizer_model .apply_chat_template (dummy_msgs , add_generation_prompt = True , tokenize = True )
254+ prompt_w_gen_ids = _extract_token_ids (prompt_w_gen_tokens )
255+
256+ if prompt_w_gen_ids [: len (prompt_wo_gen_ids )] != prompt_wo_gen_ids :
257+ raise ValueError ("Unable to extract generation prompt tokens." )
258+ # Extract the tokenized generation prompt (the expected assistant prefix)
259+ assistant_prefix = prompt_w_gen_ids [len (prompt_wo_gen_ids ) :]
260+ full_turn_tokens = _extract_token_ids (
261+ tokenizer_model .apply_chat_template (
262+ dummy_msgs + [{"role" : "assistant" , "content" : "Dummy response" }], add_generation_prompt = False , tokenize = True
263+ )
264+ )
265+ full_turn_ids = _extract_token_ids (full_turn_tokens )
266+ # Extract the actual tokens that appear right after the user message in the full turn
267+ actual_prefix_in_full_turn = full_turn_ids [len (prompt_wo_gen_ids ) : len (prompt_wo_gen_ids ) + len (assistant_prefix )]
268+
269+ if actual_prefix_in_full_turn != assistant_prefix :
270+ expected_str = tokenizer_model .decode (assistant_prefix )
271+ actual_str = tokenizer_model .decode (actual_prefix_in_full_turn )
272+ raise ValueError (
273+ "Chat template generation prompt mismatch!\n "
274+ f"Expected assistant prefix tokens: { assistant_prefix } ('{ expected_str } ')\n "
275+ f"Actual prefix tokens found: { actual_prefix_in_full_turn } ('{ actual_str } ')\n "
276+ "This means the tokenizer's chat template will break the sft masking logic."
277+ )
278+
279+
181280def _get_completion_in_chat_template (tokenizer_model , round_msgs ):
182281 """
183282 Calculates the completion part of a conversation turn when formatted with a chat template.
@@ -196,18 +295,8 @@ def _get_completion_in_chat_template(tokenizer_model, round_msgs):
196295 # include generation_prompt as part of the prompt tokens
197296 prompt_tokens = tokenizer_model .apply_chat_template (round_msgs [:- 1 ], add_generation_prompt = True , tokenize = True )
198297
199- # attention masks in BatchEncoding are effectively ignored
200- if hasattr (prompt_completion_tokens , INPUT_TOKENS_KEY ):
201- prompt_completion_ids = getattr (prompt_completion_tokens , INPUT_TOKENS_KEY )
202- prompt_ids = getattr (prompt_tokens , INPUT_TOKENS_KEY )
203- elif isinstance (prompt_completion_tokens , dict ) and INPUT_TOKENS_KEY in prompt_completion_tokens :
204- prompt_completion_ids = prompt_completion_tokens [INPUT_TOKENS_KEY ]
205- prompt_ids = prompt_tokens [INPUT_TOKENS_KEY ]
206- elif isinstance (prompt_completion_tokens , list ):
207- prompt_completion_ids = prompt_completion_tokens
208- prompt_ids = prompt_tokens
209- else :
210- raise ValueError (f"Can't handle the chat template output of type { type (prompt_completion_tokens )} " )
298+ prompt_completion_ids = _extract_token_ids (prompt_completion_tokens )
299+ prompt_ids = _extract_token_ids (prompt_tokens )
211300
212301 completion_tokens = prompt_completion_ids [len (prompt_ids ) :]
213302 completion_in_chat_template = tokenizer_model .decode (completion_tokens , skip_special_tokens = False )
0 commit comments