-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathinterface_example.py
More file actions
154 lines (126 loc) · 4.85 KB
/
Copy pathinterface_example.py
File metadata and controls
154 lines (126 loc) · 4.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Copyright (c) EfficientMoE.
# SPDX-License-Identifier: Apache-2.0
# EfficientMoE Team
import argparse
import multiprocessing as mp
import os
import time
import warnings
from functools import partial
warnings.filterwarnings("ignore")
import datasets
import torch
from transformers import AutoTokenizer, LlamaTokenizerFast, TextStreamer
from moe_infinity import MoE
class StopWatch(TextStreamer):
def __init__(self, engine, *args, **kwargs):
super().__init__(*args, **kwargs)
self.start_prefilling = None
self.prefilling_time = None
self.start_decoding = None
self.decoding_time = None
self.decoding_iterations = 0
self.engine = engine
def put(self, value):
if self.start_prefilling is None:
self.start_prefilling = time.time()
return
elif self.prefilling_time is None:
self.prefilling_time = time.time() - self.start_prefilling
self.engine.expert_dispatcher.clear_expert_cache_counts()
self.start_decoding = time.time()
self.decoding_iterations += 1
if self.decoding_iterations % 100 == 0:
current_time = time.time()
print(f"Prefilling time: {self.prefilling_time} seconds")
print(f"Decoding time: {self.decoding_time} seconds")
print(f"Decoding iterations: {self.decoding_iterations}")
print(
f"Decoding time per iteration: {(current_time - self.start_decoding) / self.decoding_iterations} seconds"
)
return super().put(value)
def end(self):
if self.decoding_time is None and self.start_decoding is not None:
self.decoding_time = time.time() - self.start_decoding
return super().end()
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, required=True)
parser.add_argument("--offload_dir", type=str, required=True)
parser.add_argument("--device_memory_ratio", type=float, default=0.85)
parser.add_argument("--out_len", type=int, default=32)
args = parser.parse_args()
model_name = args.model_name_or_path.split("/")[-1]
config = {
"offload_path": os.path.join(args.offload_dir, model_name),
"device_memory_ratio": args.device_memory_ratio,
}
model = MoE(args.model_name_or_path, config)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path, trust_remote_code=True, use_fast=False
)
dataset = datasets.load_dataset("openai/gsm8k", "main", split="test")
all_inputs = dataset["question"]
# dataset_name = "openai/gsm8k"
# names = datasets.get_dataset_config_names(dataset_name)
# pool = mp.Pool(mp.cpu_count())
# all_inputs = [None] * len(names)
# all_inputs = pool.map(partial(datasets.load_dataset, dataset_name), names)
# print(all_inputs)
# text_list = []
# for dataset in all_inputs:
# if "test" not in dataset:
# continue
# for i, text in enumerate(dataset["test"]["question"]):
# text_list.append(text)
# print(len(text_list))
# all_inputs = text_list
custom_kwargs = {}
if "nllb" in args.model_name_or_path.lower():
custom_kwargs = {"forced_bos_token_id": 256057} # translate to French
elif "mixtral" in args.model_name_or_path.lower():
custom_kwargs = {"pad_token_id": tokenizer.eos_token_id}
elif (
"deepseek" in args.model_name_or_path.lower()
or "qwen3" in args.model_name_or_path.lower()
):
custom_kwargs = {"pad_token_id": tokenizer.eos_token_id}
else:
raise ValueError(f"Model {args.model_name_or_path} not supported")
tokenizer.pad_token = tokenizer.eos_token
cnt = 0
max_seq_length = 512
for input_text in all_inputs:
prompt = tokenizer.apply_chat_template(
conversation=[
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": input_text,
},
],
tokenize=False,
add_generation_prompt=True,
)
print(f"prompt: {prompt}")
token_ids = tokenizer.encode(prompt, return_tensors="pt")
token_ids = token_ids.to("cuda:0")
streamer = StopWatch(model.engine, tokenizer)
with torch.no_grad():
print("outputs_text ...")
outputs = model.generate(
token_ids.to("cuda:0"),
# attention_mask=inputs.attention_mask.to("cuda:0"),
streamer=streamer,
max_new_tokens=args.out_len,
min_new_tokens=args.out_len,
do_sample=False,
# use_cache=False,
**custom_kwargs,
)
print(f"Prefilling time: {streamer.prefilling_time} seconds")
print(f"Decoding time: {streamer.decoding_time} seconds")
print(f"Decoding iterations: {streamer.decoding_iterations}")
print(
f"Decoding time per iteration: {streamer.decoding_time / streamer.decoding_iterations} seconds"
)
# print(f"Input tokens: {len(inputs.input_ids[0])}")