Skip to content

Commit 3239753

Browse files
issue/1135 layernorm module
1 parent b32dcb5 commit 3239753

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,11 @@ class LayerNorm : public Module {
1414
* @param eps Small constant for numerical stability (default: 1e-6)
1515
* @param dtype Data type for the weight (default: DataType::F32)
1616
* @param device Device to create the weight on
17-
* @param elementwise_affine Whether to include learnable affine weight and bias parameters (default: true)
1817
*/
1918
LayerNorm(size_t normalized_shape,
2019
double eps = 1e-6,
2120
const DataType &dtype = DataType::F32,
22-
const Device &device = Device(),
23-
bool elementwise_affine = true);
21+
const Device &device = Device());
2422

2523
/**
2624
* @brief Forward pass: apply LayerNorm
@@ -35,6 +33,8 @@ class LayerNorm : public Module {
3533
*/
3634
Tensor forward(const Tensor &x) const;
3735

36+
// Module information
37+
size_t normalized_shape() const { return normalized_shape_; }
3838
double eps() const { return eps_; }
3939
DataType dtype() const { return dtype_; }
4040

@@ -54,6 +54,6 @@ class LayerNorm : public Module {
5454
size_t normalized_shape_; // Size of the feature dimension
5555
double eps_; // Epsilon for numerical stability
5656
DataType dtype_; // Data type for weight
57-
bool elementwise_affine_; // Whether to use learnable affine parameters
5857
};
58+
5959
} // namespace infinicore::nn

include/infinicore/ops.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "ops/hardswish.hpp"
2323
#include "ops/hardtanh.hpp"
2424
#include "ops/kv_caching.hpp"
25+
#include "ops/layer_norm.hpp"
2526
#include "ops/matmul.hpp"
2627
#include "ops/ones.hpp"
2728
#include "ops/paged_attention.hpp"

src/infinicore/nn/layer_norm.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include "infinicore/nn/layer_norm.hpp"
2+
#include "infinicore/ops.hpp"
3+
#include <cmath>
4+
#include <stdexcept>
5+
6+
namespace infinicore::nn {
7+
8+
LayerNorm::LayerNorm(size_t normalized_shape, double eps, const DataType &dtype, const Device &device)
9+
: normalized_shape_(normalized_shape),
10+
eps_(eps),
11+
dtype_(dtype) {
12+
13+
device_ = device;
14+
15+
INFINICORE_NN_PARAMETER_INIT(weight, ({normalized_shape}, dtype_, device));
16+
INFINICORE_NN_PARAMETER_INIT(bias, ({normalized_shape}, dtype_, device));
17+
}
18+
19+
Tensor LayerNorm::forward(const Tensor &x) const {
20+
return op::layer_norm(x, weight_, bias_, static_cast<float>(eps_));
21+
}
22+
23+
std::string LayerNorm::extra_repr() const {
24+
return "LayerNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")";
25+
}
26+
27+
} // namespace infinicore::nn

0 commit comments

Comments
 (0)