|
3 | 3 | from infinilm.modeling_utils import load_model_state_dict_by_file |
4 | 4 | from infinilm.distributed import DistConfig |
5 | 5 | from infinilm.infer_engine import GenerationConfig, InferEngine |
| 6 | +from infinilm.base_config import BaseConfig |
6 | 7 | from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig |
7 | 8 | import argparse |
8 | 9 | import sys |
@@ -125,150 +126,6 @@ def get_test_cases( |
125 | 126 |
|
126 | 127 | return case_dict |
127 | 128 |
|
128 | | - |
129 | | -def get_args(): |
130 | | - parser = argparse.ArgumentParser(description="run Llama args") |
131 | | - |
132 | | - parser.add_argument( |
133 | | - "--cpu", |
134 | | - action="store_true", |
135 | | - help="Run cpu test", |
136 | | - ) |
137 | | - parser.add_argument( |
138 | | - "--nvidia", |
139 | | - action="store_true", |
140 | | - help="Run nvidia test", |
141 | | - ) |
142 | | - parser.add_argument( |
143 | | - "--qy", |
144 | | - action="store_true", |
145 | | - help="Run qy test", |
146 | | - ) |
147 | | - parser.add_argument( |
148 | | - "--metax", |
149 | | - action="store_true", |
150 | | - help="Run metax test", |
151 | | - ) |
152 | | - parser.add_argument( |
153 | | - "--moore", |
154 | | - action="store_true", |
155 | | - help="Run moore test", |
156 | | - ) |
157 | | - parser.add_argument( |
158 | | - "--iluvatar", |
159 | | - action="store_true", |
160 | | - help="Run iluvatar test", |
161 | | - ) |
162 | | - parser.add_argument( |
163 | | - "--cambricon", |
164 | | - action="store_true", |
165 | | - help="Run cambricon test", |
166 | | - ) |
167 | | - parser.add_argument( |
168 | | - "--ali", |
169 | | - action="store_true", |
170 | | - help="Run alippu test", |
171 | | - ) |
172 | | - parser.add_argument( |
173 | | - "--hygon", |
174 | | - action="store_true", |
175 | | - help="Run hygon test", |
176 | | - ) |
177 | | - parser.add_argument( |
178 | | - "--model", |
179 | | - type=str, |
180 | | - required=True, |
181 | | - help="model path", |
182 | | - ) |
183 | | - parser.add_argument( |
184 | | - "--batch-size", |
185 | | - type=parse_list, |
186 | | - default=1, |
187 | | - help="number of prompts in a batch (can be an int or a list of ints, e.g., '1' or '[1,2,4]' or '1,2,4')", |
188 | | - ) |
189 | | - parser.add_argument( |
190 | | - "--tensor-parallel-size", |
191 | | - "--tp", |
192 | | - type=int, |
193 | | - default=1, |
194 | | - help="total rank for tensor parallel", |
195 | | - ) |
196 | | - parser.add_argument( |
197 | | - "--input-len", |
198 | | - type=parse_list, |
199 | | - default=10, |
200 | | - help="output tokens", |
201 | | - ) |
202 | | - |
203 | | - parser.add_argument( |
204 | | - "--output-len", |
205 | | - type=parse_list, |
206 | | - default=20, |
207 | | - help="output tokens", |
208 | | - ) |
209 | | - parser.add_argument( |
210 | | - "--skip-load", |
211 | | - action="store_true", |
212 | | - help="skip loading model weights", |
213 | | - ) |
214 | | - parser.add_argument( |
215 | | - "--top-k", |
216 | | - type=int, |
217 | | - default=1, |
218 | | - help="top k sampling", |
219 | | - ) |
220 | | - |
221 | | - parser.add_argument( |
222 | | - "--top-p", |
223 | | - type=float, |
224 | | - default=1.0, |
225 | | - help="top p sampling", |
226 | | - ) |
227 | | - |
228 | | - parser.add_argument( |
229 | | - "--temperature", |
230 | | - type=float, |
231 | | - default=1.0, |
232 | | - help="sampling temperature", |
233 | | - ) |
234 | | - parser.add_argument( |
235 | | - "--enable-paged-attn", |
236 | | - action="store_true", |
237 | | - help="use paged cache", |
238 | | - ) |
239 | | - parser.add_argument( |
240 | | - "--paged-kv-block-size", |
241 | | - type=int, |
242 | | - default=256, |
243 | | - help="num tokens each kv block can hold", |
244 | | - ) |
245 | | - parser.add_argument( |
246 | | - "--enable-graph", |
247 | | - action="store_true", |
248 | | - help="enable graph compiling", |
249 | | - ) |
250 | | - parser.add_argument( |
251 | | - "--warmup", |
252 | | - action="store_true", |
253 | | - help="Perform a warmup run before benchmarking/inference.", |
254 | | - ) |
255 | | - parser.add_argument( |
256 | | - "--attn", |
257 | | - type=str, |
258 | | - default="default", |
259 | | - choices=["default", "paged-attn", "flash-attn"], |
260 | | - help="attention backend to use: 'default' or 'flash-attn'", |
261 | | - ) |
262 | | - parser.add_argument( |
263 | | - "--kv-cache-dtype", |
264 | | - type=str, |
265 | | - default=None, |
266 | | - choices=["int8"], |
267 | | - ) |
268 | | - |
269 | | - return parser.parse_args() |
270 | | - |
271 | | - |
272 | 129 | with open("examples/bench_prompt.md", "r") as f: |
273 | 130 | prompt = f.read() |
274 | 131 |
|
@@ -305,7 +162,7 @@ def __init__( |
305 | 162 | cache_config=cache_config, |
306 | 163 | enable_graph_compiling=enable_graph, |
307 | 164 | attention_backend=attn_backend, |
308 | | - kv_cache_dtype=args.kv_cache_dtype, |
| 165 | + kv_cache_dtype=cfg.kv_cache_dtype, |
309 | 166 | ) |
310 | 167 |
|
311 | 168 | # ---------------------------------------------------------------------------- # |
@@ -396,52 +253,28 @@ def run( |
396 | 253 |
|
397 | 254 |
|
398 | 255 | if __name__ == "__main__": |
399 | | - args = get_args() |
400 | | - print(args) |
401 | | - |
402 | | - # Parse command line arguments |
403 | | - device_str = "cpu" |
404 | | - if args.cpu: |
405 | | - device_str = "cpu" |
406 | | - elif args.nvidia: |
407 | | - device_str = "cuda" |
408 | | - elif args.qy: |
409 | | - device_str = "cuda" |
410 | | - elif args.metax: |
411 | | - device_str = "cuda" |
412 | | - elif args.moore: |
413 | | - device_str = "musa" |
414 | | - elif args.iluvatar: |
415 | | - device_str = "cuda" |
416 | | - elif args.cambricon: |
417 | | - device_str = "mlu" |
418 | | - elif args.ali: |
419 | | - device_str = "cuda" |
420 | | - elif args.hygon: |
421 | | - device_str = "cuda" |
422 | | - else: |
423 | | - print( |
424 | | - "python examples/bench.py --nvidia --model=~/TinyLlama-1.1B-Chat-v1.0/ --batch-size=2 --tp=1 --input-len=50 --output-len=50" |
425 | | - ) |
426 | | - sys.exit(1) |
427 | | - _PAGED_KV_BLOCK_SIZE = args.paged_kv_block_size |
| 256 | + cfg = BaseConfig() |
| 257 | + |
| 258 | + device_str = cfg.get_device_str(cfg.device) |
| 259 | + |
| 260 | + _PAGED_KV_BLOCK_SIZE = cfg.paged_kv_block_size |
428 | 261 | # -------------------------------------------------------- # |
429 | 262 | # 解析参数 |
430 | 263 | # -------------------------------------------------------- # |
431 | | - model_path = args.model |
| 264 | + model_path = cfg.model |
432 | 265 |
|
433 | 266 | infini_device = infinicore.device(device_str, 0) |
434 | 267 |
|
435 | | - tp = args.tensor_parallel_size |
| 268 | + tp = cfg.tp |
436 | 269 |
|
437 | | - skip_load = args.skip_load |
| 270 | + skip_load = cfg.skip_load |
438 | 271 |
|
439 | | - batch_size = args.batch_size |
440 | | - input_len = args.input_len |
441 | | - output_len = args.output_len |
442 | | - enable_paged_attn = args.enable_paged_attn |
443 | | - enable_graph = args.enable_graph |
444 | | - attn_backend = args.attn |
| 272 | + batch_size = cfg.batch_size |
| 273 | + input_len = cfg.input_len |
| 274 | + output_len = cfg.output_len |
| 275 | + enable_paged_attn = cfg.enable_paged_attn |
| 276 | + enable_graph = cfg.enable_graph |
| 277 | + attn_backend = cfg.attn |
445 | 278 |
|
446 | 279 | if isinstance(batch_size, int): |
447 | 280 | batch_size = [batch_size] |
@@ -488,7 +321,7 @@ def run( |
488 | 321 | # ---------------------------------------------------------------------------- # |
489 | 322 | # Warmup |
490 | 323 | # ---------------------------------------------------------------------------- # |
491 | | - if args.warmup: |
| 324 | + if cfg.warmup: |
492 | 325 | warmup_steps = 1 |
493 | 326 |
|
494 | 327 | # warmup cache capacity |
@@ -518,9 +351,9 @@ def run( |
518 | 351 | input_ids_infini, |
519 | 352 | GenerationConfig( |
520 | 353 | max_new_tokens=5, # decode kernel warmup |
521 | | - temperature=args.temperature, |
522 | | - top_k=args.top_k, |
523 | | - top_p=args.top_p, |
| 354 | + temperature=cfg.temperature, |
| 355 | + top_k=cfg.top_k, |
| 356 | + top_p=cfg.top_p, |
524 | 357 | stop_on_eos=False, |
525 | 358 | ), |
526 | 359 | _measure_and_log_time=False, |
@@ -557,7 +390,7 @@ def run( |
557 | 390 | batch_size=batch_size, |
558 | 391 | input_len=input_len, |
559 | 392 | output_len=output_len, |
560 | | - top_k=args.top_k, |
561 | | - top_p=args.top_p, |
562 | | - temperature=args.temperature, |
| 393 | + top_k=cfg.top_k, |
| 394 | + top_p=cfg.top_p, |
| 395 | + temperature=cfg.temperature, |
563 | 396 | ) |
0 commit comments