Skip to content

Commit 66edeac

Browse files
author
李宜杰
committed
local eval add mindformers model
1 parent 2d794cf commit 66edeac

5 files changed

Lines changed: 379 additions & 9 deletions

File tree

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from ais_bench.benchmark.models import MindFormerModel
2+
3+
models = [
4+
dict(
5+
attr="local", # local or service
6+
type=MindFormerModel, # transformers < 4.33.0 用这个,优先AutoModelForCausalLM.from_pretrained加载模型,失败则用AutoModel.from_pretrained加载
7+
abbr='mindformer-model',
8+
path='THUDM/chatglm-6b', # path to model dir, current value is just a example
9+
checkpoint = 'THUDM/your_checkpoint', # path to checkpoint file, current value is just a example
10+
yaml_cfg_file = 'THUDM/your.yaml',
11+
tokenizer_path='THUDM/chatglm-6b', # path to tokenizer dir, current value is just a example
12+
model_kwargs=dict( # 模型参数参考 huggingface.co/docs/transformers/v4.50.0/en/model_doc/auto#transformers.AutoModel.from_pretrained
13+
device_map='npu',
14+
),
15+
tokenizer_kwargs=dict( # tokenizer参数参考 huggingface.co/docs/transformers/v4.50.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase
16+
padding_side='right',
17+
),
18+
generation_kwargs = dict( # 后处理参数参考huggingface.co/docs/transformers/main_classes/test_generation
19+
temperature = 0.5,
20+
top_k = 10,
21+
top_p = 0.95,
22+
do_sample = True,
23+
seed = None,
24+
repetition_penalty = 1.03,
25+
),
26+
run_cfg = dict(num_gpus=1, num_procs=1), # 多卡/多机多卡 参数,使用torchrun拉起任务
27+
max_out_len=100, # 最大输出token长度
28+
batch_size=2, # 每次推理的batch size
29+
max_seq_len=2048,
30+
batch_padding=True,
31+
)
32+
]

