Skip to content

Commit ae7ee73

Browse files
committed
fix inference server's base config
1 parent 5b543d5 commit ae7ee73

File tree

2 files changed

+24
-117
lines changed

2 files changed

+24
-117
lines changed

python/infinilm/base_config.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def __init__(self):
2323
self.enable_graph = self.args.enable_graph
2424
self.cache_type = self.args.cache_type
2525
self.enable_paged_attn = self.args.enable_paged_attn
26+
27+
# When enable_paged_attn is True, automatically set attn to "paged-attn"
28+
if self.enable_paged_attn and self.attn == "default":
29+
self.attn = "paged-attn"
2630
self.paged_kv_block_size = self.args.paged_kv_block_size
2731
self.num_blocks = self.args.num_blocks
2832
self.block_size = self.args.block_size
@@ -70,6 +74,7 @@ def __init__(self):
7074
self.port = self.args.port
7175
self.endpoint = self.args.endpoint
7276

77+
self.ignore_eos = self.args.ignore_eos
7378
def _add_common_args(self):
7479
# --- base configuration ---
7580
self.parser.add_argument("--model", type=str, required=True)
@@ -79,8 +84,8 @@ def _add_common_args(self):
7984

8085
# --- Infer backend optimization ---
8186
self.parser.add_argument("--attn", type=str, default="default", choices=["default", "paged-attn", "flash-attn"])
82-
self.parser.add_argument("--enable-graph", action="store_true")
83-
self.parser.add_argument("--cache-type", type=str, default="paged", choices=["paged", "static"])
87+
self.parser.add_argument("--enable-graph", action="store_false")
88+
self.parser.add_argument("--cache-type", type=str, default="paged", choices=["paged", "static"])
8489
self.parser.add_argument("--enable-paged-attn", action="store_true", help="use paged cache",)
8590
self.parser.add_argument("--paged-kv-block-size", type=int, default=256)
8691
self.parser.add_argument("--num-blocks", type=int, default=512, help="number of KV cache blocks")
@@ -131,6 +136,7 @@ def _add_common_args(self):
131136
self.parser.add_argument("--port", type=int, default=8000, help="server port")
132137
self.parser.add_argument("--endpoint", type=str, default="/completions", help="API endpoint")
133138

139+
self.parser.add_argument("--ignore-eos", action="store_true", dest="ignore_eos", default=False, help="Ignore EOS token and continue generation",)
134140

135141
def get_device_str(self, device):
136142
"""Convert device name to backend string (cuda/cpu/musa/mlu)"""

python/infinilm/server/inference_server.py

Lines changed: 16 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -550,130 +550,31 @@ def setup_logging(log_level: str = "INFO"):
550550
)
551551

552552

