Skip to content

Commit 5a2f9c4

Browse files
zhuyuegongchensu
authored andcommitted
Add SwiGLU operator Python interface and tests.
1 parent 6f16798 commit 5a2f9c4

10 files changed

Lines changed: 260 additions & 1 deletion

File tree

include/infinicore/ops.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
#include "ops/ones.hpp"
77
#include "ops/rearrange.hpp"
88
#include "ops/rms_norm.hpp"
9+
#include "ops/swiglu.hpp"

include/infinicore/ops/swiglu.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
class SwiGLU {
8+
public:
9+
using schema = void (*)(Tensor, Tensor, Tensor);
10+
static void execute(Tensor c, Tensor a, Tensor b);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
14+
Tensor swiglu(Tensor a, Tensor b);
15+
void swiglu_(Tensor c, Tensor a, Tensor b);
16+
} // namespace infinicore::op

python/infinicore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from infinicore.ops.matmul import matmul
3131
from infinicore.ops.rearrange import rearrange
3232
from infinicore.ops.rms_norm import rms_norm
33+
from infinicore.ops.swiglu import swiglu
3334
from infinicore.tensor import (
3435
empty,
3536
from_blob,
@@ -74,6 +75,7 @@
7475
"matmul",
7576
"rearrange",
7677
"rms_norm",
78+
"swiglu",
7779
"empty",
7880
"from_blob",
7981
"ones",

python/infinicore/ops/swiglu.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
5+
def swiglu(input, other, *, out=None):
6+
if out is None:
7+
return Tensor(_infinicore.swiglu(input._underlying, other._underlying))
8+
9+
_infinicore.swiglu_(out._underlying, input._underlying, other._underlying)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include "infinicore/ops/swiglu.hpp"
2+
#include <stdexcept>
3+
4+
namespace infinicore::op {
5+
6+
common::OpDispatcher<SwiGLU::schema> &SwiGLU::dispatcher() {
7+
static common::OpDispatcher<SwiGLU::schema> dispatcher_;
8+
return dispatcher_;
9+
};
10+
11+
void SwiGLU::execute(Tensor c, Tensor a, Tensor b) {
12+
auto device_type = context::getDevice().getType();
13+
auto func = dispatcher().lookup(device_type);
14+
15+
if (func == nullptr) {
16+
throw std::runtime_error("No SwiGLU implementation found for device type: " + std::to_string(static_cast<int>(device_type)));
17+
}
18+
19+
func(c, a, b);
20+
}
21+
22+
Tensor swiglu(Tensor a, Tensor b) {
23+
Shape shape = a->shape();
24+
auto c = Tensor::empty(shape, a->dtype(), a->device());
25+
swiglu_(c, a, b);
26+
return c;
27+
}
28+
29+
void swiglu_(Tensor c, Tensor a, Tensor b) {
30+
SwiGLU::execute(c, a, b);
31+
}
32+
} // namespace infinicore::op
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#include "../../utils.hpp"
2+
#include "infinicore/common/hash.hpp"
3+
#include "infinicore/ops/common/cache.hpp"
4+
#include "infinicore/ops/swiglu.hpp"
5+
#include <infiniop.h>
6+
7+
namespace infinicore::op::swiglu_impl::infiniop {
8+
9+
thread_local common::OpCache<size_t, infiniopSwiGLUDescriptor_t> caches(
10+
100, // capacity
11+
[](infiniopSwiGLUDescriptor_t &desc) {
12+
if (desc != nullptr) {
13+
INFINICORE_CHECK_ERROR(infiniopDestroySwiGLUDescriptor(desc));
14+
desc = nullptr;
15+
}
16+
});
17+
18+
void calculate(Tensor c, Tensor a, Tensor b) {
19+
size_t seed = hash_combine(c, b, a);
20+
21+
auto device_type = context::getDevice().getType();
22+
auto device_index = context::getDevice().getIndex();
23+
24+
auto &cache = caches.getCache(device_type, device_index);
25+
26+
auto desc_opt = cache.get(seed);
27+
infiniopSwiGLUDescriptor_t desc = nullptr;
28+
29+
if (!desc_opt) {
30+
INFINICORE_CHECK_ERROR(infiniopCreateSwiGLUDescriptor(
31+
context::getInfiniopHandle(), &desc,
32+
c->desc(), a->desc(), b->desc()));
33+
cache.put(seed, desc);
34+
} else {
35+
desc = *desc_opt;
36+
}
37+
38+
size_t workspace_size = 0;
39+
INFINICORE_CHECK_ERROR(infiniopGetSwiGLUWorkspaceSize(desc, &workspace_size));
40+
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
41+
42+
INFINICORE_CHECK_ERROR(infiniopSwiGLU(
43+
desc, workspace->data(), workspace_size,
44+
c->data(), a->data(), b->data(), context::getStream()));
45+
}
46+
47+
static bool registered = []() {
48+
SwiGLU::dispatcher().registerAll(&calculate, false);
49+
return true;
50+
}();
51+
52+
} // namespace infinicore::op::swiglu_impl::infiniop

src/infinicore/pybind11/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "ops/matmul.hpp"
88
#include "ops/rearrange.hpp"
99
#include "ops/rms_norm.hpp"
10+
#include "ops/swiglu.hpp"
1011

1112
namespace py = pybind11;
1213

@@ -18,6 +19,7 @@ inline void bind(py::module &m) {
1819
bind_matmul(m);
1920
bind_rearrange(m);
2021
bind_rms_norm(m);
22+
bind_swiglu(m);
2123
}
2224

2325
} // namespace infinicore::ops
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
#include <pybind11/pybind11.h>
4+
5+
#include "infinicore/ops/swiglu.hpp"
6+
7+
namespace py = pybind11;
8+
9+
namespace infinicore::ops {
10+
11+
inline void bind_swiglu(py::module &m) {
12+
m.def("swiglu",
13+
&op::swiglu,
14+
py::arg("a"),
15+
py::arg("b"),
16+
R"doc(SwiGLU activation function.)doc");
17+
18+
m.def("swiglu_",
19+
&op::swiglu_,
20+
py::arg("c"),
21+
py::arg("a"),
22+
py::arg("b"),
23+
R"doc(In-place SwiGLU activation function.)doc");
24+
}
25+
26+
} // namespace infinicore::ops

