Skip to content

Commit c176cb9

Browse files
committed
Add UserLM template support
1 parent 963cb15 commit c176cb9

6 files changed

Lines changed: 110 additions & 3 deletions

File tree

docs/source/Instruction/Supported-models-and-datasets.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@
320320
|[swift/Meta-Llama-3-70B-Instruct-AWQ](https://modelscope.cn/models/swift/Meta-Llama-3-70B-Instruct-AWQ)|llama|llama3|-|✘|-|[study-hjt/Meta-Llama-3-70B-Instruct-AWQ](https://huggingface.co/study-hjt/Meta-Llama-3-70B-Instruct-AWQ)|
321321
|[ChineseAlpacaGroup/llama-3-chinese-8b-instruct](https://modelscope.cn/models/ChineseAlpacaGroup/llama-3-chinese-8b-instruct)|llama|llama3|-|✔|-|[hfl/llama-3-chinese-8b-instruct](https://huggingface.co/hfl/llama-3-chinese-8b-instruct)|
322322
|[ChineseAlpacaGroup/llama-3-chinese-8b](https://modelscope.cn/models/ChineseAlpacaGroup/llama-3-chinese-8b)|llama|llama3|-|✔|-|[hfl/llama-3-chinese-8b](https://huggingface.co/hfl/llama-3-chinese-8b)|
323+
|[microsoft/UserLM-8b](https://modelscope.cn/models/microsoft/UserLM-8b)|llama|userlm|-|✔|-|[microsoft/UserLM-8b](https://huggingface.co/microsoft/UserLM-8b)|
323324
|[LLM-Research/Meta-Llama-3.1-8B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-8B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)|
324325
|[LLM-Research/Meta-Llama-3.1-70B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-70B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct)|
325326
|[LLM-Research/Meta-Llama-3.1-405B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-405B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-405B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct)|

docs/source_en/Instruction/Supported-models-and-datasets.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ The table below introduces the models integrated with ms-swift:
321321
|[swift/Meta-Llama-3-70B-Instruct-AWQ](https://modelscope.cn/models/swift/Meta-Llama-3-70B-Instruct-AWQ)|llama|llama3|-|✘|-|[study-hjt/Meta-Llama-3-70B-Instruct-AWQ](https://huggingface.co/study-hjt/Meta-Llama-3-70B-Instruct-AWQ)|
322322
|[ChineseAlpacaGroup/llama-3-chinese-8b-instruct](https://modelscope.cn/models/ChineseAlpacaGroup/llama-3-chinese-8b-instruct)|llama|llama3|-|✔|-|[hfl/llama-3-chinese-8b-instruct](https://huggingface.co/hfl/llama-3-chinese-8b-instruct)|
323323
|[ChineseAlpacaGroup/llama-3-chinese-8b](https://modelscope.cn/models/ChineseAlpacaGroup/llama-3-chinese-8b)|llama|llama3|-|✔|-|[hfl/llama-3-chinese-8b](https://huggingface.co/hfl/llama-3-chinese-8b)|
324+
|[microsoft/UserLM-8b](https://modelscope.cn/models/microsoft/UserLM-8b)|llama|userlm|-|✔|-|[microsoft/UserLM-8b](https://huggingface.co/microsoft/UserLM-8b)|
324325
|[LLM-Research/Meta-Llama-3.1-8B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-8B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)|
325326
|[LLM-Research/Meta-Llama-3.1-70B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-70B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct)|
326327
|[LLM-Research/Meta-Llama-3.1-405B-Instruct](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-405B-Instruct)|llama|llama3_2|transformers>=4.43|✔|-|[meta-llama/Meta-Llama-3.1-405B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct)|

swift/model/models/llama.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ def get_config(self, model_dir):
148148
Model('ChineseAlpacaGroup/llama-3-chinese-8b-instruct', 'hfl/llama-3-chinese-8b-instruct'),
149149
Model('ChineseAlpacaGroup/llama-3-chinese-8b', 'hfl/llama-3-chinese-8b'),
150150
], TemplateType.llama3),
151+
ModelGroup([
152+
Model('microsoft/UserLM-8b', 'microsoft/UserLM-8b'),
153+
],
154+
TemplateType.userlm),
151155
# llama3.1
152156
ModelGroup(
153157
[

swift/template/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class LLMTemplateType:
2828
llama = 'llama' # llama2
2929
llama3 = 'llama3'
3030
llama3_2 = 'llama3_2'
31+
userlm = 'userlm'
3132
reflection = 'reflection'
3233
megrez = 'megrez'
3334
yi_coder = 'yi_coder'

swift/template/templates/llama.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,18 @@
33
import datetime as dt
44
import torch
55
import torch.nn as nn
6-
from dataclasses import dataclass, field
6+
from copy import deepcopy
7+
from dataclasses import asdict, dataclass, field
78
from typing import Any, Dict, List, Literal, Optional
89

910
from swift.utils import get_env_args
1011
from ..base import Template
1112
from ..constant import LLMTemplateType, MLLMTemplateType
1213
from ..register import TemplateMeta, register_template
13-
from ..template_inputs import StdTemplateInputs
14-
from ..utils import Context, Prompt, Word, findall
14+
from ..template_inputs import StdTemplateInputs, TemplateInputs
15+
from ..utils import Context, ContextType, Prompt, Word, findall
1516
from ..vision_utils import load_batch
17+
from .utils import EmptyTemplateMeta
1618

1719
# ref: https://github.com/facebookresearch/llama/blob/main/llama/generation.py
1820
LLAMA_DEFAULT_SYSTEM = (
@@ -47,6 +49,80 @@ class Llama3TemplateMeta(TemplateMeta):
4749

4850
register_template(Llama3TemplateMeta(LLMTemplateType.llama3))
4951

52+
class UserLMTemplate(Template):
53+
54+
def encode(self, inputs, return_template_inputs: bool = False, return_length: bool = False):
55+
from swift.infer_engine import InferRequest
56+
assert self._processor_inited, ('Please initialize the processor before calling the template.encode method: '
57+
'template.init_processor(processor).')
58+
if isinstance(inputs, InferRequest):
59+
inputs = asdict(inputs)
60+
if isinstance(inputs, dict):
61+
inputs = TemplateInputs.from_dict(inputs)
62+
elif isinstance(inputs, TemplateInputs):
63+
inputs = deepcopy(inputs)
64+
return super().encode(inputs, return_template_inputs=return_template_inputs, return_length=return_length)
65+
66+
def _swift_encode(self, inputs: StdTemplateInputs):
67+
system = self._get_system(inputs)
68+
messages = inputs.messages
69+
assert messages, 'UserLM expects non-empty messages.'
70+
71+
res_context_list: List[Context] = []
72+
res_context_types: List[ContextType] = []
73+
loss_scale_list: List[float] = []
74+
75+
if self.template_meta.auto_add_bos:
76+
all_tokens = self.tokenizer.encode('a')
77+
single_token = self.tokenizer.encode('a', add_special_tokens=False)
78+
assert len(single_token) == 1
79+
idx = all_tokens.index(single_token[0])
80+
bos_token = all_tokens[:idx]
81+
if bos_token:
82+
res_context_list.append(bos_token)
83+
res_context_types.append(ContextType.OTHER)
84+
loss_scale_list.append(0.)
85+
86+
if system:
87+
self._concat_context_list(self.template_meta.system_prefix, res_context_list, res_context_types, system=system)
88+
loss_scale_list.extend([0.] * (len(res_context_list) - len(loss_scale_list)))
89+
90+
last_role = messages[-1]['role']
91+
is_training_target = self.is_training and last_role == 'user'
92+
history_messages = messages[:-1] if is_training_target else messages
93+
94+
for message in history_messages:
95+
role = message['role']
96+
content = message['content']
97+
assert role in {'user', 'assistant'}, f'role: "{role}"'
98+
header = f'<|start_header_id|>{role}<|end_header_id|>\n\n'
99+
res_context_list.extend([header, content, '<|eot_id|>'])
100+
res_context_types.extend([ContextType.OTHER, ContextType.OTHER, ContextType.OTHER])
101+
loss_scale_list.extend([0., 0., 0.])
102+
103+
target_header = '<|start_header_id|>user<|end_header_id|>\n\n'
104+
res_context_list.append(target_header)
105+
res_context_types.append(ContextType.OTHER)
106+
loss_scale_list.append(0.)
107+
108+
if is_training_target:
109+
res_context_list.extend([messages[-1]['content'], '<|eot_id|>'])
110+
res_context_types.extend([ContextType.RESPONSE, ContextType.SUFFIX])
111+
loss_scale_list.extend([1., 1.])
112+
113+
answer_len = 2 if is_training_target else 0
114+
return res_context_list, loss_scale_list, answer_len
115+
116+
117+
@dataclass
118+
class UserLMTemplateMeta(EmptyTemplateMeta):
119+
agent_template: str = 'llama3'
120+
system_prefix: Optional[Prompt] = field(
121+
default_factory=lambda: ['<|start_header_id|>system<|end_header_id|>\n\n{{SYSTEM}}<|eot_id|>'])
122+
123+
124+
register_template(UserLMTemplateMeta(LLMTemplateType.userlm, template_cls=UserLMTemplate))
125+
50126

51127
def _get_llama3_2_prefix() -> Prompt:
52128
now = dt.datetime.now()

tests/test_align/test_template/test_template.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,30 @@ def test_minimax_vl():
102102
assert len(res['input_ids']) == 5877
103103

104104

105+
def test_userlm():
106+
tokenizer = get_processor('/root/models/microsoft/UserLM-8b/microsoft/UserLM-8b')
107+
template = get_template(tokenizer)
108+
assert template.template_backend == 'swift'
109+
inputs = {
110+
'messages': [{
111+
'role': 'system',
112+
'content': 'You generate the next user turn in a conversation.'
113+
}, {
114+
'role': 'user',
115+
'content': 'The assistant just said: Hello, how can I help you today?'
116+
}, {
117+
'role': 'assistant',
118+
'content': 'I can help with planning, coding, or writing. What would you like to do?'
119+
}]
120+
}
121+
res = template.encode(inputs)
122+
template.print_inputs(res)
123+
text = tokenizer.decode(res['input_ids'])
124+
assert '<|start_header_id|>assistant<|end_header_id|>' in text
125+
assert 'I can help with planning, coding, or writing. What would you like to do?' in text
126+
assert text.endswith('<|start_header_id|>user<|end_header_id|>\n\n')
127+
128+
105129
def test_deepseek_v3_1():
106130
tokenizer = get_processor('deepseek-ai/DeepSeek-V3.1')
107131
template = get_template(tokenizer)

0 commit comments

Comments
 (0)