Skip to content

Commit 18d6699

Browse files
issue/791 fix add_rmsnorm api and rmsnorm module
1 parent 3c8fb3c commit 18d6699

File tree

16 files changed

+225
-152
lines changed

16 files changed

+225
-152
lines changed

include/infinicore/nn/rmsnorm.hpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

3-
#include "module.hpp"
43
#include "../ops.hpp"
4+
#include "module.hpp"
55

66
namespace infinicore::nn {
77

@@ -57,6 +57,21 @@ class RMSNorm : public Module {
5757
*/
5858
Tensor forward(const Tensor &x) const;
5959

60+
/**
61+
* @brief Forward pass: apply RMSNorm in-place with residual
62+
*
63+
* @param x Input tensor of shape (*, normalized_shape) where * is any number of dimensions.
64+
* Will be modified in-place to the normalized output.
65+
* @param residual Residual tensor to add to input before normalization.
66+
* Will be modified in-place to the sum of input and residual.
67+
*
68+
* The normalization is applied over the last dimension.
69+
* For example:
70+
* Input: [batch, seq_len, hidden_size] -> normalize over hidden_size
71+
* Input: [batch, hidden_size] -> normalize over hidden_size
72+
*/
73+
void forward_inplace(Tensor &x, Tensor &residual) const;
74+
6075
// Module information
6176
size_t normalized_shape() const { return normalized_shape_; }
6277
double eps() const { return eps_; }
@@ -73,9 +88,9 @@ class RMSNorm : public Module {
7388
INFINICORE_NN_PARAMETER(weight);
7489

7590
private:
76-
size_t normalized_shape_; // Size of the feature dimension
77-
double eps_; // Epsilon for numerical stability
78-
DataType dtype_; // Data type for weight
91+
size_t normalized_shape_; // Size of the feature dimension
92+
double eps_; // Epsilon for numerical stability
93+
DataType dtype_; // Data type for weight
7994
};
8095

8196
} // namespace infinicore::nn

include/infinicore/ops/add_rms_norm.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,14 @@
55
#include <utility>
66

77
namespace infinicore::op {
8-
class AddRMSNorm {
9-
public:
10-
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, float);
11-
static void execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
12-
static common::OpDispatcher<schema> &dispatcher();
13-
};
8+
INFINICORE_GRAPH_OP_CLASS(AddRMSNorm, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &, float);
149

1510
// Fused Add and RMS Normalization
1611
// Returns: (normalized_result, add_result)
1712
// The add_result can be used as residual for subsequent layers
18-
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
19-
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
13+
std::pair<Tensor, Tensor> add_rms_norm(const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon = 1e-5f);
14+
void add_rms_norm_(Tensor out, Tensor residual, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon = 1e-5f);
15+
// Fused Add and RMS Normalization (inplace)
16+
// normalized_result wil be stored in input, add_result will be stored in residual
17+
void add_rms_norm_inplace(Tensor input, Tensor residual, const Tensor &weight, float epsilon = 1e-5f);
2018
} // namespace infinicore::op

include/infiniop/ops/add_rms_norm.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,22 @@ __C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor(
99
infiniopHandle_t handle,
1010
infiniopAddRMSNormDescriptor_t *desc_ptr,
1111
infiniopTensorDescriptor_t y_desc,
12+
infiniopTensorDescriptor_t residual_out_desc,
1213
infiniopTensorDescriptor_t a_desc,
1314
infiniopTensorDescriptor_t b_desc,
1415
infiniopTensorDescriptor_t weight_desc,
15-
float epsilon,
16-
infiniopTensorDescriptor_t residual_out_desc);
16+
float epsilon);
1717

1818
__C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size);
1919

2020
__C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t desc,
2121
void *workspace,
2222
size_t workspace_size,
2323
void *y,
24+
void *residual_out,
2425
const void *a,
2526
const void *b,
2627
const void *weight,
27-
void *residual_out,
2828
void *stream);
2929

3030
__C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc);

python/infinicore/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
uint8,
4444
)
4545
from infinicore.ops.add import add
46-
from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_
46+
from infinicore.ops.add_rms_norm import add_rms_norm
4747
from infinicore.ops.attention import attention
4848
from infinicore.ops.matmul import matmul
4949
from infinicore.ops.mul import mul
Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
import infinicore.tensor as tensor
12
from infinicore.lib import _infinicore
2-
from infinicore.tensor import Tensor
33

44