test/infinicore/ops/swiglu.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import sys
2+
import os
3+
4+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
5+
6+
import torch
7+
import infinicore
8+
from framework.base import BaseOperatorTest, TensorSpec, TestCase
9+
from framework.runner import GenericTestRunner
10+
11+
# ==============================================================================
12+
# Operator-specific configuration
13+
# ==============================================================================
14+
15+
# Test cases format: (operation_mode, shape, a_strides, b_strides, c_strides)
16+
# SwiGLU operates element-wise on two tensors of the same shape
17+
_TEST_CASES_DATA = [
18+
# Basic 2D SwiGLU
19+
(TestCase.BOTH, (2, 4), None, None, None),
20+
(TestCase.BOTH, (128, 64), None, None, None),
21+
# 3D SwiGLU
22+
(TestCase.BOTH, (2, 4, 8), None, None, None),
23+
(TestCase.BOTH, (4, 48, 6), None, None, None),
24+
# Strided tensors
25+
(TestCase.BOTH, (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
26+
(TestCase.BOTH, (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
27+
# Mixed cases
28+
(TestCase.BOTH, (8, 16, 32), None, None, None),
29+
# Large tensors
30+
(TestCase.BOTH, (16, 5632), None, None, None),
31+
(TestCase.BOTH, (4, 4, 5632), None, None, None),
32+
]
33+
34+
35+
def parse_test_cases(data):
36+
"""
37+
Parse swiglu test case data according to format:
38+
(operation_mode, shape, a_strides, b_strides, c_strides)
39+
"""
40+
operation_mode = data[0]
41+
shape = data[1]
42+
a_strides = data[2] if len(data) > 2 else None
43+
b_strides = data[3] if len(data) > 3 else None
44+
c_strides = data[4] if len(data) > 4 else None
45+
46+
# Create input specifications
47+
inputs = []
48+
49+
# Tensor a
50+
if a_strides is not None:
51+
inputs.append(TensorSpec.from_strided_tensor(shape, a_strides))
52+
else:
53+
inputs.append(TensorSpec.from_tensor(shape))
54+
55+
# Tensor b
56+
if b_strides is not None:
57+
inputs.append(TensorSpec.from_strided_tensor(shape, b_strides))
58+
else:
59+
inputs.append(TensorSpec.from_tensor(shape))
60+
61+
# Output tensor
62+
if c_strides is not None:
63+
output = TensorSpec.from_strided_tensor(shape, c_strides)
64+
else:
65+
output = TensorSpec.from_tensor(shape)
66+
67+
return TestCase(operation_mode, inputs, output)
68+
69+
70+
# Parse test cases
71+
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
72+
73+
# Data types
74+
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
75+
76+
# Tolerance
77+
_TOLERANCE_MAP = {
78+
infinicore.float16: {"atol": 1e-3, "rtol": 1e-3},
79+
infinicore.float32: {"atol": 1e-5, "rtol": 1e-5},
80+
infinicore.bfloat16: {"atol": 5e-3, "rtol": 1e-2},
81+
}
82+
83+
84+
class OpTest(BaseOperatorTest):
85+
"""SwiGLU test with simplified test case parsing"""
86+
87+
def __init__(self):
88+
super().__init__("SwiGLU")
89+
90+
def get_test_cases(self):
91+
return _TEST_CASES
92+
93+
def get_tensor_dtypes(self):
94+
return _TENSOR_DTYPES
95+
96+
def get_tolerance_map(self):
97+
return _TOLERANCE_MAP
98+
99+
def torch_operator(self, a, b, out=None, **kwargs):
100+
# SwiGLU implementation: a * b * sigmoid(b)
101+
sigmoid_b = torch.sigmoid(b)
102+
result = a * b * sigmoid_b
103+
if out is not None:
104+
out.copy_(result)
105+
return out
106+
return result
107+
108+
def infinicore_operator(self, a, b, out=None, **kwargs):
109+
return infinicore.swiglu(a, b, out=out)
110+
111+
112+
def main():
113+
"""Main entry point"""
114+
runner = GenericTestRunner(OpTest)
115+
runner.run_and_exit()
116+
117+
118+
if __name__ == "__main__":
119+
main()

xmake/test.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ target("infinicore-test")
8585
add_files(os.projectdir().."/src/infinicore/context/*.cc")
8686
add_files(os.projectdir().."/src/infinicore/context/*/*.cc")
8787
add_files(os.projectdir().."/src/infinicore/tensor/*.cc")
88-
add_files(os.projectdir().."/src/infinicore/op/*/*.cc")
88+
add_files(os.projectdir().."/src/infinicore/ops/*/*.cc")
8989

9090
add_files(os.projectdir().."/src/infinicore-test/*.cc")
9191

0 commit comments

Comments
 (0)