Skip to content

Commit 558066f

Browse files
committed
issue/497 - rms norm interface
1 parent 846d897 commit 558066f

8 files changed

Lines changed: 154 additions & 0 deletions

File tree

include/infinicore/ops.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
#include "ops/matmul.hpp"
55
#include "ops/ones.hpp"
66
#include "ops/rearrange.hpp"
7+
#include "ops/rms_norm.hpp"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
class RMSNorm {
8+
public:
9+
using schema = void (*)(Tensor, Tensor, Tensor, float);
10+
static void execute(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
14+
Tensor rms_norm(Tensor x, Tensor weight, float epsilon = 1e-5f);
15+
void rms_norm_(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
16+
} // namespace infinicore::op

python/infinicore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from infinicore.ops.add import add
2828
from infinicore.ops.matmul import matmul
2929
from infinicore.ops.rearrange import rearrange
30+
from infinicore.ops.rms_norm import rms_norm
3031
from infinicore.tensor import (
3132
empty,
3233
from_blob,
@@ -67,6 +68,7 @@
6768
"add",
6869
"matmul",
6970
"rearrange",
71+
"rms_norm",
7072
"empty",
7173
"from_blob",
7274
"ones",

python/infinicore/ops/rms_norm.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
5+
def rms_norm(input, weight, epsilon=1e-5, *, out=None):
6+
if out is None:
7+
return Tensor(
8+
_infinicore.rms_norm(input._underlying, weight._underlying, epsilon)
9+
)
10+
11+
_infinicore.rms_norm_(
12+
out._underlying, input._underlying, weight._underlying, epsilon
13+
)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include "infinicore/ops/rms_norm.hpp"
2+
3+
namespace infinicore::op {
4+
5+
common::OpDispatcher<RMSNorm::schema> &RMSNorm::dispatcher() {
6+
static common::OpDispatcher<RMSNorm::schema> dispatcher_;
7+
return dispatcher_;
8+
};
9+
10+
void RMSNorm::execute(Tensor y, Tensor x, Tensor weight, float epsilon) {
11+
dispatcher().lookup(context::getDevice().getType())(y, x, weight, epsilon);
12+
}
13+
14+
Tensor rms_norm(Tensor x, Tensor weight, float epsilon) {
15+
auto y = Tensor::empty(x->shape(), x->dtype(), x->device());
16+
rms_norm_(y, x, weight, epsilon);
17+
return y;
18+
}
19+
20+
void rms_norm_(Tensor y, Tensor x, Tensor weight, float epsilon) {
21+
RMSNorm::execute(y, x, weight, epsilon);
22+
}
23+
24+
} // namespace infinicore::op
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#include "../../utils.hpp"
2+
#include "infinicore/common/hash.hpp"
3+
#include "infinicore/ops/common/cache.hpp"
4+
#include "infinicore/ops/rms_norm.hpp"
5+
#include <infiniop.h>
6+
7+
namespace infinicore::op::rms_norm_impl::infiniop {
8+
9+
thread_local common::OpCache<size_t, infiniopRMSNormDescriptor_t> caches(
10+
100, // capacity
11+
[](infiniopRMSNormDescriptor_t &desc) {
12+
if (desc != nullptr) {
13+
INFINICORE_CHECK_ERROR(infiniopDestroyRMSNormDescriptor(desc));
14+
desc = nullptr;
15+
}
16+
});
17+
18+
void calculate(Tensor y, Tensor x, Tensor weight, float epsilon) {
19+
size_t seed = hash_combine(y, x, weight, epsilon);
20+
21+
auto device_type = context::getDevice().getType();
22+
auto device_index = context::getDevice().getIndex();
23+
24+
auto &cache = caches.getCache(device_type, device_index);
25+
26+
auto desc_opt = cache.get(seed);
27+
infiniopRMSNormDescriptor_t desc = nullptr;
28+
29+
if (!desc_opt) {
30+
INFINICORE_CHECK_ERROR(infiniopCreateRMSNormDescriptor(
31+
context::getInfiniopHandle(), &desc,
32+
y->desc(), x->desc(), weight->desc(), epsilon));
33+
cache.put(seed, desc);
34+
} else {
35+
desc = *desc_opt;
36+
}
37+
38+
size_t workspace_size = 0;
39+
INFINICORE_CHECK_ERROR(infiniopGetRMSNormWorkspaceSize(desc, &workspace_size));
40+
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
41+
42+
INFINICORE_CHECK_ERROR(infiniopRMSNorm(
43+
desc, workspace->data(), workspace_size,
44+
y->data(), x->data(), weight->data(), context::getStream()));
45+
}
46+
47+
static bool registered = []() {
48+
RMSNorm::dispatcher().registerAll(&calculate, false);
49+
return true;
50+
}();
51+
52+
} // namespace infinicore::op::rms_norm_impl::infiniop

src/infinicore/pybind11/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "ops/add.hpp"
66
#include "ops/matmul.hpp"
77
#include "ops/rearrange.hpp"
8+
#include "ops/rms_norm.hpp"
89

910
namespace py = pybind11;
1011

@@ -14,6 +15,7 @@ inline void bind(py::module &m) {
1415
bind_add(m);
1516
bind_matmul(m);
1617
bind_rearrange(m);
18+
bind_rms_norm(m);
1719
}
1820

1921
} // namespace infinicore::ops
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#pragma once
2+
3+
#include <pybind11/pybind11.h>
4+
5+
#include "infinicore/ops/rms_norm.hpp"
6+
7+
namespace py = pybind11;
8+
9+
namespace infinicore::ops {
10+
11+
inline void bind_rms_norm(py::module &m) {
12+
m.def("rms_norm",
13+
&op::rms_norm,
14+
py::arg("x"),
15+
py::arg("weight"),
16+
py::arg("epsilon") = 1e-5f,
17+
R"doc(Root Mean Square Normalization.
18+
19+
Args:
20+
x: Input tensor
21+
weight: Scale weights
22+
epsilon: Small constant for numerical stability, default is 1e-5
23+
24+
Returns:
25+
Normalized tensor with same shape as input
26+
)doc");
27+
28+
m.def("rms_norm_",
29+
&op::rms_norm_,
30+
py::arg("y"),
31+
py::arg("x"),
32+
py::arg("weight"),
33+
py::arg("epsilon") = 1e-5f,
34+
R"doc(In-place Root Mean Square Normalization.
35+
36+
Args:
37+
y: Output tensor
38+
x: Input tensor
39+
weight: Scale weights
40+
epsilon: Small constant for numerical stability, default is 1e-5
41+
)doc");
42+
}
43+
44+
} // namespace infinicore::ops

0 commit comments

Comments
 (0)