5-
def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None):
5+
def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None, residual=None):
66
"""
77
Fused Add and RMS Normalization.
88
@@ -18,30 +18,17 @@ def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None):
1818
The add_result can be used as residual for subsequent layers.
1919
"""
2020
if out is None:
21-
result = _infinicore.add_rms_norm(
22-
a._underlying, b._underlying, weight._underlying, epsilon
23-
)
24-
return (Tensor(result[0]), Tensor(result[1]))
21+
out = tensor.empty(a.shape, dtype=a.dtype, device=a.device)
22+
if residual is None:
23+
residual = tensor.empty(b.shape, dtype=b.dtype, device=b.device)
2524

26-
y, residual_out = out
2725
_infinicore.add_rms_norm_(
28-
y._underlying,
29-
residual_out._underlying,
26+
out._underlying,
27+
residual._underlying,
3028
a._underlying,
3129
b._underlying,
3230
weight._underlying,
3331
epsilon,
3432
)
35-
return (y, residual_out)
3633

37-
38-
def add_rms_norm_(y, residual_out, a, b, weight, epsilon=1e-5):
39-
"""In-place Fused Add and RMS Normalization."""
40-
_infinicore.add_rms_norm_(
41-
y._underlying,
42-
residual_out._underlying,
43-
a._underlying,
44-
b._underlying,
45-
weight._underlying,
46-
epsilon,
47-
)
34+
return out, residual

src/infinicore/nn/rmsnorm.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,24 @@ Tensor RMSNorm::forward(const Tensor &x) const {
2121
return op::rms_norm(x, weight_, static_cast<float>(eps_));
2222
}
2323

24+
void RMSNorm::forward_inplace(Tensor &x, Tensor &residual) const {
25+
if (!residual) {
26+
residual = x;
27+
x = op::rms_norm(x, weight_, static_cast<float>(eps_));
28+
} else {
29+
if (device_.getType() == Device::Type::CPU
30+
|| device_.getType() == Device::Type::NVIDIA
31+
|| device_.getType() == Device::Type::ILUVATAR
32+
|| device_.getType() == Device::Type::METAX
33+
|| device_.getType() == Device::Type::MOORE) {
34+
op::add_rms_norm_inplace(x, residual, weight_, static_cast<float>(eps_));
35+
} else {
36+
op::add_(residual, x, residual);
37+
op::rms_norm_(x, residual, weight_, static_cast<float>(eps_));
38+
}
39+
}
40+
}
41+
2442
std::string RMSNorm::extra_repr() const {
2543
return "RMSNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")";
2644
}

src/infinicore/ops/add_rms_norm/add_rms_norm.cc

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,30 @@
44

55
namespace infinicore::op {
66

7-
common::OpDispatcher<AddRMSNorm::schema> &AddRMSNorm::dispatcher() {
8-
static common::OpDispatcher<AddRMSNorm::schema> dispatcher_;
9-
return dispatcher_;
10-
};
7+
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(AddRMSNorm);
118

12-
void AddRMSNorm::execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
9+
AddRMSNorm::AddRMSNorm(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
1310
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight);
14-
infinicore::context::setDevice(y->device());
15-
dispatcher().lookup(y->device().getType())(y, residual_out, a, b, weight, epsilon);
11+
INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, residual_out, a, b, weight, epsilon);
1612
}
1713

18-
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) {
14+
void AddRMSNorm::execute(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
15+
INFINICORE_GRAPH_OP_RECORD_OR_RUN(AddRMSNorm, y, residual_out, a, b, weight, epsilon);
16+
}
17+
18+
std::pair<Tensor, Tensor> add_rms_norm(const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
1919
auto y = Tensor::empty(a->shape(), a->dtype(), a->device());
2020
auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device());
2121
add_rms_norm_(y, residual_out, a, b, weight, epsilon);
2222
return std::make_pair(y, residual_out);
2323
}
2424

25-
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
26-
AddRMSNorm::execute(y, residual_out, a, b, weight, epsilon);
25+
void add_rms_norm_(Tensor out, Tensor residual, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
26+
AddRMSNorm::execute(out, residual, a, b, weight, epsilon);
27+
}
28+
29+
void add_rms_norm_inplace(Tensor input, Tensor residual, const Tensor &weight, float epsilon) {
30+
add_rms_norm_(input, residual, input, residual, weight, epsilon);
2731
}
2832

