|
3 | 3 | import datetime as dt |
4 | 4 | import torch |
5 | 5 | import torch.nn as nn |
6 | | -from dataclasses import dataclass, field |
| 6 | +from copy import deepcopy |
| 7 | +from dataclasses import asdict, dataclass, field |
7 | 8 | from typing import Any, Dict, List, Literal, Optional |
8 | 9 |
|
9 | 10 | from swift.utils import get_env_args |
10 | 11 | from ..base import Template |
11 | 12 | from ..constant import LLMTemplateType, MLLMTemplateType |
12 | 13 | from ..register import TemplateMeta, register_template |
13 | | -from ..template_inputs import StdTemplateInputs |
14 | | -from ..utils import Context, Prompt, Word, findall |
| 14 | +from ..template_inputs import StdTemplateInputs, TemplateInputs |
| 15 | +from ..utils import Context, ContextType, Prompt, Word, findall |
15 | 16 | from ..vision_utils import load_batch |
| 17 | +from .utils import EmptyTemplateMeta |
16 | 18 |
|
17 | 19 | # ref: https://github.com/facebookresearch/llama/blob/main/llama/generation.py |
18 | 20 | LLAMA_DEFAULT_SYSTEM = ( |
@@ -47,6 +49,80 @@ class Llama3TemplateMeta(TemplateMeta): |
47 | 49 |
|
48 | 50 | register_template(Llama3TemplateMeta(LLMTemplateType.llama3)) |
49 | 51 |
|
| 52 | +class UserLMTemplate(Template): |
| 53 | + |
| 54 | + def encode(self, inputs, return_template_inputs: bool = False, return_length: bool = False): |
| 55 | + from swift.infer_engine import InferRequest |
| 56 | + assert self._processor_inited, ('Please initialize the processor before calling the template.encode method: ' |
| 57 | + 'template.init_processor(processor).') |
| 58 | + if isinstance(inputs, InferRequest): |
| 59 | + inputs = asdict(inputs) |
| 60 | + if isinstance(inputs, dict): |
| 61 | + inputs = TemplateInputs.from_dict(inputs) |
| 62 | + elif isinstance(inputs, TemplateInputs): |
| 63 | + inputs = deepcopy(inputs) |
| 64 | + return super().encode(inputs, return_template_inputs=return_template_inputs, return_length=return_length) |
| 65 | + |
| 66 | + def _swift_encode(self, inputs: StdTemplateInputs): |
| 67 | + system = self._get_system(inputs) |
| 68 | + messages = inputs.messages |
| 69 | + assert messages, 'UserLM expects non-empty messages.' |
| 70 | + |
| 71 | + res_context_list: List[Context] = [] |
| 72 | + res_context_types: List[ContextType] = [] |
| 73 | + loss_scale_list: List[float] = [] |
| 74 | + |
| 75 | + if self.template_meta.auto_add_bos: |
| 76 | + all_tokens = self.tokenizer.encode('a') |
| 77 | + single_token = self.tokenizer.encode('a', add_special_tokens=False) |
| 78 | + assert len(single_token) == 1 |
| 79 | + idx = all_tokens.index(single_token[0]) |
| 80 | + bos_token = all_tokens[:idx] |
| 81 | + if bos_token: |
| 82 | + res_context_list.append(bos_token) |
| 83 | + res_context_types.append(ContextType.OTHER) |
| 84 | + loss_scale_list.append(0.) |
| 85 | + |
| 86 | + if system: |
| 87 | + self._concat_context_list(self.template_meta.system_prefix, res_context_list, res_context_types, system=system) |
| 88 | + loss_scale_list.extend([0.] * (len(res_context_list) - len(loss_scale_list))) |
| 89 | + |
| 90 | + last_role = messages[-1]['role'] |
| 91 | + is_training_target = self.is_training and last_role == 'user' |
| 92 | + history_messages = messages[:-1] if is_training_target else messages |
| 93 | + |
| 94 | + for message in history_messages: |
| 95 | + role = message['role'] |
| 96 | + content = message['content'] |
| 97 | + assert role in {'user', 'assistant'}, f'role: "{role}"' |
| 98 | + header = f'<|start_header_id|>{role}<|end_header_id|>\n\n' |
| 99 | + res_context_list.extend([header, content, '<|eot_id|>']) |
| 100 | + res_context_types.extend([ContextType.OTHER, ContextType.OTHER, ContextType.OTHER]) |
| 101 | + loss_scale_list.extend([0., 0., 0.]) |
| 102 | + |
| 103 | + target_header = '<|start_header_id|>user<|end_header_id|>\n\n' |
| 104 | + res_context_list.append(target_header) |
| 105 | + res_context_types.append(ContextType.OTHER) |
| 106 | + loss_scale_list.append(0.) |
| 107 | + |
| 108 | + if is_training_target: |
| 109 | + res_context_list.extend([messages[-1]['content'], '<|eot_id|>']) |
| 110 | + res_context_types.extend([ContextType.RESPONSE, ContextType.SUFFIX]) |
| 111 | + loss_scale_list.extend([1., 1.]) |
| 112 | + |
| 113 | + answer_len = 2 if is_training_target else 0 |
| 114 | + return res_context_list, loss_scale_list, answer_len |
| 115 | + |
| 116 | + |
| 117 | +@dataclass |
| 118 | +class UserLMTemplateMeta(EmptyTemplateMeta): |
| 119 | + agent_template: str = 'llama3' |
| 120 | + system_prefix: Optional[Prompt] = field( |
| 121 | + default_factory=lambda: ['<|start_header_id|>system<|end_header_id|>\n\n{{SYSTEM}}<|eot_id|>']) |
| 122 | + |
| 123 | + |
| 124 | +register_template(UserLMTemplateMeta(LLMTemplateType.userlm, template_cls=UserLMTemplate)) |
| 125 | + |
50 | 126 |
|
51 | 127 | def _get_llama3_2_prefix() -> Prompt: |
52 | 128 | now = dt.datetime.now() |
|
0 commit comments