553-
def parse_args():
554-
"""Parse command line arguments."""
555-
parser = argparse.ArgumentParser(description="InfiniLM Inference Server")
556-
parser.add_argument(
557-
"--model_path", type=str, required=True, help="Path to model directory"
558-
)
559-
parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism degree")
560-
parser.add_argument(
561-
"--cache_type",
562-
type=str,
563-
default="paged",
564-
choices=["paged", "static"],
565-
help="Cache type: paged or static",
566-
)
567-
parser.add_argument(
568-
"--max_tokens",
569-
type=int,
570-
default=512,
571-
help="Maximum number of tokens to generate",
572-
)
573-
parser.add_argument(
574-
"--max_batch_size",
575-
type=int,
576-
default=8,
577-
help="Maximum batch size (paged cache only)",
578-
)
579-
parser.add_argument(
580-
"--num_blocks",
581-
type=int,
582-
default=512,
583-
help="Number of blocks for KV cache (paged cache only)",
584-
)
585-
parser.add_argument(
586-
"--block_size",
587-
type=int,
588-
default=256,
589-
help="Block size for KV cache (paged cache only)",
590-
)
591-
parser.add_argument(
592-
"--max_cache_len",
593-
type=int,
594-
default=4096,
595-
help="Maximum sequence length (static cache only)",
596-
)
597-
parser.add_argument(
598-
"--dtype",
599-
type=str,
600-
default="float16",
601-
choices=["float32", "float16", "bfloat16"],
602-
help="Data type",
603-
)
604-
parser.add_argument(
605-
"--temperature", type=float, default=1.0, help="Sampling temperature"
606-
)
607-
parser.add_argument(
608-
"--top_p", type=float, default=0.8, help="Top-p sampling parameter"
609-
)
610-
parser.add_argument("--top_k", type=int, default=1, help="Top-k sampling parameter")
611-
parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host")
612-
parser.add_argument("--port", type=int, default=8000, help="Server port")
613-
parser.add_argument("--cpu", action="store_true", help="Use CPU")
614-
parser.add_argument("--nvidia", action="store_true", help="Use NVIDIA GPU")
615-
parser.add_argument("--qy", action="store_true", help="Use QY GPU")
616-
parser.add_argument("--metax", action="store_true", help="Use MetaX device")
617-
parser.add_argument("--moore", action="store_true", help="Use Moore device")
618-
parser.add_argument("--iluvatar", action="store_true", help="Use Iluvatar device")
619-
parser.add_argument("--cambricon", action="store_true", help="Use Cambricon device")
620-
parser.add_argument("--ali", action="store_true", help="Use Ali PPU device")
621-
parser.add_argument("--hygon", action="store_true", help="Use Hygon DCU device")
622-
parser.add_argument(
623-
"--enable-graph",
624-
action="store_true",
625-
help="Enable graph compiling",
626-
)
627-
parser.add_argument(
628-
"--attn",
629-
type=str,
630-
default="default",
631-
choices=["default", "paged-attn", "flash-attn"],
632-
help="Attention backend to use: 'default' or 'flash-attn'",
633-
)
634-
parser.add_argument(
635-
"--log_level",
636-
type=str,
637-
default="INFO",
638-
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
639-
help="Logging level",
640-
)
641-
parser.add_argument(
642-
"--ignore-eos",
643-
action="store_true",
644-
dest="ignore_eos",
645-
default=False,
646-
help="Ignore EOS token and continue generation",
647-
)
648-
649-
return parser.parse_args()
650-
651553

652554
def main():
653-
# args = parse_args()
654555
cfg = BaseConfig()
655556
setup_logging(cfg.log_level)
656557
device = cfg.get_device_str(cfg.device)
657558

658559
server = InferenceServer(
659560
model_path=cfg.model,
660561
device=device,
661-
dtype=args.dtype,
662-
tensor_parallel_size=args.tp,
663-
cache_type=args.cache_type,
664-
max_tokens=args.max_tokens,
665-
max_batch_size=args.max_batch_size,
666-
num_blocks=args.num_blocks,
667-
block_size=args.block_size,
668-
max_cache_len=args.max_cache_len,
669-
temperature=args.temperature,
670-
top_p=args.top_p,
671-
top_k=args.top_k,
672-
host=args.host,
673-
port=args.port,
674-
enable_graph=args.enable_graph,
675-
attn_backend=args.attn,
676-
ignore_eos=args.ignore_eos,
562+
dtype=cfg.dtype,
563+
tensor_parallel_size=cfg.tp,
564+
cache_type=cfg.cache_type,
565+
max_tokens=cfg.max_tokens,
566+
max_batch_size=cfg.max_batch_size,
567+
num_blocks=cfg.num_blocks,
568+
block_size=cfg.block_size,
569+
max_cache_len=cfg.max_cache_len,
570+
temperature=cfg.temperature,
571+
top_p=cfg.top_p,
572+
top_k=cfg.top_k,
573+
host=cfg.host,
574+
port=cfg.port,
575+
enable_graph=cfg.enable_graph,
576+
attn_backend=cfg.attn,
577+
ignore_eos=cfg.ignore_eos,
677578
)
678579
server.start()
679580

0 commit comments

Comments
 (0)