diff --git a/docs/source/Instruction/Supported-models-and-datasets.md b/docs/source/Instruction/Supported-models-and-datasets.md index fbcb61aa55..39c4eefa98 100644 --- a/docs/source/Instruction/Supported-models-and-datasets.md +++ b/docs/source/Instruction/Supported-models-and-datasets.md @@ -320,6 +320,7 @@ |[swift/Meta-Llama-3-70B-Instruct-AWQ](https://modelscope.cn/models/swift/Meta-Llama-3-70B-Instruct-AWQ)|llama|llama3|-|✘|-|[study-hjt/Meta-Llama-3-70B-Instruct-AWQ](https://huggingface.co/study-hjt/Meta-Llama-3-70B-Instruct-AWQ)| |[ChineseAlpacaGroup/llama-3-chinese-8b-instruct](https://modelscope.cn/models/ChineseAlpacaGroup/llama-3-chinese-8b-instruct)|llama|llama3|-|✔|-|[hfl/llama-3-chinese-8b-instruct](https://huggingface.co/hfl/llama-3-chinese-8b-instruct)| |[ChineseAlpacaGroup/llama-3-chinese-8b](https://modelscope.cn/models/ChineseAlpacaGroup/llama-3-chinese-8b)|llama|llama3|-|✔|-|[hfl/llama-3-chinese-8b](https://huggingface.co/hfl/llama-3-chinese-8b)| +|[microsoft/UserLM-8b](https://modelscope.cn/models/microsoft/UserLM-8b)|llama|userlm|-|✔|-|[microsoft/UserLM-8b](https://huggingface.co/microsoft/UserLM-8b)| |[LLM-Research/Meta-Llama-3.1-8B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-8B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)| |[LLM-Research/Meta-Llama-3.1-70B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-70B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct)| |[LLM-Research/Meta-Llama-3.1-405B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-405B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-405B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct)| diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md index 099ebf4d19..3c379f41fd 100644 --- a/docs/source_en/Instruction/Supported-models-and-datasets.md +++ b/docs/source_en/Instruction/Supported-models-and-datasets.md @@ -321,6 +321,7 @@ The table below introduces the models integrated with ms-swift: |[swift/Meta-Llama-3-70B-Instruct-AWQ](https://modelscope.cn/models/swift/Meta-Llama-3-70B-Instruct-AWQ)|llama|llama3|-|✘|-|[study-hjt/Meta-Llama-3-70B-Instruct-AWQ](https://huggingface.co/study-hjt/Meta-Llama-3-70B-Instruct-AWQ)| |[ChineseAlpacaGroup/llama-3-chinese-8b-instruct](https://modelscope.cn/models/ChineseAlpacaGroup/llama-3-chinese-8b-instruct)|llama|llama3|-|✔|-|[hfl/llama-3-chinese-8b-instruct](https://huggingface.co/hfl/llama-3-chinese-8b-instruct)| |[ChineseAlpacaGroup/llama-3-chinese-8b](https://modelscope.cn/models/ChineseAlpacaGroup/llama-3-chinese-8b)|llama|llama3|-|✔|-|[hfl/llama-3-chinese-8b](https://huggingface.co/hfl/llama-3-chinese-8b)| +|[microsoft/UserLM-8b](https://modelscope.cn/models/microsoft/UserLM-8b)|llama|userlm|-|✔|-|[microsoft/UserLM-8b](https://huggingface.co/microsoft/UserLM-8b)| |[LLM-Research/Meta-Llama-3.1-8B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-8B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)| |[LLM-Research/Meta-Llama-3.1-70B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-70B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct)| |[LLM-Research/Meta-Llama-3.1-405B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-405B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-405B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct)| diff --git a/swift/model/models/llama.py b/swift/model/models/llama.py index bef8714e4f..9750c3d3b7 100644 --- a/swift/model/models/llama.py +++ b/swift/model/models/llama.py @@ -148,6 +148,9 @@ def get_config(self, model_dir): Model('ChineseAlpacaGroup/llama-3-chinese-8b-instruct', 'hfl/llama-3-chinese-8b-instruct'), Model('ChineseAlpacaGroup/llama-3-chinese-8b', 'hfl/llama-3-chinese-8b'), ], TemplateType.llama3), + ModelGroup([ + Model('microsoft/UserLM-8b', 'microsoft/UserLM-8b'), + ], TemplateType.userlm), # llama3.1 ModelGroup( [ diff --git a/swift/template/constant.py b/swift/template/constant.py index 5ce8e1a729..3e23f4b2ab 100644 --- a/swift/template/constant.py +++ b/swift/template/constant.py @@ -28,6 +28,7 @@ class LLMTemplateType: llama = 'llama' # llama2 llama3 = 'llama3' llama3_2 = 'llama3_2' + userlm = 'userlm' reflection = 'reflection' megrez = 'megrez' yi_coder = 'yi_coder' diff --git a/swift/template/templates/llama.py b/swift/template/templates/llama.py index 49c14eb188..dfee6d4477 100644 --- a/swift/template/templates/llama.py +++ b/swift/template/templates/llama.py @@ -3,16 +3,18 @@ import datetime as dt import torch import torch.nn as nn -from dataclasses import dataclass, field +from copy import deepcopy +from dataclasses import asdict, dataclass, field from typing import Any, Dict, List, Literal, Optional from swift.utils import get_env_args from ..base import Template from ..constant import LLMTemplateType, MLLMTemplateType from ..register import TemplateMeta, register_template -from ..template_inputs import StdTemplateInputs -from ..utils import Context, Prompt, Word, findall +from ..template_inputs import StdTemplateInputs, TemplateInputs +from ..utils import Context, ContextType, Prompt, Word, findall from ..vision_utils import load_batch +from .utils import EmptyTemplateMeta # ref: https://github.com/facebookresearch/llama/blob/main/llama/generation.py LLAMA_DEFAULT_SYSTEM = ( @@ -48,6 +50,82 @@ class Llama3TemplateMeta(TemplateMeta): register_template(Llama3TemplateMeta(LLMTemplateType.llama3)) +class UserLMTemplate(Template): + + def encode(self, inputs, return_template_inputs: bool = False, return_length: bool = False): + from swift.infer_engine import InferRequest + assert self._processor_inited, ('Please initialize the processor before calling the template.encode method: ' + 'template.init_processor(processor).') + if isinstance(inputs, InferRequest): + inputs = asdict(inputs) + if isinstance(inputs, dict): + inputs = TemplateInputs.from_dict(inputs) + elif isinstance(inputs, TemplateInputs): + inputs = deepcopy(inputs) + return super().encode(inputs, return_template_inputs=return_template_inputs, return_length=return_length) + + def _swift_encode(self, inputs: StdTemplateInputs): + system = self._get_system(inputs) + messages = inputs.messages + assert messages, 'UserLM expects non-empty messages.' + + res_context_list: List[Context] = [] + res_context_types: List[ContextType] = [] + loss_scale_list: List[float] = [] + + if self.template_meta.auto_add_bos: + all_tokens = self.tokenizer.encode('a') + single_token = self.tokenizer.encode('a', add_special_tokens=False) + assert len(single_token) == 1 + idx = all_tokens.index(single_token[0]) + bos_token = all_tokens[:idx] + if bos_token: + res_context_list.append(bos_token) + res_context_types.append(ContextType.OTHER) + loss_scale_list.append(0.) + + if system: + self._concat_context_list( + self.template_meta.system_prefix, res_context_list, res_context_types, system=system) + loss_scale_list.extend([0.] * (len(res_context_list) - len(loss_scale_list))) + + last_role = messages[-1]['role'] + is_training_target = self.is_training and last_role == 'user' + history_messages = messages[:-1] if is_training_target else messages + + for message in history_messages: + role = message['role'] + content = message['content'] + assert role in {'user', 'assistant'}, f'role: "{role}"' + header = f'<|start_header_id|>{role}<|end_header_id|>\n\n' + res_context_list.extend([header, content, '<|eot_id|>']) + res_context_types.extend([ContextType.OTHER, ContextType.OTHER, ContextType.OTHER]) + loss_scale_list.extend([0., 0., 0.]) + + target_header = '<|start_header_id|>user<|end_header_id|>\n\n' + res_context_list.append(target_header) + res_context_types.append(ContextType.OTHER) + loss_scale_list.append(0.) + + if is_training_target: + res_context_list.extend([messages[-1]['content'], '<|eot_id|>']) + res_context_types.extend([ContextType.RESPONSE, ContextType.SUFFIX]) + loss_scale_list.extend([1., 1.]) + + answer_len = 2 if is_training_target else 0 + return res_context_list, loss_scale_list, answer_len + + +@dataclass +class UserLMTemplateMeta(EmptyTemplateMeta): + agent_template: str = 'llama3' + system_prefix: Optional[Prompt] = field( + default_factory=lambda: ['<|start_header_id|>system<|end_header_id|>\n\n{{SYSTEM}}<|eot_id|>']) + + +register_template(UserLMTemplateMeta(LLMTemplateType.userlm, template_cls=UserLMTemplate)) + + def _get_llama3_2_prefix() -> Prompt: now = dt.datetime.now() date_string = now.strftime('%d %b %Y') diff --git a/tests/test_align/test_template/test_template.py b/tests/test_align/test_template/test_template.py index d336a8b13a..b0e298abb0 100644 --- a/tests/test_align/test_template/test_template.py +++ b/tests/test_align/test_template/test_template.py @@ -102,6 +102,30 @@ def test_minimax_vl(): assert len(res['input_ids']) == 5877 +def test_userlm(): + tokenizer = get_processor('microsoft/UserLM-8b') + template = get_template(tokenizer) + assert template.template_backend == 'swift' + inputs = { + 'messages': [{ + 'role': 'system', + 'content': 'You generate the next user turn in a conversation.' + }, { + 'role': 'user', + 'content': 'The assistant just said: Hello, how can I help you today?' + }, { + 'role': 'assistant', + 'content': 'I can help with planning, coding, or writing. What would you like to do?' + }] + } + res = template.encode(inputs) + template.print_inputs(res) + text = tokenizer.decode(res['input_ids']) + assert '<|start_header_id|>assistant<|end_header_id|>' in text + assert 'I can help with planning, coding, or writing. What would you like to do?' in text + assert text.endswith('<|start_header_id|>user<|end_header_id|>\n\n') + + def test_deepseek_v3_1(): tokenizer = get_processor('deepseek-ai/DeepSeek-V3.1') template = get_template(tokenizer)