-
Notifications
You must be signed in to change notification settings - Fork 1.4k
[feat]Add UserLM template support #9021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,16 +3,18 @@ | |
| import datetime as dt | ||
| import torch | ||
| import torch.nn as nn | ||
| from dataclasses import dataclass, field | ||
| from copy import deepcopy | ||
| from dataclasses import asdict, dataclass, field | ||
| from typing import Any, Dict, List, Literal, Optional | ||
|
|
||
| from swift.utils import get_env_args | ||
| from ..base import Template | ||
| from ..constant import LLMTemplateType, MLLMTemplateType | ||
| from ..register import TemplateMeta, register_template | ||
| from ..template_inputs import StdTemplateInputs | ||
| from ..utils import Context, Prompt, Word, findall | ||
| from ..template_inputs import StdTemplateInputs, TemplateInputs | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| from ..utils import Context, ContextType, Prompt, Word, findall | ||
| from ..vision_utils import load_batch | ||
| from .utils import EmptyTemplateMeta | ||
|
|
||
| # ref: https://github.com/facebookresearch/llama/blob/main/llama/generation.py | ||
| LLAMA_DEFAULT_SYSTEM = ( | ||
|
|
@@ -48,6 +50,82 @@ class Llama3TemplateMeta(TemplateMeta): | |
| register_template(Llama3TemplateMeta(LLMTemplateType.llama3)) | ||
|
|
||
|
|
||
| class UserLMTemplate(Template): | ||
|
|
||
| def encode(self, inputs, return_template_inputs: bool = False, return_length: bool = False): | ||
| from swift.infer_engine import InferRequest | ||
| assert self._processor_inited, ('Please initialize the processor before calling the template.encode method: ' | ||
| 'template.init_processor(processor).') | ||
| if isinstance(inputs, InferRequest): | ||
| inputs = asdict(inputs) | ||
| if isinstance(inputs, dict): | ||
| inputs = TemplateInputs.from_dict(inputs) | ||
| elif isinstance(inputs, TemplateInputs): | ||
| inputs = deepcopy(inputs) | ||
| return super().encode(inputs, return_template_inputs=return_template_inputs, return_length=return_length) | ||
|
|
||
|
Comment on lines
+55
to
+66
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| def _swift_encode(self, inputs: StdTemplateInputs): | ||
| system = self._get_system(inputs) | ||
| messages = inputs.messages | ||
| assert messages, 'UserLM expects non-empty messages.' | ||
|
|
||
| res_context_list: List[Context] = [] | ||
| res_context_types: List[ContextType] = [] | ||
| loss_scale_list: List[float] = [] | ||
|
|
||
| if self.template_meta.auto_add_bos: | ||
| all_tokens = self.tokenizer.encode('a') | ||
| single_token = self.tokenizer.encode('a', add_special_tokens=False) | ||
| assert len(single_token) == 1 | ||
| idx = all_tokens.index(single_token[0]) | ||
| bos_token = all_tokens[:idx] | ||
| if bos_token: | ||
| res_context_list.append(bos_token) | ||
| res_context_types.append(ContextType.OTHER) | ||
| loss_scale_list.append(0.) | ||
|
|
||
| if system: | ||
| self._concat_context_list( | ||
| self.template_meta.system_prefix, res_context_list, res_context_types, system=system) | ||
| loss_scale_list.extend([0.] * (len(res_context_list) - len(loss_scale_list))) | ||
|
|
||
| last_role = messages[-1]['role'] | ||
| is_training_target = self.is_training and last_role == 'user' | ||
| history_messages = messages[:-1] if is_training_target else messages | ||
|
|
||
| for message in history_messages: | ||
| role = message['role'] | ||
| content = message['content'] | ||
| assert role in {'user', 'assistant'}, f'role: "{role}"' | ||
| header = f'<|start_header_id|>{role}<|end_header_id|>\n\n' | ||
| res_context_list.extend([header, content, '<|eot_id|>']) | ||
| res_context_types.extend([ContextType.OTHER, ContextType.OTHER, ContextType.OTHER]) | ||
| loss_scale_list.extend([0., 0., 0.]) | ||
|
|
||
| target_header = '<|start_header_id|>user<|end_header_id|>\n\n' | ||
| res_context_list.append(target_header) | ||
| res_context_types.append(ContextType.OTHER) | ||
| loss_scale_list.append(0.) | ||
|
|
||
| if is_training_target: | ||
| res_context_list.extend([messages[-1]['content'], '<|eot_id|>']) | ||
| res_context_types.extend([ContextType.RESPONSE, ContextType.SUFFIX]) | ||
| loss_scale_list.extend([1., 1.]) | ||
|
|
||
| answer_len = 2 if is_training_target else 0 | ||
| return res_context_list, loss_scale_list, answer_len | ||
|
|
||
|
|
||
| @dataclass | ||
| class UserLMTemplateMeta(EmptyTemplateMeta): | ||
| agent_template: str = 'llama3' | ||
| system_prefix: Optional[Prompt] = field( | ||
| default_factory=lambda: ['<|start_header_id|>system<|end_header_id|>\n\n{{SYSTEM}}<|eot_id|>']) | ||
|
|
||
|
|
||
| register_template(UserLMTemplateMeta(LLMTemplateType.userlm, template_cls=UserLMTemplate)) | ||
|
|
||
|
|
||
| def _get_llama3_2_prefix() -> Prompt: | ||
| now = dt.datetime.now() | ||
| date_string = now.strftime('%d %b %Y') | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The imports
deepcopyandasdictare only used in the redundantencodemethod override. If that method is removed, these imports should also be removed to keep the code clean.