Skip to content

Commit f4faae1

Browse files
committed
issue/1153 - add fused FFN operator with multi-backend support
1 parent 4aec970 commit f4faae1

13 files changed

Lines changed: 2030 additions & 0 deletions

File tree

include/infiniop/ops/fused_ffn.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#ifndef __INFINIOP_FUSED_FFN_API_H__
2+
#define __INFINIOP_FUSED_FFN_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopFusedFFNDescriptor_t;
7+
8+
__INFINI_C __export infiniStatus_t infiniopCreateFusedFFNDescriptor(
9+
infiniopHandle_t handle,
10+
infiniopFusedFFNDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t out_desc,
12+
infiniopTensorDescriptor_t in_desc,
13+
infiniopTensorDescriptor_t residual_desc,
14+
infiniopTensorDescriptor_t norm_weight_desc,
15+
infiniopTensorDescriptor_t gate_up_weight_desc,
16+
infiniopTensorDescriptor_t down_weight_desc,
17+
float epsilon);
18+
19+
__INFINI_C __export infiniStatus_t infiniopGetFusedFFNWorkspaceSize(
20+
infiniopFusedFFNDescriptor_t desc, size_t *size);
21+
22+
__INFINI_C __export infiniStatus_t infiniopFusedFFN(
23+
infiniopFusedFFNDescriptor_t desc,
24+
void *workspace,
25+
size_t workspace_size,
26+
void *out,
27+
const void *in,
28+
const void *residual,
29+
const void *norm_weight,
30+
const void *gate_up_weight,
31+
const void *down_weight,
32+
void *stream);
33+
34+
__INFINI_C __export infiniStatus_t infiniopDestroyFusedFFNDescriptor(
35+
infiniopFusedFFNDescriptor_t desc);
36+
37+
#endif
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
#include "fused_ffn_cpu.h"
2+
#include "../../../../utils.h"
3+
#include "../../../devices/cpu/common_cpu.h"
4+
#include <cmath>
5+
#include <cstring>
6+
7+
namespace op::fused_ffn::cpu {
8+
9+
Descriptor::~Descriptor() = default;
10+
11+
infiniStatus_t Descriptor::create(
12+
infiniopHandle_t handle,
13+
Descriptor **desc_ptr,
14+
infiniopTensorDescriptor_t out_desc,
15+
infiniopTensorDescriptor_t in_desc,
16+
infiniopTensorDescriptor_t residual_desc,
17+
infiniopTensorDescriptor_t norm_weight_desc,
18+
infiniopTensorDescriptor_t gate_up_weight_desc,
19+
infiniopTensorDescriptor_t down_weight_desc,
20+
float epsilon) {
21+
22+
auto result = FusedFFNInfo::create(
23+
out_desc, in_desc, residual_desc,
24+
norm_weight_desc, gate_up_weight_desc, down_weight_desc, epsilon);
25+
CHECK_RESULT(result);
26+
auto info = result.take();
27+
28+
// Workspace size (same as NVIDIA implementation)
29+
size_t dtype_size = infiniSizeOf(info.dtype);
30+
size_t ntok = info.ntok();
31+
size_t d = info.d();
32+
size_t di = info.di();
33+
34+
size_t normalized_size = ntok * d * dtype_size;
35+
size_t gate_up_size = ntok * 2 * di * dtype_size;
36+
37+
size_t workspace_size = normalized_size + gate_up_size;
38+
39+
*desc_ptr = new Descriptor(
40+
nullptr,
41+
std::move(info),
42+
workspace_size,
43+
handle->device, handle->device_id);
44+
return INFINI_STATUS_SUCCESS;
45+
}
46+
47+
template <typename Tdata, typename TnormWeight, typename TmatWeight>
48+
infiniStatus_t calculateTyped(
49+
const FusedFFNInfo &info,
50+
void *workspace, size_t workspace_size,
51+
void *out,
52+
const void *in,
53+
const void *residual,
54+
const void *norm_weight,
55+
const void *gate_up_weight,
56+
const void *down_weight) {
57+
58+
size_t ntok = info.ntok();
59+
size_t d = info.d();
60+
size_t di = info.di();
61+
62+
// Partition workspace (no separate hidden_buf needed, SwiGLU is in-place)
63+
char *ws_ptr = static_cast<char *>(workspace);
64+
Tdata *normalized_buf = reinterpret_cast<Tdata *>(ws_ptr);
65+
ws_ptr += ntok * d * sizeof(Tdata);
66+
Tdata *gate_up_buf = reinterpret_cast<Tdata *>(ws_ptr);
67+
68+
const Tdata *in_ptr = reinterpret_cast<const Tdata *>(in);
69+
const Tdata *residual_ptr = reinterpret_cast<const Tdata *>(residual);
70+
const TnormWeight *norm_w_ptr = reinterpret_cast<const TnormWeight *>(norm_weight);
71+
const TmatWeight *gate_up_w_ptr = reinterpret_cast<const TmatWeight *>(gate_up_weight);
72+
const TmatWeight *down_w_ptr = reinterpret_cast<const TmatWeight *>(down_weight);
73+
Tdata *out_ptr = reinterpret_cast<Tdata *>(out);
74+
75+
// Stage 1: RMSNorm
76+
for (size_t t = 0; t < ntok; t++) {
77+
const Tdata *x = in_ptr + t * info.in_stride;
78+
Tdata *norm = normalized_buf + t * d;
79+
80+
// Compute variance
81+
float sum_sq = 0.0f;
82+
for (size_t i = 0; i < d; i++) {
83+
float val = utils::cast<float>(x[i]);
84+
sum_sq += val * val;
85+
}
86+
87+
// Normalize
88+
float rms = 1.0f / std::sqrt(sum_sq / d + info.epsilon);
89+
for (size_t i = 0; i < d; i++) {
90+
float val = utils::cast<float>(x[i]) * utils::cast<float>(norm_w_ptr[i]) * rms;
91+
norm[i] = utils::cast<Tdata>(val);
92+
}
93+
}
94+
95+
// Stage 2: GateUp GEMM (C = A @ B^T)
96+
// normalized: [ntok, d], gate_up_weight: [2*di, d] -> gate_up: [ntok, 2*di]
97+
for (size_t t = 0; t < ntok; t++) {
98+
const Tdata *norm = normalized_buf + t * d;
99+
Tdata *gate_up = gate_up_buf + t * 2 * di;
100+
101+
for (size_t j = 0; j < 2 * di; j++) {
102+
float sum = 0.0f;
103+
for (size_t k = 0; k < d; k++) {
104+
sum += utils::cast<float>(norm[k]) * utils::cast<float>(gate_up_w_ptr[j * d + k]);
105+
}
106+
gate_up[j] = utils::cast<Tdata>(sum);
107+
}
108+
}
109+
110+
// Stage 3: SwiGLU (in-place, overwrites gate half of gate_up_buf)
111+
for (size_t t = 0; t < ntok; t++) {
112+
Tdata *gate_up = gate_up_buf + t * 2 * di;
113+
114+
for (size_t i = 0; i < di; i++) {
115+
float gate = utils::cast<float>(gate_up[i]);
116+
float up = utils::cast<float>(gate_up[di + i]);
117+
// SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))
118+
float silu = gate / (1.0f + std::exp(-gate));
119+
gate_up[i] = utils::cast<Tdata>(silu * up);
120+
}
121+
}
122+
123+
// Stage 4: Down GEMM (C = A @ B^T) + Residual Add (fused)
124+
// Read from gate_up_buf (stride 2*di) to match non-fused path's buffer layout
125+
{
126+
bool fuse_residual = info.has_residual && (out_ptr == residual_ptr);
127+
for (size_t t = 0; t < ntok; t++) {
128+
const Tdata *hidden = gate_up_buf + t * 2 * di; // stride = 2*di to match non-fused
129+
Tdata *o = out_ptr + t * info.out_stride;
130+
131+
if (fuse_residual) {
132+
const Tdata *res = residual_ptr + t * info.residual_stride;
133+
for (size_t j = 0; j < d; j++) {
134+
float sum = utils::cast<float>(res[j]);
135+
for (size_t k = 0; k < di; k++) {
136+
sum += utils::cast<float>(hidden[k]) * utils::cast<float>(down_w_ptr[j * di + k]);
137+
}
138+
o[j] = utils::cast<Tdata>(sum);
139+
}
140+
} else {
141+
for (size_t j = 0; j < d; j++) {
142+
float sum = 0.0f;
143+
for (size_t k = 0; k < di; k++) {
144+
sum += utils::cast<float>(hidden[k]) * utils::cast<float>(down_w_ptr[j * di + k]);
145+
}
146+
o[j] = utils::cast<Tdata>(sum);
147+
}
148+
}
149+
}
150+
}
151+
152+
// Stage 5: Residual Add (only when not fused into GEMM)
153+
if (info.has_residual && out_ptr != residual_ptr) {
154+
for (size_t t = 0; t < ntok; t++) {
155+
Tdata *o = out_ptr + t * info.out_stride;
156+
const Tdata *res = residual_ptr + t * info.residual_stride;
157+
for (size_t i = 0; i < d; i++) {
158+
float val = utils::cast<float>(o[i]) + utils::cast<float>(res[i]);
159+
o[i] = utils::cast<Tdata>(val);
160+
}
161+
}
162+
}
163+
164+
return INFINI_STATUS_SUCCESS;
165+
}
166+
167+
infiniStatus_t Descriptor::calculate(
168+
void *workspace, size_t workspace_size,
169+
void *out,
170+
const void *in,
171+
const void *residual,
172+
const void *norm_weight,
173+
const void *gate_up_weight,
174+
const void *down_weight,
175+
void *stream) const {
176+
177+
if (workspace_size < _workspace_size) {
178+
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
179+
}
180+
181+
// Dispatch based on dtype, wtype (norm weight), and mtype (matrix weight)
182+
if (_info.dtype == INFINI_DTYPE_F16) {
183+
if (_info.wtype == INFINI_DTYPE_F16 && _info.mtype == INFINI_DTYPE_F16) {
184+
return calculateTyped<fp16_t, fp16_t, fp16_t>(_info, workspace, workspace_size, out, in, residual, norm_weight, gate_up_weight, down_weight);
185+
} else if (_info.wtype == INFINI_DTYPE_F32 && _info.mtype == INFINI_DTYPE_F16) {
186+
return calculateTyped<fp16_t, float, fp16_t>(_info, workspace, workspace_size, out, in, residual, norm_weight, gate_up_weight, down_weight);
187+
} else if (_info.wtype == INFINI_DTYPE_F16 && _info.mtype == INFINI_DTYPE_F32) {
188+
return calculateTyped<fp16_t, fp16_t, float>(_info, workspace, workspace_size, out, in, residual, norm_weight, gate_up_weight, down_weight);
189+
} else if (_info.wtype == INFINI_DTYPE_F32 && _info.mtype == INFINI_DTYPE_F32) {
190+
return calculateTyped<fp16_t, float, float>(_info, workspace, workspace_size, out, in, residual, norm_weight, gate_up_weight, down_weight);
191+
}
192+
} else if (_info.dtype == INFINI_DTYPE_BF16) {
193+
if (_info.wtype == INFINI_DTYPE_BF16 && _info.mtype == INFINI_DTYPE_BF16) {
194+
return calculateTyped<bf16_t, bf16_t, bf16_t>(_info, workspace, workspace_size, out, in, residual, norm_weight, gate_up_weight, down_weight);
195+
} else if (_info.wtype == INFINI_DTYPE_F32 && _info.mtype == INFINI_DTYPE_BF16) {
196+
return calculateTyped<bf16_t, float, bf16_t>(_info, workspace, workspace_size, out, in, residual, norm_weight, gate_up_weight, down_weight);
197+
} else if (_info.wtype == INFINI_DTYPE_BF16 && _info.mtype == INFINI_DTYPE_F32) {
198+
return calculateTyped<bf16_t, bf16_t, float>(_info, workspace, workspace_size, out, in, residual, norm_weight, gate_up_weight, down_weight);
199+
} else if (_info.wtype == INFINI_DTYPE_F32 && _info.mtype == INFINI_DTYPE_F32) {
200+
return calculateTyped<bf16_t, float, float>(_info, workspace, workspace_size, out, in, residual, norm_weight, gate_up_weight, down_weight);
201+
}
202+
} else if (_info.dtype == INFINI_DTYPE_F32) {
203+
return calculateTyped<float, float, float>(_info, workspace, workspace_size, out, in, residual, norm_weight, gate_up_weight, down_weight);
204+
}
205+
206+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
207+
}
208+
209+
} // namespace op::fused_ffn::cpu
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __FUSED_FFN_CPU_H__
2+
#define __FUSED_FFN_CPU_H__
3+
4+
#include "../fused_ffn.h"
5+
6+
DESCRIPTOR(cpu)
7+
8+
#endif
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#ifndef FUSED_FFN_H
2+
#define FUSED_FFN_H
3+
4+
#include "../../operator.h"
5+
#include "info.h"
6+
7+
#define DESCRIPTOR(NAMESPACE) \
8+
\
9+
namespace op::fused_ffn::NAMESPACE { \
10+
class Descriptor final : public InfiniopDescriptor { \
11+
struct Opaque; \
12+
Opaque *_opaque; \
13+
FusedFFNInfo _info; \
14+
size_t _workspace_size; \
15+
\
16+
Descriptor( \
17+
Opaque *opaque, \
18+
FusedFFNInfo info, \
19+
size_t workspace_size, \
20+
infiniDevice_t device_type, \
21+
int device_id) \
22+
: InfiniopDescriptor{device_type, device_id}, \
23+
_opaque(opaque), \
24+
_info(info), \
25+
_workspace_size(workspace_size) {} \
26+
\
27+
public: \
28+
~Descriptor(); \
29+
\
30+
size_t workspaceSize() const { return _workspace_size; } \
31+
\
32+
static infiniStatus_t create( \
33+
infiniopHandle_t handle, \
34+
Descriptor **desc_ptr, \
35+
infiniopTensorDescriptor_t out_desc, \
36+
infiniopTensorDescriptor_t in_desc, \
37+
infiniopTensorDescriptor_t residual_desc, \
38+
infiniopTensorDescriptor_t norm_weight_desc, \
39+
infiniopTensorDescriptor_t gate_up_weight_desc, \
40+
infiniopTensorDescriptor_t down_weight_desc, \
41+
float epsilon); \
42+
\
43+
infiniStatus_t calculate( \
44+
void *workspace, size_t workspace_size, \
45+
void *out, \
46+
const void *in, \
47+
const void *residual, \
48+
const void *norm_weight, \
49+
const void *gate_up_weight, \
50+
const void *down_weight, \
51+
void *stream) const; \
52+
}; \
53+
}
54+
55+
#endif // FUSED_FFN_H

0 commit comments

Comments
 (0)