ais_bench/benchmark/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
from ais_bench.benchmark.models.api_models.triton_api import TritonCustomAPIStream # noqa: F401
1515
from ais_bench.benchmark.models.api_models.tgi_api import TGICustomAPIStream # noqa: F401
1616
from ais_bench.benchmark.models.api_models.vllm_custom_api_chat import VllmMultiturnAPIChatStream # noqa: F401
17-
from ais_bench.benchmark.models.local_models.vllm_offline_vl import VLLMOfflineVLModel
17+
from ais_bench.benchmark.models.local_models.vllm_offline_vl import VLLMOfflineVLModel
18+
from ais_bench.benchmark.models.local_models.mindformers_model import MindFormerModel
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
import os, sys
2+
from typing import Dict, List, Optional, Union
3+
4+
import numpy as np
5+
import torch
6+
import transformers
7+
8+
from ais_bench.benchmark.models.base import BaseModel
9+
from ais_bench.benchmark.models.base_api import APITemplateParser
10+
from ais_bench.benchmark.registry import MODELS
11+
from ais_bench.benchmark.utils.logging import get_logger
12+
from ais_bench.benchmark.utils.prompt import PromptList
13+
14+
from mindspore import Tensor, Model
15+
from mindformers import MindFormerConfig, build_context
16+
from mindformers.models import build_network
17+
from mindformers.core.parallel_config import build_parallel_config
18+
from mindformers.utils.load_checkpoint_utils import get_load_path_after_hf_convert
19+
from mindformers.trainer.utils import transform_and_load_checkpoint
20+
21+
PromptType = Union[PromptList, str, dict]
22+
23+
24+
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
25+
"""Criteria to stop on the specified multi-token sequence."""
26+
27+
def __init__(
28+
self,
29+
sequence: str,
30+
tokenizer: transformers.PreTrainedTokenizer,
31+
batch_size: int,
32+
):
33+
self.done_tracker = [False] * batch_size
34+
self.sequence = sequence
35+
self.sequence_ids = tokenizer.encode(sequence,
36+
add_special_tokens=False)
37+
self.sequence_id_len = len(self.sequence_ids)
38+
self.tokenizer = tokenizer
39+
40+
def __call__(self, input_ids, scores, **kwargs) -> bool:
41+
# compare the last len(stop) tokens
42+
lookback_ids_batch = input_ids[:, -self.sequence_id_len:]
43+
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
44+
for i, done in enumerate(self.done_tracker):
45+
if done:
46+
continue
47+
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
48+
return False not in self.done_tracker
49+
50+
51+
def drop_error_generation_kwargs(generation_kwargs: dict) -> dict:
52+
for key in ['is_synthetic', 'batch_size', 'do_performance']:
53+
if key in generation_kwargs:
54+
generation_kwargs.pop(key)
55+
return generation_kwargs
56+
57+
58+
@MODELS.register_module()
59+
class MindFormerModel(BaseModel):
60+
61+
def __init__(self,
62+
path: str,
63+
checkpoint: Optional[str] = None,
64+
yaml_cfg_file: Optional[str] = None,
65+
batch_size: int = 1,
66+
max_seq_len: int = 2048,
67+
tokenizer_path: Optional[str] = None,
68+
tokenizer_kwargs: dict = dict(),
69+
tokenizer_only: bool = False,
70+
generation_kwargs: dict = dict(),
71+
meta_template: Optional[Dict] = None,
72+
extract_pred_after_decode: bool = False,
73+
batch_padding: bool = False,
74+
pad_token_id: Optional[int] = None,
75+
mode: str = 'none',
76+
use_fastchat_template: bool = False,
77+
end_str: Optional[str] = None,
78+
**kwargs):
79+
super().__init__(path=path,
80+
max_seq_len=max_seq_len,
81+
tokenizer_only=tokenizer_only,
82+
meta_template=meta_template)
83+
self.logger = get_logger()
84+
self.batch_size = batch_size
85+
self.pad_token_id = pad_token_id
86+
self.pretrained_model_path = path
87+
if mode not in ['none', 'mid']:
88+
raise ValueError(f"mode must be 'none' or 'mid', but got {mode}")
89+
self.mode = mode
90+
if not yaml_cfg_file:
91+
raise ValueError('`yaml_cfg_file` is required for MindFormerModel')
92+
self.config = MindFormerConfig(yaml_cfg_file)
93+
self.checkpoint = checkpoint
94+
self._load_tokenizer(path=path,
95+
tokenizer_path=tokenizer_path,
96+
tokenizer_kwargs=tokenizer_kwargs)
97+
self.batch_padding = batch_padding
98+
self.extract_pred_after_decode = extract_pred_after_decode
99+
if not tokenizer_only:
100+
self._load_model(self.config, self.batch_size, self.max_seq_len)
101+
self.generation_kwargs = generation_kwargs
102+
self.use_fastchat_template = use_fastchat_template
103+
self.end_str = end_str
104+
105+
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
106+
tokenizer_kwargs: dict):
107+
from transformers import AutoTokenizer, GenerationConfig
108+
109+
DEFAULT_TOKENIZER_KWARGS = dict(padding_side='left', truncation_side='left', trust_remote_code=True)
110+
kwargs = DEFAULT_TOKENIZER_KWARGS.copy()
111+
kwargs.update(tokenizer_kwargs)
112+
113+
load_path = tokenizer_path if tokenizer_path else path
114+
self.tokenizer = AutoTokenizer.from_pretrained(load_path, **kwargs)
115+
116+
pad_token_id = self.pad_token_id
117+
118+
# A patch for some models without pad_token_id
119+
if pad_token_id is not None:
120+
if self.tokenizer.pad_token_id is None:
121+
self.logger.debug(f'Using {pad_token_id} as pad_token_id')
122+
elif self.tokenizer.pad_token_id != pad_token_id:
123+
self.logger.warning(f'pad_token_id is not consistent. Using {pad_token_id} as pad_token_id')
124+
self.tokenizer.pad_token_id = pad_token_id
125+
return
126+
if self.tokenizer.pad_token_id is not None:
127+
return
128+
self.logger.warning('pad_token_id is not set for the tokenizer.')
129+
130+
try:
131+
generation_config = GenerationConfig.from_pretrained(path)
132+
except Exception:
133+
generation_config = None
134+
135+
if generation_config and generation_config.pad_token_id is not None:
136+
self.logger.warning(f'Using {generation_config.pad_token_id} as pad_token_id.')
137+
self.tokenizer.pad_token_id = generation_config.pad_token_id
138+
return
139+
if self.tokenizer.eos_token_id is not None:
140+
self.logger.warning(f'Using eos_token_id {self.tokenizer.eos_token_id} as pad_token_id.')
141+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
142+
return
143+
raise ValueError('pad_token_id is not set for this tokenizer. Please set `pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
144+
145+
def _set_config_from_yaml(self):
146+
if self.checkpoint is not None:
147+
self.config.load_checkpoint = self.checkpoint
148+
elif self.checkpoint is None and self.config.load_checkpoint is None:
149+
self.config.load_checkpoint = self.path
150+
self.config.model.pretrained_model_dir = self.pretrained_model_path
151+
self.config.model.model_config.seq_length = self.max_seq_len
152+
build_context(self.config)
153+
build_parallel_config(self.config)
154+
155+
def _load_model(self, config, batch_size, max_seq_len):
156+
157+
self._set_config_from_yaml()
158+
try:
159+
self.model = build_network(
160+
config.model,
161+
default_args={
162+
"parallel_config": config.parallel_config,
163+
"moe_config": config.moe_config
164+
})
165+
self.logger.info("..........Network Built Successfully..........")
166+
self.model.set_train(False)
167+
config.load_checkpoint = get_load_path_after_hf_convert(config, self.model)
168+
self.logger.info(f"load checkpoint path : {config.load_checkpoint}")
169+
run_mode = config.get("run_mode", None)
170+
if run_mode == "predict":
171+
self.model.load_weights(config.load_checkpoint)
172+
else:
173+
model = Model(self.model)
174+
input_ids = Tensor(np.ones((batch_size, max_seq_len), dtype=np.int32))
175+
infer_data = self.model.prepare_inputs_for_predict_layout(input_ids)
176+
transform_and_load_checkpoint(config, model, self.model, infer_data, do_eval=True)
177+
178+
self.logger.info("..........Checkpoint Load Successfully..........")
179+
except ValueError as e:
180+
raise ValueError('Failed to load MindFormers model, please check configuration') from e
181+
182+
183+
def generate(self,
184+
inputs: List[str],
185+
max_out_len: int,
186+
min_out_len: Optional[int] = None,
187+
stopping_criteria: List[str] = [],
188+
**kwargs) -> List[str]:
189+
"""Generate results given a list of inputs.
190+
191+
Args:
192+
inputs (List[str]): A list of strings.
193+
max_out_len (int): The maximum length of the output.
194+
min_out_len (Optional[int]): The minimum length of the output.
195+
196+
Returns:
197+
List[str]: A list of generated strings.
198+
"""
199+
generation_kwargs = kwargs.copy()
200+
generation_kwargs.update(self.generation_kwargs)
201+
202+
messages = list(inputs)
203+
batch_size = len(messages)
204+
prompt_char_lens = None
205+
206+
if self.extract_pred_after_decode:
207+
prompt_char_lens = [len(text) for text in messages]
208+
209+
if self.use_fastchat_template:
210+
try:
211+
from fastchat.model import get_conversation_template
212+
except ModuleNotFoundError:
213+
raise ModuleNotFoundError(
214+
'Fastchat is not implemented. You can use '
215+
"'pip install \"fschat[model_worker,webui]\"' "
216+
'to implement fastchat.')
217+
for idx, text in enumerate(messages):
218+
conv = get_conversation_template('vicuna')
219+
conv.append_message(conv.roles[0], text)
220+
conv.append_message(conv.roles[1], None)
221+
messages[idx] = conv.get_prompt()
222+
if self.mode == 'mid':
223+
assert len(messages) == 1
224+
tokens = self.tokenizer(messages, padding=False, truncation=False, return_tensors='np')
225+
input_ids = tokens['input_ids']
226+
if input_ids.shape[-1] > self.max_seq_len:
227+
input_ids = np.concatenate([input_ids[:, : self.max_seq_len // 2], input_ids[:, - self.max_seq_len // 2:]], axis=-1)
228+
tokens = {'input_ids': input_ids}
229+
else:
230+
tokenize_kwargs = dict(
231+
padding=True,
232+
truncation=True,
233+
max_length=self.max_seq_len,
234+
return_tensors='np'
235+
)
236+
tokens = self.tokenizer(messages, **tokenize_kwargs)
237+
238+
input_ids = tokens['input_ids']
239+
if len(messages) > 1:
240+
attention_mask = tokens.get('attention_mask')
241+
prompt_token_lens = (
242+
attention_mask.sum(axis=1).astype(int).tolist()
243+
if attention_mask is not None else
244+
[input_ids.shape[1]] * batch_size
245+
)
246+
else:
247+
prompt_token_lens = [len(ids) for ids in input_ids]
248+
249+
input_ids_tensor = Tensor(input_ids)
250+
251+
if min_out_len is not None:
252+
generation_kwargs['min_new_tokens'] = min_out_len
253+
generation_kwargs['max_new_tokens'] = max_out_len
254+
generation_kwargs.setdefault('top_k', 1)
255+
generation_kwargs.setdefault('return_dict_in_generate', False)
256+
257+
origin_stopping_criteria = list(stopping_criteria)
258+
if stopping_criteria:
259+
if self.tokenizer.eos_token is not None:
260+
stopping_criteria = stopping_criteria + [
261+
self.tokenizer.eos_token
262+
]
263+
stopping_list = transformers.StoppingCriteriaList([
264+
*[
265+
MultiTokenEOSCriteria(sequence, self.tokenizer,
266+
input_ids_tensor.shape[0])
267+
for sequence in stopping_criteria
268+
],
269+
])
270+
generation_kwargs['stopping_criteria'] = stopping_list
271+
272+
generation_kwargs = drop_error_generation_kwargs(generation_kwargs)
273+
274+
outputs = self.model.generate(input_ids=input_ids_tensor,
275+
**generation_kwargs)
276+
277+
if isinstance(outputs, dict):
278+
outputs = outputs.get('sequences', outputs)
279+
if outputs is None:
280+
raise ValueError("Model output dictionary is missing 'sequence' key.")
281+
282+
sequences = [seq.tolist() for seq in outputs]
283+
284+
if not self.extract_pred_after_decode:
285+
sequences = [
286+
seq[prompt_len:]
287+
for seq, prompt_len in zip(sequences, prompt_token_lens)
288+
]
289+
290+
decodeds = [
291+
self.tokenizer.decode(seq, skip_special_tokens=True)
292+
for seq in sequences
293+
]
294+
295+
if self.extract_pred_after_decode and prompt_char_lens is not None:
296+
decodeds = [
297+
text[length:]
298+
for text, length in zip(decodeds, prompt_char_lens)
299+
]
300+
301+
if self.end_str:
302+
decodeds = [text.split(self.end_str)[0] for text in decodeds]
303+
if origin_stopping_criteria:
304+
for token in origin_stopping_criteria:
305+
decodeds = [text.split(token)[0] for text in decodeds]
306+
return decodeds

0 commit comments

Comments
 (0)