Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/Instruction/Supported-models-and-datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@
|[swift/Meta-Llama-3-70B-Instruct-AWQ](https://modelscope.cn/models/swift/Meta-Llama-3-70B-Instruct-AWQ)|llama|llama3|-|✘|-|[study-hjt/Meta-Llama-3-70B-Instruct-AWQ](https://huggingface.co/study-hjt/Meta-Llama-3-70B-Instruct-AWQ)|
|[ChineseAlpacaGroup/llama-3-chinese-8b-instruct](https://modelscope.cn/models/ChineseAlpacaGroup/llama-3-chinese-8b-instruct)|llama|llama3|-|✔|-|[hfl/llama-3-chinese-8b-instruct](https://huggingface.co/hfl/llama-3-chinese-8b-instruct)|
|[ChineseAlpacaGroup/llama-3-chinese-8b](https://modelscope.cn/models/ChineseAlpacaGroup/llama-3-chinese-8b)|llama|llama3|-|✔|-|[hfl/llama-3-chinese-8b](https://huggingface.co/hfl/llama-3-chinese-8b)|
|[microsoft/UserLM-8b](https://modelscope.cn/models/microsoft/UserLM-8b)|llama|userlm|-|✔|-|[microsoft/UserLM-8b](https://huggingface.co/microsoft/UserLM-8b)|
|[LLM-Research/Meta-Llama-3.1-8B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-8B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)|
|[LLM-Research/Meta-Llama-3.1-70B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-70B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct)|
|[LLM-Research/Meta-Llama-3.1-405B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-405B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-405B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct)|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ The table below introduces the models integrated with ms-swift:
|[swift/Meta-Llama-3-70B-Instruct-AWQ](https://modelscope.cn/models/swift/Meta-Llama-3-70B-Instruct-AWQ)|llama|llama3|-|✘|-|[study-hjt/Meta-Llama-3-70B-Instruct-AWQ](https://huggingface.co/study-hjt/Meta-Llama-3-70B-Instruct-AWQ)|
|[ChineseAlpacaGroup/llama-3-chinese-8b-instruct](https://modelscope.cn/models/ChineseAlpacaGroup/llama-3-chinese-8b-instruct)|llama|llama3|-|✔|-|[hfl/llama-3-chinese-8b-instruct](https://huggingface.co/hfl/llama-3-chinese-8b-instruct)|
|[ChineseAlpacaGroup/llama-3-chinese-8b](https://modelscope.cn/models/ChineseAlpacaGroup/llama-3-chinese-8b)|llama|llama3|-|✔|-|[hfl/llama-3-chinese-8b](https://huggingface.co/hfl/llama-3-chinese-8b)|
|[microsoft/UserLM-8b](https://modelscope.cn/models/microsoft/UserLM-8b)|llama|userlm|-|✔|-|[microsoft/UserLM-8b](https://huggingface.co/microsoft/UserLM-8b)|
|[LLM-Research/Meta-Llama-3.1-8B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-8B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)|
|[LLM-Research/Meta-Llama-3.1-70B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-70B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct)|
|[LLM-Research/Meta-Llama-3.1-405B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-405B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-405B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct)|
Expand Down
3 changes: 3 additions & 0 deletions swift/model/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def get_config(self, model_dir):
Model('ChineseAlpacaGroup/llama-3-chinese-8b-instruct', 'hfl/llama-3-chinese-8b-instruct'),
Model('ChineseAlpacaGroup/llama-3-chinese-8b', 'hfl/llama-3-chinese-8b'),
], TemplateType.llama3),
ModelGroup([
Model('microsoft/UserLM-8b', 'microsoft/UserLM-8b'),
], TemplateType.userlm),
# llama3.1
ModelGroup(
[
Expand Down
1 change: 1 addition & 0 deletions swift/template/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class LLMTemplateType:
llama = 'llama' # llama2
llama3 = 'llama3'
llama3_2 = 'llama3_2'
userlm = 'userlm'
reflection = 'reflection'
megrez = 'megrez'
yi_coder = 'yi_coder'
Expand Down
84 changes: 81 additions & 3 deletions swift/template/templates/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
import datetime as dt
import torch
import torch.nn as nn
from dataclasses import dataclass, field
from copy import deepcopy
from dataclasses import asdict, dataclass, field
Comment on lines +6 to +7
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The imports deepcopy and asdict are only used in the redundant encode method override. If that method is removed, these imports should also be removed to keep the code clean.

Suggested change
from copy import deepcopy
from dataclasses import asdict, dataclass, field
from dataclasses import dataclass, field

from typing import Any, Dict, List, Literal, Optional

from swift.utils import get_env_args
from ..base import Template
from ..constant import LLMTemplateType, MLLMTemplateType
from ..register import TemplateMeta, register_template
from ..template_inputs import StdTemplateInputs
from ..utils import Context, Prompt, Word, findall
from ..template_inputs import StdTemplateInputs, TemplateInputs
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

TemplateInputs is only used in the redundant encode method override and can be removed from the imports if the method is removed.

Suggested change
from ..template_inputs import StdTemplateInputs, TemplateInputs
from ..template_inputs import StdTemplateInputs

from ..utils import Context, ContextType, Prompt, Word, findall
from ..vision_utils import load_batch
from .utils import EmptyTemplateMeta

# ref: https://github.com/facebookresearch/llama/blob/main/llama/generation.py
LLAMA_DEFAULT_SYSTEM = (
Expand Down Expand Up @@ -48,6 +50,82 @@ class Llama3TemplateMeta(TemplateMeta):
register_template(Llama3TemplateMeta(LLMTemplateType.llama3))


class UserLMTemplate(Template):

def encode(self, inputs, return_template_inputs: bool = False, return_length: bool = False):
from swift.infer_engine import InferRequest
assert self._processor_inited, ('Please initialize the processor before calling the template.encode method: '
'template.init_processor(processor).')
if isinstance(inputs, InferRequest):
inputs = asdict(inputs)
if isinstance(inputs, dict):
inputs = TemplateInputs.from_dict(inputs)
elif isinstance(inputs, TemplateInputs):
inputs = deepcopy(inputs)
return super().encode(inputs, return_template_inputs=return_template_inputs, return_length=return_length)

Comment on lines +55 to +66
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The encode method override in UserLMTemplate is redundant. The base Template.encode already handles the conversion of a dict to TemplateInputs and verifies that the processor is initialized. Furthermore, the use of deepcopy on TemplateInputs is unnecessary and can significantly impact performance when dealing with large inputs (e.g., multi-modal data). The InferRequest handling is also typically managed by the InferEngine before it calls the template. Removing this override simplifies the implementation without loss of functionality.

def _swift_encode(self, inputs: StdTemplateInputs):
system = self._get_system(inputs)
messages = inputs.messages
assert messages, 'UserLM expects non-empty messages.'

res_context_list: List[Context] = []
res_context_types: List[ContextType] = []
loss_scale_list: List[float] = []

if self.template_meta.auto_add_bos:
all_tokens = self.tokenizer.encode('a')
single_token = self.tokenizer.encode('a', add_special_tokens=False)
assert len(single_token) == 1
idx = all_tokens.index(single_token[0])
bos_token = all_tokens[:idx]
if bos_token:
res_context_list.append(bos_token)
res_context_types.append(ContextType.OTHER)
loss_scale_list.append(0.)

if system:
self._concat_context_list(
self.template_meta.system_prefix, res_context_list, res_context_types, system=system)
loss_scale_list.extend([0.] * (len(res_context_list) - len(loss_scale_list)))

last_role = messages[-1]['role']
is_training_target = self.is_training and last_role == 'user'
history_messages = messages[:-1] if is_training_target else messages

for message in history_messages:
role = message['role']
content = message['content']
assert role in {'user', 'assistant'}, f'role: "{role}"'
header = f'<|start_header_id|>{role}<|end_header_id|>\n\n'
res_context_list.extend([header, content, '<|eot_id|>'])
res_context_types.extend([ContextType.OTHER, ContextType.OTHER, ContextType.OTHER])
loss_scale_list.extend([0., 0., 0.])

target_header = '<|start_header_id|>user<|end_header_id|>\n\n'
res_context_list.append(target_header)
res_context_types.append(ContextType.OTHER)
loss_scale_list.append(0.)

if is_training_target:
res_context_list.extend([messages[-1]['content'], '<|eot_id|>'])
res_context_types.extend([ContextType.RESPONSE, ContextType.SUFFIX])
loss_scale_list.extend([1., 1.])

answer_len = 2 if is_training_target else 0
return res_context_list, loss_scale_list, answer_len


@dataclass
class UserLMTemplateMeta(EmptyTemplateMeta):
agent_template: str = 'llama3'
system_prefix: Optional[Prompt] = field(
default_factory=lambda: ['<|start_header_id|>system<|end_header_id|>\n\n{{SYSTEM}}<|eot_id|>'])


register_template(UserLMTemplateMeta(LLMTemplateType.userlm, template_cls=UserLMTemplate))


def _get_llama3_2_prefix() -> Prompt:
now = dt.datetime.now()
date_string = now.strftime('%d %b %Y')
Expand Down
24 changes: 24 additions & 0 deletions tests/test_align/test_template/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,30 @@ def test_minimax_vl():
assert len(res['input_ids']) == 5877


def test_userlm():
tokenizer = get_processor('microsoft/UserLM-8b')
template = get_template(tokenizer)
assert template.template_backend == 'swift'
inputs = {
'messages': [{
'role': 'system',
'content': 'You generate the next user turn in a conversation.'
}, {
'role': 'user',
'content': 'The assistant just said: Hello, how can I help you today?'
}, {
'role': 'assistant',
'content': 'I can help with planning, coding, or writing. What would you like to do?'
}]
}
res = template.encode(inputs)
template.print_inputs(res)
text = tokenizer.decode(res['input_ids'])
assert '<|start_header_id|>assistant<|end_header_id|>' in text
assert 'I can help with planning, coding, or writing. What would you like to do?' in text
assert text.endswith('<|start_header_id|>user<|end_header_id|>\n\n')


def test_deepseek_v3_1():
tokenizer = get_processor('deepseek-ai/DeepSeek-V3.1')
template = get_template(tokenizer)
Expand Down
Loading