Skip to content

Commit 61fcdf7

Browse files
gongchensugongchensu
andauthored
feat(ops): add MetaX backend for Swiglu (#28)
Co-authored-by: gongchensu <zhuyue@qiyuanlab.com>
1 parent f44be6f commit 61fcdf7

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

src/metax/swiglu/kernel.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#ifndef INFINI_OPS_METAX_SWIGLU_KERNEL_H_
2+
#define INFINI_OPS_METAX_SWIGLU_KERNEL_H_
3+
4+
#include <utility>
5+
6+
// clang-format off
7+
#include <mcr/mc_runtime.h>
8+
// clang-format on
9+
10+
#include "cuda/swiglu/kernel.h"
11+
12+
namespace infini::ops {
13+
14+
namespace swiglu {
15+
16+
struct MetaxBackend {
17+
using stream_t = mcStream_t;
18+
19+
static constexpr auto malloc = [](auto&&... args) {
20+
return mcMalloc(std::forward<decltype(args)>(args)...);
21+
};
22+
23+
static constexpr auto memcpy = mcMemcpy;
24+
25+
static constexpr auto free = mcFree;
26+
27+
static constexpr auto memcpyH2D = mcMemcpyHostToDevice;
28+
};
29+
30+
} // namespace swiglu
31+
32+
template <>
33+
class Operator<Swiglu, Device::Type::kMetax>
34+
: public CudaSwiglu<swiglu::MetaxBackend> {
35+
public:
36+
using CudaSwiglu<swiglu::MetaxBackend>::CudaSwiglu;
37+
};
38+
39+
} // namespace infini::ops
40+
41+
#endif

0 commit comments

Comments
 (0)