|
| 1 | +import os, sys |
| 2 | +from typing import Dict, List, Optional, Union |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | +import transformers |
| 7 | + |
| 8 | +from ais_bench.benchmark.models.base import BaseModel |
| 9 | +from ais_bench.benchmark.models.base_api import APITemplateParser |
| 10 | +from ais_bench.benchmark.registry import MODELS |
| 11 | +from ais_bench.benchmark.utils.logging import get_logger |
| 12 | +from ais_bench.benchmark.utils.prompt import PromptList |
| 13 | + |
| 14 | +from mindspore import Tensor, Model |
| 15 | +from mindformers import MindFormerConfig, build_context |
| 16 | +from mindformers.models import build_network |
| 17 | +from mindformers.core.parallel_config import build_parallel_config |
| 18 | +from mindformers.utils.load_checkpoint_utils import get_load_path_after_hf_convert |
| 19 | +from mindformers.trainer.utils import transform_and_load_checkpoint |
| 20 | + |
| 21 | +PromptType = Union[PromptList, str, dict] |
| 22 | + |
| 23 | + |
| 24 | +class MultiTokenEOSCriteria(transformers.StoppingCriteria): |
| 25 | + """Criteria to stop on the specified multi-token sequence.""" |
| 26 | + |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + sequence: str, |
| 30 | + tokenizer: transformers.PreTrainedTokenizer, |
| 31 | + batch_size: int, |
| 32 | + ): |
| 33 | + self.done_tracker = [False] * batch_size |
| 34 | + self.sequence = sequence |
| 35 | + self.sequence_ids = tokenizer.encode(sequence, |
| 36 | + add_special_tokens=False) |
| 37 | + self.sequence_id_len = len(self.sequence_ids) |
| 38 | + self.tokenizer = tokenizer |
| 39 | + |
| 40 | + def __call__(self, input_ids, scores, **kwargs) -> bool: |
| 41 | + # compare the last len(stop) tokens |
| 42 | + lookback_ids_batch = input_ids[:, -self.sequence_id_len:] |
| 43 | + lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) |
| 44 | + for i, done in enumerate(self.done_tracker): |
| 45 | + if done: |
| 46 | + continue |
| 47 | + self.done_tracker[i] = self.sequence in lookback_tokens_batch[i] |
| 48 | + return False not in self.done_tracker |
| 49 | + |
| 50 | + |
| 51 | +def drop_error_generation_kwargs(generation_kwargs: dict) -> dict: |
| 52 | + for key in ['is_synthetic', 'batch_size', 'do_performance']: |
| 53 | + if key in generation_kwargs: |
| 54 | + generation_kwargs.pop(key) |
| 55 | + return generation_kwargs |
| 56 | + |
| 57 | + |
| 58 | +@MODELS.register_module() |
| 59 | +class MindFormerModel(BaseModel): |
| 60 | + |
| 61 | + def __init__(self, |
| 62 | + path: str, |
| 63 | + checkpoint: Optional[str] = None, |
| 64 | + yaml_cfg_file: Optional[str] = None, |
| 65 | + batch_size: int = 1, |
| 66 | + max_seq_len: int = 2048, |
| 67 | + tokenizer_path: Optional[str] = None, |
| 68 | + tokenizer_kwargs: dict = dict(), |
| 69 | + tokenizer_only: bool = False, |
| 70 | + generation_kwargs: dict = dict(), |
| 71 | + meta_template: Optional[Dict] = None, |
| 72 | + extract_pred_after_decode: bool = False, |
| 73 | + batch_padding: bool = False, |
| 74 | + pad_token_id: Optional[int] = None, |
| 75 | + mode: str = 'none', |
| 76 | + use_fastchat_template: bool = False, |
| 77 | + end_str: Optional[str] = None, |
| 78 | + **kwargs): |
| 79 | + super().__init__(path=path, |
| 80 | + max_seq_len=max_seq_len, |
| 81 | + tokenizer_only=tokenizer_only, |
| 82 | + meta_template=meta_template) |
| 83 | + self.logger = get_logger() |
| 84 | + self.batch_size = batch_size |
| 85 | + self.pad_token_id = pad_token_id |
| 86 | + self.pretrained_model_path = path |
| 87 | + if mode not in ['none', 'mid']: |
| 88 | + raise ValueError(f"mode must be 'none' or 'mid', but got {mode}") |
| 89 | + self.mode = mode |
| 90 | + if not yaml_cfg_file: |
| 91 | + raise ValueError('`yaml_cfg_file` is required for MindFormerModel') |
| 92 | + self.config = MindFormerConfig(yaml_cfg_file) |
| 93 | + self.checkpoint = checkpoint |
| 94 | + self._load_tokenizer(path=path, |
| 95 | + tokenizer_path=tokenizer_path, |
| 96 | + tokenizer_kwargs=tokenizer_kwargs) |
| 97 | + self.batch_padding = batch_padding |
| 98 | + self.extract_pred_after_decode = extract_pred_after_decode |
| 99 | + if not tokenizer_only: |
| 100 | + self._load_model(self.config, self.batch_size, self.max_seq_len) |
| 101 | + self.generation_kwargs = generation_kwargs |
| 102 | + self.use_fastchat_template = use_fastchat_template |
| 103 | + self.end_str = end_str |
| 104 | + |
| 105 | + def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], |
| 106 | + tokenizer_kwargs: dict): |
| 107 | + from transformers import AutoTokenizer, GenerationConfig |
| 108 | + |
| 109 | + DEFAULT_TOKENIZER_KWARGS = dict(padding_side='left', truncation_side='left', trust_remote_code=True) |
| 110 | + kwargs = DEFAULT_TOKENIZER_KWARGS.copy() |
| 111 | + kwargs.update(tokenizer_kwargs) |
| 112 | + |
| 113 | + load_path = tokenizer_path if tokenizer_path else path |
| 114 | + self.tokenizer = AutoTokenizer.from_pretrained(load_path, **kwargs) |
| 115 | + |
| 116 | + pad_token_id = self.pad_token_id |
| 117 | + |
| 118 | + # A patch for some models without pad_token_id |
| 119 | + if pad_token_id is not None: |
| 120 | + if self.tokenizer.pad_token_id is None: |
| 121 | + self.logger.debug(f'Using {pad_token_id} as pad_token_id') |
| 122 | + elif self.tokenizer.pad_token_id != pad_token_id: |
| 123 | + self.logger.warning(f'pad_token_id is not consistent. Using {pad_token_id} as pad_token_id') |
| 124 | + self.tokenizer.pad_token_id = pad_token_id |
| 125 | + return |
| 126 | + if self.tokenizer.pad_token_id is not None: |
| 127 | + return |
| 128 | + self.logger.warning('pad_token_id is not set for the tokenizer.') |
| 129 | + |
| 130 | + try: |
| 131 | + generation_config = GenerationConfig.from_pretrained(path) |
| 132 | + except Exception: |
| 133 | + generation_config = None |
| 134 | + |
| 135 | + if generation_config and generation_config.pad_token_id is not None: |
| 136 | + self.logger.warning(f'Using {generation_config.pad_token_id} as pad_token_id.') |
| 137 | + self.tokenizer.pad_token_id = generation_config.pad_token_id |
| 138 | + return |
| 139 | + if self.tokenizer.eos_token_id is not None: |
| 140 | + self.logger.warning(f'Using eos_token_id {self.tokenizer.eos_token_id} as pad_token_id.') |
| 141 | + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
| 142 | + return |
| 143 | + raise ValueError('pad_token_id is not set for this tokenizer. Please set `pad_token_id={PAD_TOKEN_ID}` in model_cfg.') |
| 144 | + |
| 145 | + def _set_config_from_yaml(self): |
| 146 | + if self.checkpoint is not None: |
| 147 | + self.config.load_checkpoint = self.checkpoint |
| 148 | + elif self.checkpoint is None and self.config.load_checkpoint is None: |
| 149 | + self.config.load_checkpoint = self.path |
| 150 | + self.config.model.pretrained_model_dir = self.pretrained_model_path |
| 151 | + self.config.model.model_config.seq_length = self.max_seq_len |
| 152 | + build_context(self.config) |
| 153 | + build_parallel_config(self.config) |
| 154 | + |
| 155 | + def _load_model(self, config, batch_size, max_seq_len): |
| 156 | + |
| 157 | + self._set_config_from_yaml() |
| 158 | + try: |
| 159 | + self.model = build_network( |
| 160 | + config.model, |
| 161 | + default_args={ |
| 162 | + "parallel_config": config.parallel_config, |
| 163 | + "moe_config": config.moe_config |
| 164 | + }) |
| 165 | + self.logger.info("..........Network Built Successfully..........") |
| 166 | + self.model.set_train(False) |
| 167 | + config.load_checkpoint = get_load_path_after_hf_convert(config, self.model) |
| 168 | + self.logger.info(f"load checkpoint path : {config.load_checkpoint}") |
| 169 | + run_mode = config.get("run_mode", None) |
| 170 | + if run_mode == "predict": |
| 171 | + self.model.load_weights(config.load_checkpoint) |
| 172 | + else: |
| 173 | + model = Model(self.model) |
| 174 | + input_ids = Tensor(np.ones((batch_size, max_seq_len), dtype=np.int32)) |
| 175 | + infer_data = self.model.prepare_inputs_for_predict_layout(input_ids) |
| 176 | + transform_and_load_checkpoint(config, model, self.model, infer_data, do_eval=True) |
| 177 | + |
| 178 | + self.logger.info("..........Checkpoint Load Successfully..........") |
| 179 | + except ValueError as e: |
| 180 | + raise ValueError('Failed to load MindFormers model, please check configuration') from e |
| 181 | + |
| 182 | + |
| 183 | + def generate(self, |
| 184 | + inputs: List[str], |
| 185 | + max_out_len: int, |
| 186 | + min_out_len: Optional[int] = None, |
| 187 | + stopping_criteria: List[str] = [], |
| 188 | + **kwargs) -> List[str]: |
| 189 | + """Generate results given a list of inputs. |
| 190 | +
|
| 191 | + Args: |
| 192 | + inputs (List[str]): A list of strings. |
| 193 | + max_out_len (int): The maximum length of the output. |
| 194 | + min_out_len (Optional[int]): The minimum length of the output. |
| 195 | +
|
| 196 | + Returns: |
| 197 | + List[str]: A list of generated strings. |
| 198 | + """ |
| 199 | + generation_kwargs = kwargs.copy() |
| 200 | + generation_kwargs.update(self.generation_kwargs) |
| 201 | + |
| 202 | + messages = list(inputs) |
| 203 | + batch_size = len(messages) |
| 204 | + prompt_char_lens = None |
| 205 | + |
| 206 | + if self.extract_pred_after_decode: |
| 207 | + prompt_char_lens = [len(text) for text in messages] |
| 208 | + |
| 209 | + if self.use_fastchat_template: |
| 210 | + try: |
| 211 | + from fastchat.model import get_conversation_template |
| 212 | + except ModuleNotFoundError: |
| 213 | + raise ModuleNotFoundError( |
| 214 | + 'Fastchat is not implemented. You can use ' |
| 215 | + "'pip install \"fschat[model_worker,webui]\"' " |
| 216 | + 'to implement fastchat.') |
| 217 | + for idx, text in enumerate(messages): |
| 218 | + conv = get_conversation_template('vicuna') |
| 219 | + conv.append_message(conv.roles[0], text) |
| 220 | + conv.append_message(conv.roles[1], None) |
| 221 | + messages[idx] = conv.get_prompt() |
| 222 | + if self.mode == 'mid': |
| 223 | + assert len(messages) == 1 |
| 224 | + tokens = self.tokenizer(messages, padding=False, truncation=False, return_tensors='np') |
| 225 | + input_ids = tokens['input_ids'] |
| 226 | + if input_ids.shape[-1] > self.max_seq_len: |
| 227 | + input_ids = np.concatenate([input_ids[:, : self.max_seq_len // 2], input_ids[:, - self.max_seq_len // 2:]], axis=-1) |
| 228 | + tokens = {'input_ids': input_ids} |
| 229 | + else: |
| 230 | + tokenize_kwargs = dict( |
| 231 | + padding=True, |
| 232 | + truncation=True, |
| 233 | + max_length=self.max_seq_len, |
| 234 | + return_tensors='np' |
| 235 | + ) |
| 236 | + tokens = self.tokenizer(messages, **tokenize_kwargs) |
| 237 | + |
| 238 | + input_ids = tokens['input_ids'] |
| 239 | + if len(messages) > 1: |
| 240 | + attention_mask = tokens.get('attention_mask') |
| 241 | + prompt_token_lens = ( |
| 242 | + attention_mask.sum(axis=1).astype(int).tolist() |
| 243 | + if attention_mask is not None else |
| 244 | + [input_ids.shape[1]] * batch_size |
| 245 | + ) |
| 246 | + else: |
| 247 | + prompt_token_lens = [len(ids) for ids in input_ids] |
| 248 | + |
| 249 | + input_ids_tensor = Tensor(input_ids) |
| 250 | + |
| 251 | + if min_out_len is not None: |
| 252 | + generation_kwargs['min_new_tokens'] = min_out_len |
| 253 | + generation_kwargs['max_new_tokens'] = max_out_len |
| 254 | + generation_kwargs.setdefault('top_k', 1) |
| 255 | + generation_kwargs.setdefault('return_dict_in_generate', False) |
| 256 | + |
| 257 | + origin_stopping_criteria = list(stopping_criteria) |
| 258 | + if stopping_criteria: |
| 259 | + if self.tokenizer.eos_token is not None: |
| 260 | + stopping_criteria = stopping_criteria + [ |
| 261 | + self.tokenizer.eos_token |
| 262 | + ] |
| 263 | + stopping_list = transformers.StoppingCriteriaList([ |
| 264 | + *[ |
| 265 | + MultiTokenEOSCriteria(sequence, self.tokenizer, |
| 266 | + input_ids_tensor.shape[0]) |
| 267 | + for sequence in stopping_criteria |
| 268 | + ], |
| 269 | + ]) |
| 270 | + generation_kwargs['stopping_criteria'] = stopping_list |
| 271 | + |
| 272 | + generation_kwargs = drop_error_generation_kwargs(generation_kwargs) |
| 273 | + |
| 274 | + outputs = self.model.generate(input_ids=input_ids_tensor, |
| 275 | + **generation_kwargs) |
| 276 | + |
| 277 | + if isinstance(outputs, dict): |
| 278 | + outputs = outputs.get('sequences', outputs) |
| 279 | + if outputs is None: |
| 280 | + raise ValueError("Model output dictionary is missing 'sequence' key.") |
| 281 | + |
| 282 | + sequences = [seq.tolist() for seq in outputs] |
| 283 | + |
| 284 | + if not self.extract_pred_after_decode: |
| 285 | + sequences = [ |
| 286 | + seq[prompt_len:] |
| 287 | + for seq, prompt_len in zip(sequences, prompt_token_lens) |
| 288 | + ] |
| 289 | + |
| 290 | + decodeds = [ |
| 291 | + self.tokenizer.decode(seq, skip_special_tokens=True) |
| 292 | + for seq in sequences |
| 293 | + ] |
| 294 | + |
| 295 | + if self.extract_pred_after_decode and prompt_char_lens is not None: |
| 296 | + decodeds = [ |
| 297 | + text[length:] |
| 298 | + for text, length in zip(decodeds, prompt_char_lens) |
| 299 | + ] |
| 300 | + |
| 301 | + if self.end_str: |
| 302 | + decodeds = [text.split(self.end_str)[0] for text in decodeds] |
| 303 | + if origin_stopping_criteria: |
| 304 | + for token in origin_stopping_criteria: |
| 305 | + decodeds = [text.split(token)[0] for text in decodeds] |
| 306 | + return decodeds |
0 commit comments