@@ -780,6 +780,63 @@ def test_eagle3(self, overlap_scheduler, eagle3_one_model):
780780 self .MODEL_PATH ) as llm :
781781 run_accuracy_test (llm , self .MODEL_NAME , ["GSM8K" ])
782782
783+ @pytest .mark .skip_less_device (2 )
784+ @skip_pre_hopper
785+ def test_gen_only_spec_dec (self ):
786+ speculative_decoding_config = {
787+ "decoding_type" : "Eagle" ,
788+ "max_draft_len" : 4 ,
789+ "speculative_model" :
790+ f"{ llm_models_root ()} /EAGLE3-LLaMA3.1-Instruct-8B" ,
791+ "eagle3_one_model" : True ,
792+ }
793+ ctx_server_config = {
794+ "disable_overlap_scheduler" :
795+ True , # BS=1 does not need overlap scheduling
796+ "kv_cache_config" : {
797+ "free_gpu_memory_fraction" : 0.5 ,
798+ "enable_block_reuse" : True # reuse on context requests
799+ },
800+ "max_num_tokens" : 13393 * 2 ,
801+ "max_batch_size" : 1 ,
802+ "cache_transceiver_config" : {
803+ "backend" : "NIXL" ,
804+ "transceiver_runtime" : "PYTHON" ,
805+ "max_tokens_in_buffer" : 4096 ,
806+ },
807+ "cuda_graph_config" : None ,
808+ }
809+ gen_server_config = {
810+ "disable_overlap_scheduler" : False ,
811+ "speculative_config" : speculative_decoding_config ,
812+ "kv_cache_config" : {
813+ "free_gpu_memory_fraction" : 0.5 ,
814+ "enable_block_reuse" : False
815+ },
816+ "max_num_tokens" : 13393 * 2 ,
817+ "max_batch_size" : 16 ,
818+ "cache_transceiver_config" : {
819+ "backend" : "NIXL" ,
820+ "transceiver_runtime" : "PYTHON" ,
821+ "max_tokens_in_buffer" : 4096 ,
822+ },
823+ "cuda_graph_config" : None ,
824+ }
825+ disaggregated_server_config = {
826+ "hostname" : "localhost" ,
827+ "backend" : "pytorch" ,
828+ "context_servers" : {
829+ "num_instances" : 1
830+ },
831+ "generation_servers" : {
832+ "num_instances" : 1
833+ }
834+ }
835+ with launch_disaggregated_llm (disaggregated_server_config ,
836+ ctx_server_config , gen_server_config ,
837+ self .MODEL_PATH ) as llm :
838+ run_accuracy_test (llm , self .MODEL_NAME , ["GSM8K" ])
839+
783840 @pytest .mark .skip_less_device (2 )
784841 @pytest .mark .skip_less_device_memory (32000 )
785842 @pytest .mark .parametrize ("backend" , ["xgrammar" , "llguidance" ])
@@ -1001,6 +1058,39 @@ def test_gen_only_sync(self):
10011058 ) as llm :
10021059 run_accuracy_test (llm , self .MODEL_NAME , ["GSM8K" ])
10031060
1061+ @pytest .mark .skip_less_device (8 )
1062+ @skip_pre_hopper
1063+ def test_gen_only_spec_dec (self ):
1064+ ctx_server_config = {"disable_overlap_scheduler" : True }
1065+ gen_server_config = {"disable_overlap_scheduler" : False }
1066+ cache_transceiver_config = {
1067+ "backend" : "NIXL" ,
1068+ "max_tokens_in_buffer" : 4096 ,
1069+ "transceiver_runtime" : "PYTHON" ,
1070+ }
1071+ ctx_server_config ["cache_transceiver_config" ] = cache_transceiver_config
1072+ gen_server_config ["cache_transceiver_config" ] = cache_transceiver_config
1073+ gen_server_config ["speculative_config" ] = {
1074+ "decoding_type" : "MTP" ,
1075+ "max_draft_len" : 2
1076+ }
1077+ disaggregated_server_config = {
1078+ "hostname" : "localhost" ,
1079+ "backend" : "pytorch" ,
1080+ "context_servers" : {
1081+ "num_instances" : 1
1082+ },
1083+ "generation_servers" : {
1084+ "num_instances" : 1
1085+ }
1086+ }
1087+ with launch_disaggregated_llm (disaggregated_server_config ,
1088+ ctx_server_config ,
1089+ gen_server_config ,
1090+ self .MODEL_PATH ,
1091+ tensor_parallel_size = 4 ) as llm :
1092+ run_accuracy_test (llm , self .MODEL_NAME , ["MMLU" , "GSM8K" ])
1093+
10041094 @pytest .mark .skip_less_device (8 )
10051095 @parametrize_with_ids ("overlap_scheduler" , [True , False ])
10061096 @parametrize_with_ids ("mtp_nextn" , [0 , 2 ])
0 commit comments