Skip to content

Commit 23aa856

Browse files
issue/1135 layernorm module
1 parent b32dcb5 commit 23aa856

File tree

3 files changed

+87
-0
lines changed

3 files changed

+87
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#pragma once
2+
3+
#include "../ops.hpp"
4+
#include "module.hpp"
5+
6+
namespace infinicore::nn {
7+
8+
class LayerNorm : public Module {
9+
public:
10+
/**
11+
* @brief Construct a LayerNorm layer
12+
*
13+
* @param normalized_shape Size of the feature dimension to normalize (typically hidden_size)
14+
* @param eps Small constant for numerical stability (default: 1e-6)
15+
* @param dtype Data type for the weight (default: DataType::F32)
16+
* @param device Device to create the weight on
17+
*/
18+
LayerNorm(size_t normalized_shape,
19+
double eps = 1e-6,
20+
const DataType &dtype = DataType::F32,
21+
const Device &device = Device());
22+
23+
/**
24+
* @brief Forward pass: apply LayerNorm
25+
*
26+
* @param x Input tensor of shape (*, normalized_shape) where * is any number of dimensions
27+
* @return Normalized tensor with same shape as input
28+
*
29+
* The normalization is applied over the last dimension.
30+
* For example:
31+
* Input: [batch, seq_len, hidden_size] -> normalize over hidden_size
32+
* Input: [batch, hidden_size] -> normalize over hidden_size
33+
*/
34+
Tensor forward(const Tensor &x) const;
35+
36+
// Module information
37+
size_t normalized_shape() const { return normalized_shape_; }
38+
double eps() const { return eps_; }
39+
DataType dtype() const { return dtype_; }
40+
41+
// String representation
42+
std::string extra_repr() const;
43+
44+
// Accessors for parameters
45+
Tensor weight() const { return weight_; }
46+
Tensor bias() const { return bias_; }
47+
48+
protected:
49+
// Parameters
50+
INFINICORE_NN_PARAMETER(weight);
51+
INFINICORE_NN_PARAMETER(bias);
52+
53+
private:
54+
size_t normalized_shape_; // Size of the feature dimension
55+
double eps_; // Epsilon for numerical stability
56+
DataType dtype_; // Data type for weight
57+
};
58+
59+
} // namespace infinicore::nn

include/infinicore/ops.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "ops/hardswish.hpp"
2323
#include "ops/hardtanh.hpp"
2424
#include "ops/kv_caching.hpp"
25+
#include "ops/layer_norm.hpp"
2526
#include "ops/matmul.hpp"
2627
#include "ops/ones.hpp"
2728
#include "ops/paged_attention.hpp"

src/infinicore/nn/layer_norm.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include "infinicore/nn/layer_norm.hpp"
2+
#include "infinicore/ops.hpp"
3+
#include <cmath>
4+
#include <stdexcept>
5+
6+
namespace infinicore::nn {
7+
8+
LayerNorm::LayerNorm(size_t normalized_shape, double eps, const DataType &dtype, const Device &device)
9+
: normalized_shape_(normalized_shape),
10+
eps_(eps),
11+
dtype_(dtype) {
12+
13+
device_ = device;
14+
15+
INFINICORE_NN_PARAMETER_INIT(weight, ({normalized_shape}, dtype_, device));
16+
INFINICORE_NN_PARAMETER_INIT(bias, ({normalized_shape}, dtype_, device));
17+
}
18+
19+
Tensor LayerNorm::forward(const Tensor &x) const {
20+
return op::layer_norm(x, weight_, bias_, static_cast<float>(eps_));
21+
}
22+
23+
std::string LayerNorm::extra_repr() const {
24+
return "LayerNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")";
25+
}
26+
27+
} // namespace infinicore::nn

0 commit comments

Comments
 (0)