@@ -780,6 +780,63 @@ def test_eagle3(self, overlap_scheduler, eagle3_one_model):
780780 self .MODEL_PATH ) as llm :
781781 run_accuracy_test (llm , self .MODEL_NAME , ["GSM8K" ])
782782
783+ @pytest .mark .skip_less_device (2 )
784+ @skip_pre_hopper
785+ def test_gen_only_spec_dec (self ):
786+ speculative_decoding_config = {
787+ "decoding_type" : "Eagle" ,
788+ "max_draft_len" : 4 ,
789+ "speculative_model" :
790+ f"{ llm_models_root ()} /EAGLE3-LLaMA3.1-Instruct-8B" ,
791+ "eagle3_one_model" : True ,
792+ }
793+ ctx_server_config = {
794+ "disable_overlap_scheduler" :
795+ True , # BS=1 does not need overlap scheduling
796+ "kv_cache_config" : {
797+ "free_gpu_memory_fraction" : 0.5 ,
798+ "enable_block_reuse" : True # reuse on context requests
799+ },
800+ "max_num_tokens" : 13393 * 2 ,
801+ "max_batch_size" : 1 ,
802+ "cache_transceiver_config" : {
803+ "backend" : "NIXL" ,
804+ "transceiver_runtime" : "PYTHON" ,
805+ "max_tokens_in_buffer" : 4096 ,
806+ },
807+ "cuda_graph_config" : None ,
808+ }
809+ gen_server_config = {
810+ "disable_overlap_scheduler" : False ,
811+ "speculative_config" : speculative_decoding_config ,
812+ "kv_cache_config" : {
813+ "free_gpu_memory_fraction" : 0.5 ,
814+ "enable_block_reuse" : False
815+ },
816+ "max_num_tokens" : 13393 * 2 ,
817+ "max_batch_size" : 16 ,
818+ "cache_transceiver_config" : {
819+ "backend" : "NIXL" ,
820+ "transceiver_runtime" : "PYTHON" ,
821+ "max_tokens_in_buffer" : 4096 ,
822+ },
823+ "cuda_graph_config" : None ,
824+ }
825+ disaggregated_server_config = {
826+ "hostname" : "localhost" ,
827+ "backend" : "pytorch" ,
828+ "context_servers" : {
829+ "num_instances" : 1
830+ },
831+ "generation_servers" : {
832+ "num_instances" : 1
833+ }
834+ }
835+ with launch_disaggregated_llm (disaggregated_server_config ,
836+ ctx_server_config , gen_server_config ,
837+ self .MODEL_PATH ) as llm :
838+ run_accuracy_test (llm , self .MODEL_NAME , ["GSM8K" ])
839+
783840 @pytest .mark .skip_less_device (2 )
784841 @pytest .mark .skip_less_device_memory (32000 )
785842 @pytest .mark .parametrize ("backend" , ["xgrammar" , "llguidance" ])
@@ -1041,6 +1098,39 @@ def test_gen_only_sync(self):
10411098 ) as llm :
10421099 run_accuracy_test (llm , self .MODEL_NAME , ["GSM8K" ])
10431100
1101+ @pytest .mark .skip_less_device (8 )
1102+ @skip_pre_hopper
1103+ def test_gen_only_spec_dec (self ):
1104+ ctx_server_config = {"disable_overlap_scheduler" : True }
1105+ gen_server_config = {"disable_overlap_scheduler" : False }
1106+ cache_transceiver_config = {
1107+ "backend" : "NIXL" ,
1108+ "max_tokens_in_buffer" : 4096 ,
1109+ "transceiver_runtime" : "PYTHON" ,
1110+ }
1111+ ctx_server_config ["cache_transceiver_config" ] = cache_transceiver_config
1112+ gen_server_config ["cache_transceiver_config" ] = cache_transceiver_config
1113+ gen_server_config ["speculative_config" ] = {
1114+ "decoding_type" : "MTP" ,
1115+ "max_draft_len" : 2
1116+ }
1117+ disaggregated_server_config = {
1118+ "hostname" : "localhost" ,
1119+ "backend" : "pytorch" ,
1120+ "context_servers" : {
1121+ "num_instances" : 1
1122+ },
1123+ "generation_servers" : {
1124+ "num_instances" : 1
1125+ }
1126+ }
1127+ with launch_disaggregated_llm (disaggregated_server_config ,
1128+ ctx_server_config ,
1129+ gen_server_config ,
1130+ self .MODEL_PATH ,
1131+ tensor_parallel_size = 4 ) as llm :
1132+ run_accuracy_test (llm , self .MODEL_NAME , ["MMLU" , "GSM8K" ])
1133+
10441134 @pytest .mark .skip_less_device (8 )
10451135 @parametrize_with_ids ("overlap_scheduler" , [True , False ])
10461136 @parametrize_with_ids ("mtp_nextn" , [0 , 2 ])
0 commit comments