Skip to content

Commit 3e8ddf5

Browse files
authored
Merge pull request #293 from MoringLotus/qy_interview
创建统一的参数配置基类
2 parents 0f8270a + e5503f7 commit 3e8ddf5

File tree

8 files changed

+268
-768
lines changed

8 files changed

+268
-768
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ python/infinilm/lib/*.so
99
# Vscode
1010
.vscode/
1111

12+
*.sh
13+
model_weight/
14+
1215
# Python
1316
__pycache__/
1417
*.egg-info/

examples/bench.py

Lines changed: 23 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from infinilm.modeling_utils import load_model_state_dict_by_file
44
from infinilm.distributed import DistConfig
55
from infinilm.infer_engine import GenerationConfig, InferEngine
6+
from infinilm.base_config import BaseConfig
67
from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
78
import argparse
89
import sys
@@ -125,150 +126,6 @@ def get_test_cases(
125126

126127
return case_dict
127128

128-
129-
def get_args():
130-
parser = argparse.ArgumentParser(description="run Llama args")
131-
132-
parser.add_argument(
133-
"--cpu",
134-
action="store_true",
135-
help="Run cpu test",
136-
)
137-
parser.add_argument(
138-
"--nvidia",
139-
action="store_true",
140-
help="Run nvidia test",
141-
)
142-
parser.add_argument(
143-
"--qy",
144-
action="store_true",
145-
help="Run qy test",
146-
)
147-
parser.add_argument(
148-
"--metax",
149-
action="store_true",
150-
help="Run metax test",
151-
)
152-
parser.add_argument(
153-
"--moore",
154-
action="store_true",
155-
help="Run moore test",
156-
)
157-
parser.add_argument(
158-
"--iluvatar",
159-
action="store_true",
160-
help="Run iluvatar test",
161-
)
162-
parser.add_argument(
163-
"--cambricon",
164-
action="store_true",
165-
help="Run cambricon test",
166-
)
167-
parser.add_argument(
168-
"--ali",
169-
action="store_true",
170-
help="Run alippu test",
171-
)
172-
parser.add_argument(
173-
"--hygon",
174-
action="store_true",
175-
help="Run hygon test",
176-
)
177-
parser.add_argument(
178-
"--model",
179-
type=str,
180-
required=True,
181-
help="model path",
182-
)
183-
parser.add_argument(
184-
"--batch-size",
185-
type=parse_list,
186-
default=1,
187-
help="number of prompts in a batch (can be an int or a list of ints, e.g., '1' or '[1,2,4]' or '1,2,4')",
188-
)
189-
parser.add_argument(
190-
"--tensor-parallel-size",
191-
"--tp",
192-
type=int,
193-
default=1,
194-
help="total rank for tensor parallel",
195-
)
196-
parser.add_argument(
197-
"--input-len",
198-
type=parse_list,
199-
default=10,
200-
help="output tokens",
201-
)
202-
203-
parser.add_argument(
204-
"--output-len",
205-
type=parse_list,
206-
default=20,
207-
help="output tokens",
208-
)
209-
parser.add_argument(
210-
"--skip-load",
211-
action="store_true",
212-
help="skip loading model weights",
213-
)
214-
parser.add_argument(
215-
"--top-k",
216-
type=int,
217-
default=1,
218-
help="top k sampling",
219-
)
220-
221-
parser.add_argument(
222-
"--top-p",
223-
type=float,
224-
default=1.0,
225-
help="top p sampling",
226-
)
227-
228-
parser.add_argument(
229-
"--temperature",
230-
type=float,
231-
default=1.0,
232-
help="sampling temperature",
233-
)
234-
parser.add_argument(
235-
"--enable-paged-attn",
236-
action="store_true",
237-
help="use paged cache",
238-
)
239-
parser.add_argument(
240-
"--paged-kv-block-size",
241-
type=int,
242-
default=256,
243-
help="num tokens each kv block can hold",
244-
)
245-
parser.add_argument(
246-
"--enable-graph",
247-
action="store_true",
248-
help="enable graph compiling",
249-
)
250-
parser.add_argument(
251-
"--warmup",
252-
action="store_true",
253-
help="Perform a warmup run before benchmarking/inference.",
254-
)
255-
parser.add_argument(
256-
"--attn",
257-
type=str,
258-
default="default",
259-
choices=["default", "paged-attn", "flash-attn"],
260-
help="attention backend to use: 'default' or 'flash-attn'",
261-
)
262-
parser.add_argument(
263-
"--kv-cache-dtype",
264-
type=str,
265-
default=None,
266-
choices=["int8"],
267-
)
268-
269-
return parser.parse_args()
270-
271-
272129
with open("examples/bench_prompt.md", "r") as f:
273130
prompt = f.read()
274131

@@ -305,7 +162,7 @@ def __init__(
305162
cache_config=cache_config,
306163
enable_graph_compiling=enable_graph,
307164
attention_backend=attn_backend,
308-
kv_cache_dtype=args.kv_cache_dtype,
165+
kv_cache_dtype=cfg.kv_cache_dtype,
309166
)
310167

311168
# ---------------------------------------------------------------------------- #
@@ -396,52 +253,28 @@ def run(
396253

397254

398255
if __name__ == "__main__":
399-
args = get_args()
400-
print(args)
401-
402-
# Parse command line arguments
403-
device_str = "cpu"
404-
if args.cpu:
405-
device_str = "cpu"
406-
elif args.nvidia:
407-
device_str = "cuda"
408-
elif args.qy:
409-
device_str = "cuda"
410-
elif args.metax:
411-
device_str = "cuda"
412-
elif args.moore:
413-
device_str = "musa"
414-
elif args.iluvatar:
415-
device_str = "cuda"
416-
elif args.cambricon:
417-
device_str = "mlu"
418-
elif args.ali:
419-
device_str = "cuda"
420-
elif args.hygon:
421-
device_str = "cuda"
422-
else:
423-
print(
424-
"python examples/bench.py --nvidia --model=~/TinyLlama-1.1B-Chat-v1.0/ --batch-size=2 --tp=1 --input-len=50 --output-len=50"
425-
)
426-
sys.exit(1)
427-
_PAGED_KV_BLOCK_SIZE = args.paged_kv_block_size
256+
cfg = BaseConfig()
257+
258+
device_str = cfg.get_device_str(cfg.device)
259+
260+
_PAGED_KV_BLOCK_SIZE = cfg.paged_kv_block_size
428261
# -------------------------------------------------------- #
429262
# 解析参数
430263
# -------------------------------------------------------- #
431-
model_path = args.model
264+
model_path = cfg.model
432265

433266
infini_device = infinicore.device(device_str, 0)
434267

435-
tp = args.tensor_parallel_size
268+
tp = cfg.tp
436269

437-
skip_load = args.skip_load
270+
skip_load = cfg.skip_load
438271

439-
batch_size = args.batch_size
440-
input_len = args.input_len
441-
output_len = args.output_len
442-
enable_paged_attn = args.enable_paged_attn
443-
enable_graph = args.enable_graph
444-
attn_backend = args.attn
272+
batch_size = cfg.batch_size
273+
input_len = cfg.input_len
274+
output_len = cfg.output_len
275+
enable_paged_attn = cfg.enable_paged_attn
276+
enable_graph = cfg.enable_graph
277+
attn_backend = cfg.attn
445278

446279
if isinstance(batch_size, int):
447280
batch_size = [batch_size]
@@ -488,7 +321,7 @@ def run(
488321
# ---------------------------------------------------------------------------- #
489322
# Warmup
490323
# ---------------------------------------------------------------------------- #
491-
if args.warmup:
324+
if cfg.warmup:
492325
warmup_steps = 1
493326

494327
# warmup cache capacity
@@ -518,9 +351,9 @@ def run(
518351
input_ids_infini,
519352
GenerationConfig(
520353
max_new_tokens=5, # decode kernel warmup
521-
temperature=args.temperature,
522-
top_k=args.top_k,
523-
top_p=args.top_p,
354+
temperature=cfg.temperature,
355+
top_k=cfg.top_k,
356+
top_p=cfg.top_p,
524357
stop_on_eos=False,
525358
),
526359
_measure_and_log_time=False,
@@ -557,7 +390,7 @@ def run(
557390
batch_size=batch_size,
558391
input_len=input_len,
559392
output_len=output_len,
560-
top_k=args.top_k,
561-
top_p=args.top_p,
562-
temperature=args.temperature,
393+
top_k=cfg.top_k,
394+
top_p=cfg.top_p,
395+
temperature=cfg.temperature,
563396
)

0 commit comments

Comments
 (0)