@@ -3084,6 +3084,73 @@ def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
30843084 task = GSM8K (self .MODEL_NAME )
30853085 task .evaluate (llm )
30863086
3087+ @pytest .mark .skip_less_mpi_world_size (8 )
3088+ @skip_pre_blackwell
3089+ @pytest .mark .parametrize (
3090+ "tp_size,pp_size,ep_size,mtp_nextn,attention_dp,max_batch_size,moe_backend,fp8kv,chunked_prefill" ,
3091+ [
3092+ (8 , 1 , 8 , 0 , True , 24 , "CUTLASS" , False , False ),
3093+ (8 , 1 , 8 , 3 , False , 16 , "TRTLLM" , True , True ),
3094+ ],
3095+ ids = ["baseline" , "mtp3_fp8kv_chunked" ])
3096+ def test_nvfp4_multi_gpus_piecewise_cuda_graph (self , tp_size , pp_size ,
3097+ ep_size , mtp_nextn ,
3098+ attention_dp , max_batch_size ,
3099+ moe_backend , fp8kv ,
3100+ chunked_prefill ):
3101+ sm_version = get_sm_version ()
3102+ if moe_backend == "TRTLLM" and sm_version in (120 , 121 ):
3103+ pytest .skip (f"{ moe_backend } backend does not support SM 120 or 121" )
3104+
3105+ moe_config = MoeConfig (backend = moe_backend , max_num_tokens = 16384 )
3106+ kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.7 )
3107+ if fp8kv :
3108+ kv_cache_config .dtype = "fp8"
3109+ kv_cache_config .enable_block_reuse = True
3110+
3111+ cuda_graph_config = CudaGraphConfig (
3112+ enable_padding = True ,
3113+ max_batch_size = max_batch_size ,
3114+ )
3115+ torch_compile_config = TorchCompileConfig (
3116+ enable_piecewise_cuda_graph = True ,
3117+ capture_num_tokens = [2048 , 8192 ],
3118+ max_num_streams = 3 ,
3119+ )
3120+ pytorch_config = dict (
3121+ disable_overlap_scheduler = False ,
3122+ cuda_graph_config = cuda_graph_config ,
3123+ moe_config = moe_config ,
3124+ torch_compile_config = torch_compile_config ,
3125+ )
3126+
3127+ mtp_config = None
3128+ if mtp_nextn > 0 :
3129+ mtp_config = MTPDecodingConfig (num_nextn_predict_layers = mtp_nextn )
3130+
3131+ llm_kwargs = dict (
3132+ max_batch_size = max_batch_size ,
3133+ tensor_parallel_size = tp_size ,
3134+ pipeline_parallel_size = pp_size ,
3135+ moe_expert_parallel_size = ep_size ,
3136+ kv_cache_config = kv_cache_config ,
3137+ enable_attention_dp = attention_dp ,
3138+ speculative_config = mtp_config ,
3139+ )
3140+ if chunked_prefill :
3141+ llm_kwargs .update (
3142+ enable_chunked_prefill = True ,
3143+ max_num_tokens = 8192 ,
3144+ )
3145+
3146+ with LLM (f"{ llm_models_root ()} /DeepSeek-V3.2-Exp-FP4-v2" ,
3147+ ** pytorch_config , ** llm_kwargs ) as llm :
3148+
3149+ task = MMLU (self .MODEL_NAME )
3150+ task .evaluate (llm )
3151+ task = GSM8K (self .MODEL_NAME )
3152+ task .evaluate (llm )
3153+
30873154 @pytest .mark .skip_less_mpi_world_size (8 )
30883155 @skip_pre_blackwell
30893156 @pytest .mark .parametrize (
0 commit comments