@@ -151,6 +151,11 @@ struct MoeGemmState {
151151 // Workspace
152152 void * workspace_dev = nullptr ;
153153 size_t workspace_size = 0 ;
154+
155+ // Persistent GEMM object: avoids stack allocation per call, keeps
156+ // params_ alive for CUDA graph replay. init() triggers the one-time
157+ // cudaFuncSetAttribute call; run() reuses the object.
158+ Gemm gemm;
154159};
155160
156161static MoeGemmState s_state;
@@ -267,13 +272,20 @@ extern "C" int cgemm_nvfp4_moe_sm100_init(
267272 arguments.epilogue .thread .alpha = 1 .0f ;
268273 arguments.epilogue .thread .beta = 0 .0f ;
269274
270- Gemm gemm;
271- auto status = gemm.can_implement (arguments);
275+ auto status = st.gemm .can_implement (arguments);
272276 if (status != cutlass::Status::kSuccess ) {
273277 fprintf (stderr, " MoE GEMM can_implement failed: %d\n " , (int )status);
274278 return -1 ;
275279 }
276280
281+ // Initialize the persistent Gemm object: triggers cudaFuncSetAttribute
282+ // (one-time, not graph-safe) and fills internal params_ with dummy pointers.
283+ status = st.gemm .initialize (arguments, st.workspace_dev , stream);
284+ if (status != cutlass::Status::kSuccess ) {
285+ fprintf (stderr, " MoE GEMM initial initialize failed: %d\n " , (int )status);
286+ return -2 ;
287+ }
288+
277289 st.initialized = true ;
278290 return 0 ;
279291
@@ -324,13 +336,17 @@ extern "C" size_t cgemm_nvfp4_moe_sm100_workspace_size(
324336// SFA_dev: activation scale factors (batched swizzled layout)
325337// SFB_dev: weight scale factors (batched swizzled layout)
326338// D_dev: output (num_experts, max_M, N_output) BF16, row-major per expert
339+ // alpha_dev: device pointer to float alpha (= act_scale * weight_scale)
340+ //
341+ // Graph-safe: only host-side param building + kernel launch.
342+ // cudaFuncSetAttribute was already called during _init.
327343extern " C" int cgemm_nvfp4_moe_sm100_run (
328344 const void * A_dev, // activations (packed FP4)
329345 const void * B_dev, // weights (packed FP4)
330346 const void * SFA_dev, // activation scale factors
331347 const void * SFB_dev, // weight scale factors
332348 void * D_dev, // output (BF16)
333- float alpha,
349+ const float * alpha_dev, // device pointer to alpha scalar
334350 cudaStream_t stream
335351) {
336352#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
@@ -363,18 +379,23 @@ extern "C" int cgemm_nvfp4_moe_sm100_run(
363379 static_cast <ElementD*>(D_dev), st.stride_D },
364380 st.hw_info
365381 };
366- arguments.epilogue .thread .alpha = alpha;
382+ // Device-side alpha: if alpha_dev is non-null, kernel reads from device ptr.
383+ // alpha_ptr takes precedence over the scalar alpha value.
384+ arguments.epilogue .thread .alpha = 1 .0f ; // fallback (ignored when alpha_ptr set)
385+ arguments.epilogue .thread .alpha_ptr = alpha_dev;
367386 arguments.epilogue .thread .beta = 0 .0f ;
368387
369- Gemm gemm;
370-
371- auto status = gemm.initialize (arguments, st.workspace_dev , stream);
388+ // Rebuild params from arguments (host-side only, no CUDA API calls).
389+ // cudaFuncSetAttribute was already called during _init on the persistent
390+ // gemm object, so we call initialize() which is idempotent for the
391+ // attribute and only updates params_.
392+ auto status = st.gemm .initialize (arguments, st.workspace_dev , stream);
372393 if (status != cutlass::Status::kSuccess ) {
373394 fprintf (stderr, " MoE GEMM initialize failed: %d\n " , (int )status);
374395 return -2 ;
375396 }
376397
377- status = gemm.run (stream);
398+ status = st. gemm .run (stream);
378399 if (status != cutlass::Status::kSuccess ) {
379400 fprintf (stderr, " MoE GEMM run failed: %d\n " , (int )status);
380401 return -3 ;
0 commit comments