@@ -86,9 +86,9 @@ struct Descriptor::Opaque {
8686
8787 // Workspace slab sizes (bytes), padded to kWsAlign.
8888 size_t normalized_bytes = 0 ;
89- size_t gate_up_bytes = 0 ;
90- size_t hidden_bytes = 0 ;
91- size_t inner_ws_bytes = 0 ; // max of sub-descriptor workspaceSize()
89+ size_t gate_up_bytes = 0 ;
90+ size_t hidden_bytes = 0 ;
91+ size_t inner_ws_bytes = 0 ; // max of sub-descriptor workspaceSize()
9292
9393 bool has_residual = false ;
9494
@@ -108,10 +108,10 @@ struct Descriptor::Opaque {
108108 // Sub-descriptors owned by this fused op; each one is a standard
109109 // InfiniopDescriptor for the corresponding standalone operator.
110110 std::unique_ptr<op::rms_norm::nvidia::Descriptor> rms_norm;
111- std::unique_ptr<op::gemm::nvidia::Descriptor> gate_up_gemm;
112- std::unique_ptr<op::swiglu::nvidia::Descriptor> swiglu;
113- std::unique_ptr<op::gemm::nvidia::Descriptor> down_gemm;
114- std::unique_ptr<op::add::nvidia::Descriptor> residual_add;
111+ std::unique_ptr<op::gemm::nvidia::Descriptor> gate_up_gemm;
112+ std::unique_ptr<op::swiglu::nvidia::Descriptor> swiglu;
113+ std::unique_ptr<op::gemm::nvidia::Descriptor> down_gemm;
114+ std::unique_ptr<op::add::nvidia::Descriptor> residual_add;
115115};
116116
117117Descriptor::~Descriptor () {
@@ -138,12 +138,12 @@ infiniStatus_t Descriptor::create(
138138 auto handle = reinterpret_cast <device::nvidia::Handle *>(handle_);
139139
140140 auto opaque = std::make_unique<Opaque>();
141- opaque->internal = handle->internal ();
141+ opaque->internal = handle->internal ();
142142 opaque->has_residual = info.has_residual ;
143143
144144 const size_t ntok = info.ntok ();
145- const size_t d = info.d ();
146- const size_t di = info.di ();
145+ const size_t d = info.d ();
146+ const size_t di = info.di ();
147147 const size_t dtype_sz = infiniSizeOf (info.dtype );
148148
149149 // Profile-driven scheduler for the deep-fused kernel path.
@@ -164,7 +164,9 @@ infiniStatus_t Descriptor::create(
164164 const char *thr = std::getenv (" INFINIOP_FUSED_FFN_DEEP_MAX_NTOK" );
165165 if (thr != nullptr ) {
166166 max_ntok = static_cast <size_t >(std::atol (thr));
167- if (max_ntok == 0 ) max_ntok = 4 ;
167+ if (max_ntok == 0 ) {
168+ max_ntok = 4 ;
169+ }
168170 }
169171 opaque->use_deep_fused = (ntok <= max_ntok);
170172 }
@@ -199,9 +201,9 @@ infiniStatus_t Descriptor::create(
199201 // The compact hidden slab (stride=di instead of stride=2*di) gives the
200202 // Down-GEMM a tightly packed K dimension, which matters on BIV150 where
201203 // cuBLAS 10.2 tensor-core paths prefer aligned contiguous leading dims.
202- opaque->normalized_bytes = alignUp (ntok * d * dtype_sz, kWsAlign );
203- opaque->gate_up_bytes = alignUp (ntok * 2 * di * dtype_sz, kWsAlign );
204- opaque->hidden_bytes = alignUp (ntok * di * dtype_sz, kWsAlign );
204+ opaque->normalized_bytes = alignUp (ntok * d * dtype_sz, kWsAlign );
205+ opaque->gate_up_bytes = alignUp (ntok * 2 * di * dtype_sz, kWsAlign );
206+ opaque->hidden_bytes = alignUp (ntok * di * dtype_sz, kWsAlign );
205207
206208 DescScope scope;
207209
@@ -219,8 +221,7 @@ infiniStatus_t Descriptor::create(
219221 normalized_desc, in_view, norm_weight_desc,
220222 info.epsilon ));
221223 opaque->rms_norm .reset (sub);
222- opaque->inner_ws_bytes =
223- std::max (opaque->inner_ws_bytes , sub->workspaceSize ());
224+ opaque->inner_ws_bytes = std::max (opaque->inner_ws_bytes , sub->workspaceSize ());
224225 }
225226
226227 // ── GateUp GEMM sub-descriptor ──
@@ -235,8 +236,7 @@ infiniStatus_t Descriptor::create(
235236 CHECK_STATUS (op::gemm::nvidia::Descriptor::create (
236237 handle_, &sub, gate_up_c_desc, normalized_desc, gate_up_b_desc));
237238 opaque->gate_up_gemm .reset (sub);
238- opaque->inner_ws_bytes =
239- std::max (opaque->inner_ws_bytes , sub->workspaceSize ());
239+ opaque->inner_ws_bytes = std::max (opaque->inner_ws_bytes , sub->workspaceSize ());
240240 }
241241
242242 // ── SwiGLU sub-descriptor ──
@@ -256,8 +256,7 @@ infiniStatus_t Descriptor::create(
256256 CHECK_STATUS (op::swiglu::nvidia::Descriptor::create (
257257 handle_, &sub, hidden_desc, {half_desc, half_desc}));
258258 opaque->swiglu .reset (sub);
259- opaque->inner_ws_bytes =
260- std::max (opaque->inner_ws_bytes , sub->workspaceSize ());
259+ opaque->inner_ws_bytes = std::max (opaque->inner_ws_bytes , sub->workspaceSize ());
261260 }
262261
263262 // ── Down GEMM sub-descriptor ──
@@ -274,8 +273,7 @@ infiniStatus_t Descriptor::create(
274273 CHECK_STATUS (op::gemm::nvidia::Descriptor::create (
275274 handle_, &sub, out_view, hidden_desc, down_b_desc));
276275 opaque->down_gemm .reset (sub);
277- opaque->inner_ws_bytes =
278- std::max (opaque->inner_ws_bytes , sub->workspaceSize ());
276+ opaque->inner_ws_bytes = std::max (opaque->inner_ws_bytes , sub->workspaceSize ());
279277 }
280278
281279 // ── Residual add sub-descriptor (optional) ──
@@ -293,15 +291,10 @@ infiniStatus_t Descriptor::create(
293291 handle_, &sub,
294292 out_view_for_add, {out_view_for_add, residual_view}));
295293 opaque->residual_add .reset (sub);
296- opaque->inner_ws_bytes =
297- std::max (opaque->inner_ws_bytes , sub->workspaceSize ());
294+ opaque->inner_ws_bytes = std::max (opaque->inner_ws_bytes , sub->workspaceSize ());
298295 }
299296
300- const size_t workspace_size =
301- opaque->normalized_bytes +
302- opaque->gate_up_bytes +
303- opaque->hidden_bytes +
304- alignUp (opaque->inner_ws_bytes , kWsAlign );
297+ const size_t workspace_size = opaque->normalized_bytes + opaque->gate_up_bytes + opaque->hidden_bytes + alignUp (opaque->inner_ws_bytes , kWsAlign );
305298
306299 *desc_ptr = new Descriptor (
307300 opaque.release (),
@@ -325,16 +318,19 @@ infiniStatus_t Descriptor::calculate(
325318 return INFINI_STATUS_INSUFFICIENT_WORKSPACE ;
326319 }
327320
328- const size_t di = _info.di ();
321+ const size_t di = _info.di ();
329322 const size_t dtype_sz = infiniSizeOf (_info.dtype );
330323
331324 // Partition the workspace into the three persistent slabs plus an
332325 // inner scratch buffer shared by all sub-descriptors.
333326 char *ws = static_cast <char *>(workspace);
334- void *normalized_buf = ws; ws += _opaque->normalized_bytes ;
335- void *gate_up_buf = ws; ws += _opaque->gate_up_bytes ;
336- void *hidden_buf = ws; ws += _opaque->hidden_bytes ;
337- void *inner_ws = ws;
327+ void *normalized_buf = ws;
328+ ws += _opaque->normalized_bytes ;
329+ void *gate_up_buf = ws;
330+ ws += _opaque->gate_up_bytes ;
331+ void *hidden_buf = ws;
332+ ws += _opaque->hidden_bytes ;
333+ void *inner_ws = ws;
338334 const size_t inner_ws_size = _opaque->inner_ws_bytes ;
339335
340336 // gate and up are two halves of the interleaved gate_up buffer.
@@ -344,7 +340,7 @@ infiniStatus_t Descriptor::calculate(
344340 // shared half_desc at create time).
345341 const char *gu_bytes = static_cast <const char *>(gate_up_buf);
346342 const void *gate_ptr = gu_bytes;
347- const void *up_ptr = gu_bytes + di * dtype_sz;
343+ const void *up_ptr = gu_bytes + di * dtype_sz;
348344
349345 // Stage 1: RMSNorm
350346 CHECK_STATUS (_opaque->rms_norm ->calculate (
@@ -365,17 +361,17 @@ infiniStatus_t Descriptor::calculate(
365361 dim3 grid (static_cast <unsigned >(ntok), static_cast <unsigned >(di));
366362 dim3 block (kBlock );
367363
368- #define DEEP_FUSED_LAUNCH (TD, TW ) \
369- deepFusedGateUpSiluKernel<kBlock , float , TD , TW > \
370- <<<grid, block, 0 , cuda_stream>>> ( \
371- reinterpret_cast <TD *>(hidden_buf), \
372- reinterpret_cast <const TD *>(normalized_buf), \
373- reinterpret_cast <const TW *>(gate_up_weight), \
374- ntok, d, di, \
375- static_cast <ptrdiff_t >(d), \
376- static_cast <ptrdiff_t >(di), \
377- _opaque->gate_up_w_k_stride , \
378- _opaque->gate_up_w_col_stride , \
364+ #define DEEP_FUSED_LAUNCH (TD, TW ) \
365+ deepFusedGateUpSiluKernel<kBlock , float , TD , TW > \
366+ <<<grid, block, 0 , cuda_stream>>> ( \
367+ reinterpret_cast <TD *>(hidden_buf), \
368+ reinterpret_cast <const TD *>(normalized_buf), \
369+ reinterpret_cast <const TW *>(gate_up_weight), \
370+ ntok, d, di, \
371+ static_cast <ptrdiff_t >(d), \
372+ static_cast <ptrdiff_t >(di), \
373+ _opaque->gate_up_w_k_stride , \
374+ _opaque->gate_up_w_col_stride , \
379375 /* gate_col_base=*/ 0u , /* up_col_base=*/ di)
380376
381377 if (_info.dtype == INFINI_DTYPE_F16 && _info.mtype == INFINI_DTYPE_F16 ) {
@@ -416,8 +412,7 @@ infiniStatus_t Descriptor::calculate(
416412 // Stage 4: Down GEMM, with optional in-place residual fuse via beta=1.
417413 // fuse path : out = 1.0 * out + hidden_buf @ down_weight
418414 // plain path: out = 0.0 * out + hidden_buf @ down_weight
419- const bool fuse_residual =
420- _opaque->has_residual && (out == residual);
415+ const bool fuse_residual = _opaque->has_residual && (out == residual);
421416 CHECK_STATUS (_opaque->down_gemm ->calculate (
422417 inner_ws, inner_ws_size,
423418 out, /* beta=*/ fuse_residual ? 1 .f : 0 .f ,
0 commit comments