diff --git a/tests/datasets/test_openai_tokenize_fn.py b/tests/datasets/test_openai_tokenize_fn.py index 8fe724282..c6e39134c 100644 --- a/tests/datasets/test_openai_tokenize_fn.py +++ b/tests/datasets/test_openai_tokenize_fn.py @@ -2,14 +2,65 @@ import parametrize from unittest import TestCase from transformers import AutoTokenizer - +import unittest +from packaging.version import Version +from transformers import __version__ as transformers_version +import json from xtuner.v1.datasets import OpenaiTokenizeFunctionConfig +import copy QWEN3_PATH = os.environ["QWEN3_VL_DENSE_PATH"] # We need instruct model class TestOpenaiTokenizeFunction(TestCase): + + @unittest.skipIf( + Version(transformers_version) < Version("5.2.0"), + f"transformers >= 5.2.0 is required, but got {transformers_version}" + ) + def test_qwen3p5_openai_tokenize_fn(self): + QWEN3P5_PATH = os.environ["QWEN3_5_MOE_PATH"] + demo_data_path = 'tests/resource/qwen35_tokenize_data.jsonl' + + tokenizer = AutoTokenizer.from_pretrained(QWEN3P5_PATH, trust_remote_code=True) + tokenizer_fn_cfg = OpenaiTokenizeFunctionConfig(chat_template="qwen3.5-vl") + tokenizer_fn = tokenizer_fn_cfg.build(tokenizer) + + all_data = [] + with open(demo_data_path, 'r') as f: + for line in f: + all_data.append(json.loads(line)) + + for data in all_data: + id = data["id"] + if id == 7: + data_with_system = copy.deepcopy(data) + input_ids_ref = tokenizer.apply_chat_template(data_with_system['messages'], tools=data.get('tools'), tokenize=True, add_generation_prompt=False)['input_ids'] + # 临时方案,为了和 hf 对齐 + data_with_system['messages'][2]["content"] = "\n\n\n\n我需要先调用一些工具才能知道\n" + data_with_system['messages'][-1]["content"] = "\n\n\n\n基于我的观察,今天北京的天气是35度。" + + input_ids = tokenizer_fn(data_with_system)['input_ids'] + prompt_ref = tokenizer.decode(input_ids_ref,skip_special_tokens=False) + prompt = tokenizer.decode(input_ids,skip_special_tokens=False) + self.assertEqual(prompt_ref, prompt) + self.assertEqual(input_ids_ref, input_ids) + + data_wo_system = copy.deepcopy(data) + del data_wo_system['messages'][0] + input_ids_ref = tokenizer.apply_chat_template(data_wo_system['messages'], tools=data.get('tools'), tokenize=True, add_generation_prompt=False)['input_ids'] + # 临时方案,为了和 hf 对齐 + data_wo_system['messages'][1]["content"] = "\n\n\n\n我需要先调用一些工具才能知道\n" + data_wo_system['messages'][-1]["content"] = "\n\n\n\n基于我的观察,今天北京的天气是35度。" + + input_ids = tokenizer_fn(data_wo_system)['input_ids'] + prompt_ref = tokenizer.decode(input_ids_ref,skip_special_tokens=False) + prompt = tokenizer.decode(input_ids,skip_special_tokens=False) + self.assertEqual(prompt_ref, prompt) + self.assertEqual(input_ids_ref, input_ids) + + # TODO: Remove this test later @parametrize.parametrize( "template_type, tokenizer_path", [ diff --git a/tests/resource/qwen35_tokenize_data.jsonl b/tests/resource/qwen35_tokenize_data.jsonl new file mode 100644 index 000000000..24f3b7902 --- /dev/null +++ b/tests/resource/qwen35_tokenize_data.jsonl @@ -0,0 +1,15 @@ +{"id":1,"messages": [{"role": "system", "content": "这是单轮无think例子"},{"role": "user", "content": "这是第一个问题"},{"role": "assistant", "content": "我需要先调用一些工具才能知道"}]} +{"id":2,"messages": [{"role": "system", "content": "这是单轮有think例子"},{"role": "user", "content": "这是第一个问题"},{"role": "assistant", "content": "我需要先调用一些工具才能知道","reasoning_content": "这是 reasoning_content 内容"}]} +{"id":3,"messages": [{"role": "system", "content": "这是单轮有think例子"},{"role": "user", "content": "这是第一个问题"},{"role": "assistant", "content": "\n我需要先调用一些工具才能知道","reasoning_content": "\n这是 reasoning_content 内容\n"}]} +{"id":4,"messages": [{"role": "system", "content": "这是多轮无think例子"},{"role": "user", "content": "这是第一个问题"},{"role": "assistant", "content": "我需要先调用一些工具才能知道"},{"role": "user", "content": "这是第二个问题"},{"role": "assistant", "content": "好的,我知道这是第二个问题"}]} +{"id":5,"messages": [{"role": "system", "content": "这是多轮有think例子"},{"role": "user", "content": "这是第一个问题"},{"role": "assistant", "content": "我需要先调用一些工具才能知道"},{"role": "user", "content": "这是第二个问题"},{"role": "assistant", "content": "好的,我知道这是第二个问题", "reasoning_content": "这是 reasoning_content 内容"}]} +{"id":6,"messages": [{"role": "system", "content": "这是多轮有think例子"},{"role": "user", "content": "这是第一个问题"},{"role": "assistant", "content": "我需要先调用一些工具才能知道", "reasoning_content": "这是 reasoning_content 内容 1"},{"role": "user", "content": "这是第二个问题"},{"role": "assistant", "content": "好的,我知道这是第二个问题"},{"role": "user", "content": "这是第三个问题"},{"role": "assistant", "content": "好的,我知道这是第三个问题", "reasoning_content": "这是 reasoning_content 内容 2"}]} +{"id":7,"messages": [{"role": "system", "content": "这是单轮无think+toolcall例子"},{"role": "user", "content": "北京今天的天气如何?"},{"role": "assistant", "content": "我需要先调用一些工具才能知道", "tool_calls":[{"id": "call_123", "type": "function", "function": {"name": "get_weather", "arguments": {"location": "Boston"}}}]},{"role": "tool","content": "35"},{"role": "assistant", "content": "基于我的观察,今天北京的天气是35度。"}],"tools": [{"type":"function", "function": {"name": "get_current_temperature", "description": "Gets the temperature at a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the temperature for"}}, "required": ["location"]}}},{"type": "function", "function": {"name":"get_current_wind_speed", "description": "Get the current wind speed in km/h at a given location.", "parameters":{"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the wind speed for, in the format \"City, Country\""}}, "required": ["location"]}}}]} +{"id":8,"messages": [{"role": "system", "content": "这是单轮有think+toolcall例子"}, {"role": "user", "content": "北京今天的天气如何?"}, {"role": "assistant", "content": "我需要先调用一些工具才能知道", "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "get_weather", "arguments": {"location": "Boston"}}}]}, {"role": "tool", "content": "35"}, {"role": "assistant", "content": "基于我的观察,今天北京的天气是35度。", "reasoning_content": "这是 reasoning_content 内容"}], "tools": [{"type": "function", "function": {"name": "get_current_temperature", "description": "Gets the temperature at a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the temperature for"}}, "required": ["location"]}}}, {"type": "function", "function": {"name": "get_current_wind_speed", "description": "Get the current wind speed in km/h at a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the wind speed for, in the format \"City, Country\""}}, "required": ["location"]}}}]} +{"id":9,"messages": [{"role": "system", "content": "这是单轮有think+toolcall例子"}, {"role": "user", "content": "北京今天的天气如何?"}, {"role": "assistant", "content": "我需要先调用一些工具才能知道", "reasoning_content": "这是 reasoning_content 内容", "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "get_weather", "arguments": {"location": "Boston"}}}]}, {"role": "tool", "content": "35"}, {"role": "assistant", "content": "基于我的观察,今天北京的天气是35度。","reasoning_content": "这是最后一个 reasoning_content 内容"}], "tools": [{"type": "function", "function": {"name": "get_current_temperature", "description": "Gets the temperature at a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the temperature for"}}, "required": ["location"]}}}, {"type": "function", "function": {"name": "get_current_wind_speed", "description": "Get the current wind speed in km/h at a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the wind speed for, in the format \"City, Country\""}}, "required": ["location"]}}}]} +{"id":10,"messages": [{"role": "system", "content": "这是多轮无think+toolcall例子"}, {"role": "user", "content": "北京今天的天气如何?"}, {"role": "assistant", "content": "我需要先调用一些工具才能知道", "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "get_weather", "arguments": {"location": "Boston"}}}]}, {"role": "tool", "content": "35"}, {"role": "assistant", "content": "基于我的观察,今天北京的天气是35度。"}, {"role": "user", "content": "这是第二个问题。上海的天气如何"}, {"role": "assistant", "content": "好的,我知道这是第二个问题。我需要先调用一些工具才能知道", "tool_calls": [{"id": "call_789", "type": "function", "function": {"name": "get_weather", "arguments": {"location": "shanghai"}}}]}, {"role": "tool", "content": "25"}, {"role": "assistant", "content": "基于我的观察,今天上海的天气是25度。"}], "tools": [{"type": "function", "function": {"name": "get_current_temperature", "description": "Gets the temperature at a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the temperature for"}}, "required": ["location"]}}}, {"type": "function", "function": {"name": "get_current_wind_speed", "description": "Get the current wind speed in km/h at a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the wind speed for, in the format \"City, Country\""}}, "required": ["location"]}}}]} +{"id":11,"messages": [{"role": "system", "content": "这是多轮有think+toolcall例子。只有一个用户 user 输入。只有一次真 user 输入 表示整个对话过程中只有 user message。此时中间的所有 think 过程都会保留"}, {"role": "user", "content": "北京和上海今天的天气如何?"}, {"role": "assistant", "content": "我需要先调用一些工具才能知道", "reasoning_content": "这是 reasoning_content 内容", "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "get_weather", "arguments": {"location": "Boston"}}}]}, {"role": "tool", "content": "35"}, {"role": "assistant", "content": "我现在知道北京的天气了,我需要继续知道上海的天气", "reasoning_content": "这是 reasoning_content 内容 2", "tool_calls": [{"id": "call_789", "type": "function", "function": {"name": "get_weather", "arguments": {"location": "shanghai"}}}]}, {"role": "tool", "content": "25"}, {"role": "assistant", "content": "基于我的观察,今天北京的天气是35度,上海的天气是25度。"}], "tools": [{"type": "function", "function": {"name": "get_current_temperature", "description": "Gets the temperature at a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the temperature for"}}, "required": ["location"]}}}, {"type": "function", "function": {"name": "get_current_wind_speed", "description": "Get the current wind speed in km/h at a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the wind speed for, in the format \"City, Country\""}}, "required": ["location"]}}}]} +{"id":12,"messages": [{"role": "system", "content": "这是多轮有think+toolcall例子。有多个用户 user 输入。一旦再次来了一个新的真 user 输入,则之前的 think 内容会全部丢掉,因为相当于是一次新的回话"}, {"role": "user", "content": "北京今天天气如何?"}, {"role": "assistant", "content": "我需要先调用一些工具才能知道", "reasoning_content": "这是 reasoning_content 内容 1", "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "get_weather", "arguments": {"location": "Boston"}}}]}, {"role": "tool", "content": "35"}, {"role": "assistant", "content": "基于我的观察,今天北京的天气是35度。"}, {"role": "user", "content": "这是第二个问题。上海的天气如何?"}, {"role": "assistant", "content": "现在是第二个问题了,我需要先调用一些工具才能知道", "reasoning_content": "这是 reasoning_content 内容 2", "tool_calls": [{"id": "call_789", "type": "function", "function": {"name": "get_weather", "arguments": {"location": "shanghai"}}}]}, {"role": "tool", "content": "25"}, {"role": "assistant", "content": "基于我的观察,今天上海的天气是25度。"}], "tools": [{"type": "function", "function": {"name": "get_current_temperature", "description": "Gets the temperature at a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the temperature for"}}, "required": ["location"]}}}, {"type": "function", "function": {"name": "get_current_wind_speed", "description": "Get the current wind speed in km/h at a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the wind speed for, in the format \"City, Country\""}}, "required": ["location"]}}}]} +{"id":13,"messages": [{"role": "system", "content": "你是一个专业的图像分析助手,能够理解和分析多张图片。"}, {"role": "user", "content": [{"type": "image", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"}, {"type": "image", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"}, {"type": "text", "text": "请描述这两张图片的内容,它们有什么相同点和不同点?"}]}, {"role": "assistant", "content": "我需要仔细对比两张图片的主体、背景、光线等要素。", "reasoning_content": "第一张图片和第二张图片的主体都是同一只猫,背景都是室内环境,光线也相似。它们的相同点是都展示了这只猫在窗台上休息的场景。不同点是第一张图片中猫的姿势是侧卧,而第二张图片中猫的姿势是仰卧。"}, {"role": "user", "content": [{"type": "image", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"}, {"type": "text", "text": "这张新图片和之前的图片相比,有什么新的元素出现?"}]}, {"role": "assistant", "content": "与前两张图片相比,这张新图片中出现了不同的构图角度和新的视觉元素。"}, {"role": "user", "content": [{"type": "text", "text": "综合以上三张图片,你认为它们想表达什么主题?"}]}, {"role": "assistant", "content": "需要从整体角度总结三张图片的共同叙事逻辑和情感表达。", "reasoning_content": "这三张图片共同表达了一个主题:猫在室内环境中的不同状态和情感。第一张图片展示了猫的安静和放松,第二张图片展示了猫的舒适和满足,而第三张图片则通过不同的构图和视觉元素,传达了猫在这个环境中的多样性和丰富性。整体上,这些图片共同描绘了猫在室内生活中的多样化表现,表达了对猫的喜爱和对其生活状态的关注。"}]} +{"id":14,"messages": [{"role": "system", "content": "你是一个专业的视频分析助手,能够理解和分析视频内容。"}, {"role": "user", "content": [{"type": "video", "video": "https://example.com/video/demo.mp4"}, {"type": "text", "text": "请描述这个视频的主要内容,并分析其中的关键事件。"}]}, {"role": "assistant", "content": "让我仔细观察这个视频的每一帧内容。", "reasoning_content": "视频开始时展示了一个城市街道的场景,有行人和车辆在移动。随后镜头切换到一家咖啡店的内部,可以看到顾客在排队点单。接着视频展示了咖啡制作的过程,包括研磨咖啡豆、萃取咖啡液和打奶泡。最后视频以一杯精美的拉花咖啡特写结束。整个视频的节奏流畅,画面清晰。"}]} +{"id":15,"messages": [{"role": "system", "content": "你是一个专业的视频分析助手,能够理解和分析多个视频内容。"}, {"role": "user", "content": [{"type": "video", "video": "https://example.com/video/cooking_tutorial.mp4"}, {"type": "video", "video": "https://example.com/video/cooking_result.mp4"}, {"type": "text", "text": "请分析这两个视频,第一个视频是烹饪教程,第二个是最终成品。它们之间有什么联系?"}]}, {"role": "assistant", "content": "我需要仔细对比这两个视频的内容和逻辑关系。", "reasoning_content": "首先看第一个烹饪教程视频:视频展示了准备食材的过程,包括切菜、腌制肉类等步骤。然后是烹饪过程,展示了翻炒、调味等操作。最后视频展示了装盘。第二个成品视频展示了最终菜品的摆盘和特写镜头。两个视频的联系在于:第一个视频是制作过程,第二个视频是制作完成后的成品展示。它们共同构成了一个完整的从制作到呈现的叙事链条。"}, {"role": "user", "content": [{"type": "video", "video": "https://example.com/video/failed_attempt.mp4"}, {"type": "text", "text": "这里还有一个失败尝试的视频,和前两个相比有什么问题?"}]}, {"role": "assistant", "content": "让我对比分析这个失败案例与之前的成功案例。", "reasoning_content": "通过对比可以看出几个关键问题:首先,火候控制不当,视频中可以看到食材有些焦糊。其次,调味顺序有问题,盐放得太早导致食材出水过多。第三,翻炒的频率不够,导致受热不均匀。相比之下,第一个成功视频中火候掌握得当,调味时机准确,翻炒动作连贯。这些细节差异最终导致了截然不同的结果。"}, {"role": "user", "content": [{"type": "text", "text": "基于这三个视频,总结一下成功烹饪这道菜的关键要点。"}]}, {"role": "assistant", "content": "需要从成功和失败的对比中提炼出关键要点。", "reasoning_content": "综合三个视频的分析,成功烹饪这道菜的关键要点包括:第一,火候控制是核心,需要保持中火避免焦糊;第二,调味顺序很重要,盐应在出锅前加入;第三,翻炒要频繁均匀,确保食材受热一致;第四,食材预处理要到位,切块的均匀度影响受热;第五,要有耐心,每个步骤都不能急于求成。失败视频恰恰反证了这些要点的重要性。"}]} \ No newline at end of file diff --git a/xtuner/v1/data_proto/messages/chat.py b/xtuner/v1/data_proto/messages/chat.py index fcc4adc64..cc6f7c5bd 100644 --- a/xtuner/v1/data_proto/messages/chat.py +++ b/xtuner/v1/data_proto/messages/chat.py @@ -74,13 +74,31 @@ def tool_formatter(tools: list[dict[str, Any]]) -> str: return tool_text -def function_formatter(functions: list[dict[str, Any]]) -> str: +def function_formatter(functions: list[dict[str, Any]], template_name: Optional[str] = None) -> str: function_texts = [] - for function in functions: - name = function["function"]["name"] - arguments = function["function"]["arguments"] - function_texts.append(json.dumps({"name": name, "arguments": json.loads(arguments)}, ensure_ascii=False)) - return "\n".join([f"\n{text}\n" for text in function_texts]) + if template_name is None: # qwen3 + for function in functions: + name = function["function"]["name"] + arguments = function["function"]["arguments"] + function_texts.append(json.dumps({"name": name, "arguments": json.loads(arguments)}, ensure_ascii=False)) + return "\n".join([f"\n{text}\n" for text in function_texts]) + elif template_name == "qwen3.5-vl": # qwen3.5 + for function in functions: + name = function["function"]["name"] + arguments = function["function"]["arguments"] + if isinstance(arguments, str): + arguments = json.loads(arguments) + prompt = f"\n" + for key, value in arguments.items(): + prompt += f"\n" + if not isinstance(value, str): + value = json.dumps(value, ensure_ascii=False) + prompt += f"\n{value}\n" + prompt += "\n\n" + function_texts.append(prompt) + return "\n".join(function_texts) + else: + raise NotImplementedError(f"function_formatter for template {template_name} is not implemented.") class ChatMsg(BaseModel): @@ -137,7 +155,7 @@ def get_prompt(self, chat_template: ChatTemplate) -> str: prompt = chat_template.decorate_tool_extractor(text) elif self.role == "assistant": if self.tool_calls is not None: - function_text = function_formatter(self.tool_calls) + function_text = function_formatter(self.tool_calls, chat_template.template_name) if text is not None and text != "" and not text.endswith("\n\n"): function_text = "\n" + function_text text += function_text @@ -204,7 +222,10 @@ def process_message(messages: List[ChatMsg], chat_template: ChatTemplate, tools: messages.insert(0, ChatMsg(role="system", content=tool_text, loss=False)) else: assert isinstance(messages[0].content, str), "system message content must be str." - messages[0].content += tool_text + if chat_template.template_name == "qwen3.5-vl": + messages[0].content = tool_text + "\n\n" + messages[0].content + else: + messages[0].content += tool_text class ChatMessages(BaseMessages): diff --git a/xtuner/v1/data_proto/templates/__init__.py b/xtuner/v1/data_proto/templates/__init__.py index a016798bf..3da422052 100644 --- a/xtuner/v1/data_proto/templates/__init__.py +++ b/xtuner/v1/data_proto/templates/__init__.py @@ -46,6 +46,19 @@ image_context_token="<|image_pad|>", video_context_token="<|video_pad|>", ), + "qwen3.5-vl": HybridChatTemplate( + template_name="qwen3.5-vl", + system="<|im_start|>system\n{system}<|im_end|>\n", + tool_prompt="# Tools\n\nYou have access to the following functions:\n\n{tool_text}\n\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n", + tool_extractor="<|im_start|>user\n\n{tool_extractor}\n<|im_end|>\n<|im_start|>assistant\n", + user="<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n", + stop_words=["<|im_end|>", "<|endoftext|>"], + assistant="{assistant}<|im_end|>", + image_start_token="<|vision_start|>", + image_end_token="<|vision_end|>", + image_context_token="<|image_pad|>", + video_context_token="<|video_pad|>", + ), "llama3": HybridChatTemplate( system="<|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>", user=( diff --git a/xtuner/v1/data_proto/templates/chat.py b/xtuner/v1/data_proto/templates/chat.py index 7d800994a..0deec5808 100644 --- a/xtuner/v1/data_proto/templates/chat.py +++ b/xtuner/v1/data_proto/templates/chat.py @@ -21,6 +21,7 @@ class ChatTemplate(BaseModel): default_system: str | None = None tool_extractor: str | None = None # Tool extractor format tool_prompt: str | None = None # Tool prompt format + template_name: str | None = None # only compute loss on the last assistant response ignoring the multiple rounds of assistant only_last_assistant_loss: bool = False # gpt_oss is True diff --git a/xtuner/v1/data_proto/templates/hybrid.py b/xtuner/v1/data_proto/templates/hybrid.py index 0ec2ddfcf..28bb4234c 100644 --- a/xtuner/v1/data_proto/templates/hybrid.py +++ b/xtuner/v1/data_proto/templates/hybrid.py @@ -22,6 +22,7 @@ class HybridChatTemplate(BaseModel): default_system: Optional[str] = None tool_prompt: str | None = None # Tool prompt format tool_extractor: str | None = None # Tool extractor format + template_name: Optional[str] = None # only compute loss on the last assistant response ignoring the multiple rounds of assistant only_last_assistant_loss: bool = False # gpt_oss is True diff --git a/xtuner/v1/datasets/_hardcode_patch.py b/xtuner/v1/datasets/_hardcode_patch.py index de223e7a3..389013b95 100644 --- a/xtuner/v1/datasets/_hardcode_patch.py +++ b/xtuner/v1/datasets/_hardcode_patch.py @@ -43,8 +43,12 @@ def __init__(self, tokenizer: PreTrainedTokenizer, *args, **kwargs): super().__init__(tokenizer, *args, **kwargs) self._tokenizer = tokenizer if "" in tokenizer.added_tokens_encoder: + # \n\n\n\n 不算 loss + # \n 不算 loss + # 原则上要报错,且不算 loss self._skip_seq = [ tokenizer.added_tokens_encoder[""], + *tokenizer.encode("\n\n"), tokenizer.added_tokens_encoder[""], *tokenizer.encode("\n\n"), ] @@ -58,6 +62,8 @@ def __init__(self, tokenizer: PreTrainedTokenizer, *args, **kwargs): self._skip_seq_fallback = [] self._think_token = None + self.skip_token = tokenizer.encode("\n")[0] + if not self._skip_seq: logger.warning("`SkipEmptyThink` is disabled because or token is not found in tokenizer.") else: @@ -80,17 +86,21 @@ def process_labels(self, input_ids: list[int], labels: list[int]): new_labels = [] buffer = [] label_buffer = [] + pre_label_id = None for label_id, token_id in zip(labels, input_ids): buffer.append(token_id) # Since is filled by chat template, never calculate loss on it if label_id == self._think_token: + pre_label_id = label_id label_buffer.append(-100) else: + if pre_label_id == self._think_token and token_id == self.skip_token: + label_id = -100 label_buffer.append(label_id) + pre_label_id = label_id - # Check if buffer matches skip_seq if buffer == self._skip_seq: # Complete match found, replace with -100 new_labels.extend([-100] * len(self._skip_seq))