Skip to content

Commit 42501e0

Browse files
author
李宜杰
committed
local eval add mindformers model
1 parent 2d794cf commit 42501e0

5 files changed

Lines changed: 586 additions & 9 deletions

File tree

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from ais_bench.benchmark.models import MindFormerModel
2+
3+
models = [
4+
dict(
5+
attr="local", # local or service
6+
type=MindFormerModel, # transformers < 4.33.0 用这个,优先AutoModelForCausalLM.from_pretrained加载模型,失败则用AutoModel.from_pretrained加载
7+
abbr='mindformer-model',
8+
path='THUDM/chatglm-6b', # path to model dir, current value is just a example
9+
checkpoint = 'THUDM/your_checkpoint', # path to checkpoint file, current value is just a example
10+
yaml_cfg_file = 'THUDM/your.yaml',
11+
tokenizer_path='THUDM/chatglm-6b', # path to tokenizer dir, current value is just a example
12+
model_kwargs=dict( # 模型参数参考 huggingface.co/docs/transformers/v4.50.0/en/model_doc/auto#transformers.AutoModel.from_pretrained
13+
device_map='npu',
14+
),
15+
tokenizer_kwargs=dict( # tokenizer参数参考 huggingface.co/docs/transformers/v4.50.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase
16+
padding_side='right',
17+
),
18+
generation_kwargs = dict( # 后处理参数参考huggingface.co/docs/transformers/main_classes/test_generation
19+
temperature = 0.5,
20+
top_k = 10,
21+
top_p = 0.95,
22+
do_sample = True,
23+
seed = None,
24+
repetition_penalty = 1.03,
25+
),
26+
run_cfg = dict(num_gpus=1, num_procs=1), # 多卡/多机多卡 参数,使用torchrun拉起任务
27+
max_out_len=100, # 最大输出token长度
28+
batch_size=2, # 每次推理的batch size
29+
max_seq_len=2048,
30+
batch_padding=True,
31+
)
32+
]

ais_bench/benchmark/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
from ais_bench.benchmark.models.api_models.triton_api import TritonCustomAPIStream # noqa: F401
1515
from ais_bench.benchmark.models.api_models.tgi_api import TGICustomAPIStream # noqa: F401
1616
from ais_bench.benchmark.models.api_models.vllm_custom_api_chat import VllmMultiturnAPIChatStream # noqa: F401
17-
from ais_bench.benchmark.models.local_models.vllm_offline_vl import VLLMOfflineVLModel
17+
from ais_bench.benchmark.models.local_models.vllm_offline_vl import VLLMOfflineVLModel
18+
from ais_bench.benchmark.models.local_models.mindformers_model import MindFormerModel

0 commit comments

Comments
 (0)