-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathtest_model.py
More file actions
65 lines (59 loc) · 3.15 KB
/
test_model.py
File metadata and controls
65 lines (59 loc) · 3.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import torch
import argparse
import json
from autosmoothquant.models import (Int8LlamaForCausalLM, Int8OPTForCausalLM,
Int8BaichuanForCausalLM, Int8MixtralForCausalLM,
Int8PhiForCausalLM,Int8Qwen2ForCausalLM)
from autosmoothquant.utils import parse_quant_config
from transformers import AutoTokenizer
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model-path', type=str,
default='int8-models/llama-13b', help='path contains model weight and quant config')
parser.add_argument('--tokenizer-path', type=str,
default='int8-models/llama-13b', help='path contains tokenizer')
parser.add_argument('--model-class', type=str,
default='llama', help='currently support: llama, baichuan, opt, mixtral')
parser.add_argument('--prompt', type=str,
default='You are right, But Genshin Impact is', help='prompts')
args = parser.parse_args()
return args
@torch.no_grad()
def main():
args = parse_args()
config_path = os.path.join(args.model_path, "quant_config.json")
quant_config = parse_quant_config(config_path)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Except GEMM uses int8, the default data type is torch.float32 for quant now.
# Consider setting the default data type to torch.float16 to speed up, but this may decrease model performance.
# torch.set_default_dtype(torch.float16)
if args.model_class == "llama":
model = Int8LlamaForCausalLM.from_pretrained(args.model_path, quant_config, attn_implementation="eager", device_map="sequential")
elif args.model_class == "baichuan":
model = Int8BaichuanForCausalLM.from_pretrained(args.model_path, quant_config, attn_implementation="eager", device_map="sequential")
elif args.model_class == "opt":
model = Int8OPTForCausalLM.from_pretrained(args.model_path, quant_config, attn_implementation="eager", device_map="sequential")
elif args.model_class == "mixtral":
model = Int8MixtralForCausalLM.from_pretrained(args.model_path, quant_config, attn_implementation="eager", device_map="sequential")
elif args.model_class == "phi2":
model = Int8PhiForCausalLM.from_pretrained(args.model_path, quant_config, attn_implementation="eager",
device_map="sequential")
elif args.model_class == "qwen2":
model = Int8Qwen2ForCausalLM.from_pretrained(args.model_path, quant_config, attn_implementation="eager",
device_map="sequential")
else:
raise ValueError(
f"Model type {args.model_class} are not supported for now.")
inputs = tokenizer(
args.prompt,
padding=True,
truncation=True,
max_length=2048,
return_tensors="pt").to("cuda")
output_ids = model.generate(**inputs, max_new_tokens=20)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
print(outputs)
if __name__ == '__main__':
main()