@@ -550,130 +550,31 @@ def setup_logging(log_level: str = "INFO"):
550550 )
551551
552552
553- def parse_args ():
554- """Parse command line arguments."""
555- parser = argparse .ArgumentParser (description = "InfiniLM Inference Server" )
556- parser .add_argument (
557- "--model_path" , type = str , required = True , help = "Path to model directory"
558- )
559- parser .add_argument ("--tp" , type = int , default = 1 , help = "Tensor parallelism degree" )
560- parser .add_argument (
561- "--cache_type" ,
562- type = str ,
563- default = "paged" ,
564- choices = ["paged" , "static" ],
565- help = "Cache type: paged or static" ,
566- )
567- parser .add_argument (
568- "--max_tokens" ,
569- type = int ,
570- default = 512 ,
571- help = "Maximum number of tokens to generate" ,
572- )
573- parser .add_argument (
574- "--max_batch_size" ,
575- type = int ,
576- default = 8 ,
577- help = "Maximum batch size (paged cache only)" ,
578- )
579- parser .add_argument (
580- "--num_blocks" ,
581- type = int ,
582- default = 512 ,
583- help = "Number of blocks for KV cache (paged cache only)" ,
584- )
585- parser .add_argument (
586- "--block_size" ,
587- type = int ,
588- default = 256 ,
589- help = "Block size for KV cache (paged cache only)" ,
590- )
591- parser .add_argument (
592- "--max_cache_len" ,
593- type = int ,
594- default = 4096 ,
595- help = "Maximum sequence length (static cache only)" ,
596- )
597- parser .add_argument (
598- "--dtype" ,
599- type = str ,
600- default = "float16" ,
601- choices = ["float32" , "float16" , "bfloat16" ],
602- help = "Data type" ,
603- )
604- parser .add_argument (
605- "--temperature" , type = float , default = 1.0 , help = "Sampling temperature"
606- )
607- parser .add_argument (
608- "--top_p" , type = float , default = 0.8 , help = "Top-p sampling parameter"
609- )
610- parser .add_argument ("--top_k" , type = int , default = 1 , help = "Top-k sampling parameter" )
611- parser .add_argument ("--host" , type = str , default = "0.0.0.0" , help = "Server host" )
612- parser .add_argument ("--port" , type = int , default = 8000 , help = "Server port" )
613- parser .add_argument ("--cpu" , action = "store_true" , help = "Use CPU" )
614- parser .add_argument ("--nvidia" , action = "store_true" , help = "Use NVIDIA GPU" )
615- parser .add_argument ("--qy" , action = "store_true" , help = "Use QY GPU" )
616- parser .add_argument ("--metax" , action = "store_true" , help = "Use MetaX device" )
617- parser .add_argument ("--moore" , action = "store_true" , help = "Use Moore device" )
618- parser .add_argument ("--iluvatar" , action = "store_true" , help = "Use Iluvatar device" )
619- parser .add_argument ("--cambricon" , action = "store_true" , help = "Use Cambricon device" )
620- parser .add_argument ("--ali" , action = "store_true" , help = "Use Ali PPU device" )
621- parser .add_argument ("--hygon" , action = "store_true" , help = "Use Hygon DCU device" )
622- parser .add_argument (
623- "--enable-graph" ,
624- action = "store_true" ,
625- help = "Enable graph compiling" ,
626- )
627- parser .add_argument (
628- "--attn" ,
629- type = str ,
630- default = "default" ,
631- choices = ["default" , "paged-attn" , "flash-attn" ],
632- help = "Attention backend to use: 'default' or 'flash-attn'" ,
633- )
634- parser .add_argument (
635- "--log_level" ,
636- type = str ,
637- default = "INFO" ,
638- choices = ["DEBUG" , "INFO" , "WARNING" , "ERROR" , "CRITICAL" ],
639- help = "Logging level" ,
640- )
641- parser .add_argument (
642- "--ignore-eos" ,
643- action = "store_true" ,
644- dest = "ignore_eos" ,
645- default = False ,
646- help = "Ignore EOS token and continue generation" ,
647- )
648-
649- return parser .parse_args ()
650-
651553
652554def main ():
653- # args = parse_args()
654555 cfg = BaseConfig ()
655556 setup_logging (cfg .log_level )
656557 device = cfg .get_device_str (cfg .device )
657558
658559 server = InferenceServer (
659560 model_path = cfg .model ,
660561 device = device ,
661- dtype = args .dtype ,
662- tensor_parallel_size = args .tp ,
663- cache_type = args .cache_type ,
664- max_tokens = args .max_tokens ,
665- max_batch_size = args .max_batch_size ,
666- num_blocks = args .num_blocks ,
667- block_size = args .block_size ,
668- max_cache_len = args .max_cache_len ,
669- temperature = args .temperature ,
670- top_p = args .top_p ,
671- top_k = args .top_k ,
672- host = args .host ,
673- port = args .port ,
674- enable_graph = args .enable_graph ,
675- attn_backend = args .attn ,
676- ignore_eos = args .ignore_eos ,
562+ dtype = cfg .dtype ,
563+ tensor_parallel_size = cfg .tp ,
564+ cache_type = cfg .cache_type ,
565+ max_tokens = cfg .max_tokens ,
566+ max_batch_size = cfg .max_batch_size ,
567+ num_blocks = cfg .num_blocks ,
568+ block_size = cfg .block_size ,
569+ max_cache_len = cfg .max_cache_len ,
570+ temperature = cfg .temperature ,
571+ top_p = cfg .top_p ,
572+ top_k = cfg .top_k ,
573+ host = cfg .host ,
574+ port = cfg .port ,
575+ enable_graph = cfg .enable_graph ,
576+ attn_backend = cfg .attn ,
577+ ignore_eos = cfg .ignore_eos ,
677578 )
678579 server .start ()
679580
0 commit comments