2933
} // namespace infinicore::op
Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,53 @@
1-
#include "../../utils.hpp"
2-
#include "infinicore/common/hash.hpp"
31
#include "infinicore/ops/add_rms_norm.hpp"
4-
#include "infinicore/ops/common/cache.hpp"
5-
#include <infiniop.h>
2+
3+
#include "../infiniop_impl.hpp"
64

75
namespace infinicore::op::add_rms_norm_impl::infiniop {
86

9-
thread_local common::OpCache<size_t, infiniopAddRMSNormDescriptor_t> caches(
10-
100, // capacity
11-
[](infiniopAddRMSNormDescriptor_t &desc) {
12-
if (desc != nullptr) {
13-
INFINICORE_CHECK_ERROR(infiniopDestroyAddRMSNormDescriptor(desc));
14-
desc = nullptr;
15-
}
16-
});
7+
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, AddRMSNorm, 100);
8+
9+
struct PlannedMeta {
10+
std::shared_ptr<Descriptor> descriptor;
11+
graph::GraphTensor workspace, out, residual, a, b, weight;
12+
float epsilon;
13+
};
1714

18-
void calculate(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
15+
void *plan(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
1916
size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon);
2017

21-
auto device = context::getDevice();
22-
auto &cache = caches.getCache(device);
18+
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
19+
Descriptor, descriptor, AddRMSNorm,
20+
seed, y->desc(), residual_out->desc(),
21+
a->desc(), b->desc(), weight->desc(), epsilon);
22+
23+
INFINIOP_WORKSPACE_TENSOR(workspace, AddRMSNorm, descriptor);
2324

24-
auto desc_opt = cache.get(seed);
25-
infiniopAddRMSNormDescriptor_t desc = nullptr;
25+
auto planned = new PlannedMeta{
26+
descriptor,
27+
graph::GraphTensor(workspace),
28+
graph::GraphTensor(y),
29+
graph::GraphTensor(residual_out),
30+
graph::GraphTensor(a),
31+
graph::GraphTensor(b),
32+
graph::GraphTensor(weight),
33+
epsilon};
2634

27-
if (!desc_opt) {
28-
INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor(
29-
context::getInfiniopHandle(device), &desc,
30-
y->desc(), a->desc(), b->desc(), weight->desc(), epsilon, residual_out->desc()));
31-
cache.put(seed, desc);
32-
} else {
33-
desc = *desc_opt;
34-
}
35+
return planned;
36+
}
3537

36-
size_t workspace_size = 0;
37-
INFINICORE_CHECK_ERROR(infiniopGetAddRMSNormWorkspaceSize(desc, &workspace_size));
38-
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
38+
void run(void *planned_meta) {
39+
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
3940

4041
INFINICORE_CHECK_ERROR(infiniopAddRMSNorm(
41-
desc, workspace->data(), workspace_size,
42-
y->data(), a->data(), b->data(), weight->data(), residual_out->data(), context::getStream()));
42+
planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(),
43+
planned->out->data(), planned->residual->data(), planned->a->data(), planned->b->data(), planned->weight->data(), context::getStream()));
44+
}
45+
46+
void cleanup(void **planned_meta_ptr) {
47+
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
48+
*planned_meta_ptr = nullptr;
4349
}
4450

45-
static bool registered = []() {
46-
AddRMSNorm::dispatcher().registerAll(&calculate, false);
47-
return true;
48-
}();
51+
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(AddRMSNorm, &plan, &run, &cleanup);
4952

5053
} // namespace infinicore::op::add_rms_norm_impl::infiniop

src/infiniop/ops/add_rms_norm/add_rms_norm.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,19 @@
3333
infiniopHandle_t handle, \
3434
Descriptor **desc_ptr, \
3535
infiniopTensorDescriptor_t y_desc, \
36+
infiniopTensorDescriptor_t residual_out_desc, \
3637
infiniopTensorDescriptor_t a_desc, \
3738
infiniopTensorDescriptor_t b_desc, \
3839
infiniopTensorDescriptor_t weight_desc, \
39-
float epsilon, \
40-
infiniopTensorDescriptor_t residual_out_desc); \
40+
float epsilon); \
4141
\
4242
infiniStatus_t calculate( \
4343
void *workspace, size_t workspace_size, \
4444
void *y, \
45+
void *residual_out, \
4546
const void *a, \
4647
const void *b, \
4748
const void *weight, \
48-
void *residual_out, \
4949
void *stream) const; \
5050
}; \
5151
}

0 commit comments

Comments
 (0)