Skip to content

Commit 3981559

Browse files
committed
add embedding/rmsnorm module && add case of TinyLLama hierarchy
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent fbb9d25 commit 3981559

7 files changed

Lines changed: 1078 additions & 418 deletions

File tree

include/infinicore/nn.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
#pragma once
22

3+
#include "nn/embedding.hpp"
34
#include "nn/linear.hpp"
5+
#include "nn/rmsnorm.hpp"
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#pragma once
2+
3+
#include "module.hpp"
4+
#include "../ops.hpp"
5+
#include <optional>
6+
7+
namespace infinicore::nn {
8+
9+
/**
10+
* @brief Embedding layer that maps indices to dense vectors
11+
*
12+
* A simple lookup table that stores embeddings of a fixed dictionary and size.
13+
* This module is often used to store word embeddings and retrieve them using indices.
14+
* The input to the module is a tensor of indices, and the output is the corresponding
15+
* embedding vectors.
16+
*
17+
* Similar to PyTorch's nn.Embedding:
18+
* https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
19+
*
20+
* Example:
21+
* @code
22+
* // Create embedding: 10000 words, 300-dimensional embeddings
23+
* auto embedding = Embedding(10000, 300);
24+
*
25+
* // Input: tensor of indices [batch_size, seq_len]
26+
* auto indices = Tensor::from_data({2, 5}, {3, 5, 12, 8, 99, 0, 1, 45, 67, 23});
27+
*
28+
* // Output: [batch_size, seq_len, embedding_dim] = [2, 5, 300]
29+
* auto embeddings = embedding.forward(indices);
30+
* @endcode
31+
*/
32+
class Embedding : public Module {
33+
public:
34+
/**
35+
* @brief Construct an Embedding layer
36+
*
37+
* @param num_embeddings Size of the dictionary of embeddings (vocabulary size)
38+
* @param embedding_dim The size of each embedding vector
39+
* @param padding_idx If specified, the entries at padding_idx do not contribute to gradient
40+
* and the embedding vector at padding_idx is not updated during training
41+
* @param device Device to create the embedding weight on
42+
*/
43+
Embedding(size_t num_embeddings,
44+
size_t embedding_dim,
45+
std::optional<int64_t> padding_idx = std::nullopt,
46+
const Device &device = Device());
47+
48+
/**
49+
* @brief Forward pass: lookup embeddings for given indices
50+
*
51+
* @param indices Tensor containing indices into the embedding matrix.
52+
* Can be any shape (*), typically [batch_size] or [batch_size, seq_len]
53+
* @return Tensor containing the embedding vectors.
54+
* Shape: (*, embedding_dim) where * matches the input shape
55+
*
56+
* Example:
57+
* Input shape: [2, 3] -> Output shape: [2, 3, embedding_dim]
58+
* Input shape: [10] -> Output shape: [10, embedding_dim]
59+
*/
60+
Tensor forward(const Tensor &indices) const;
61+
62+
/**
63+
* @brief Create an Embedding from pretrained vectors
64+
*
65+
* @param embeddings Pretrained embedding matrix of shape [num_embeddings, embedding_dim]
66+
* @param freeze If true, embeddings will not be updated during training
67+
* @param padding_idx Optional padding index
68+
* @return Embedding module initialized with the pretrained vectors
69+
*/
70+
static Embedding from_pretrained(const Tensor &embeddings,
71+
bool freeze = true,
72+
std::optional<int64_t> padding_idx = std::nullopt);
73+
74+
// Accessors for parameters
75+
Tensor weight() const { return weight_; }
76+
77+
// Module information
78+
size_t num_embeddings() const { return num_embeddings_; }
79+
size_t embedding_dim() const { return embedding_dim_; }
80+
std::optional<int64_t> padding_idx() const { return padding_idx_; }
81+
82+
// String representation
83+
std::string extra_repr() const;
84+
85+
// Direct access to parameters as fields
86+
Parameter weight_;
87+
88+
private:
89+
size_t num_embeddings_; // Vocabulary size
90+
size_t embedding_dim_; // Embedding dimension
91+
std::optional<int64_t> padding_idx_; // Optional padding index
92+
};
93+
94+
} // namespace infinicore::nn

include/infinicore/nn/rmsnorm.hpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#pragma once
2+
3+
#include "module.hpp"
4+
#include "../ops.hpp"
5+
6+
namespace infinicore::nn {
7+
8+
/**
9+
* @brief Root Mean Square Layer Normalization (RMSNorm)
10+
*
11+
* Applies Root Mean Square Layer Normalization over the last dimension.
12+
* Unlike LayerNorm, RMSNorm doesn't subtract mean and doesn't use bias.
13+
*
14+
* Formula: y = (x / RMS(x)) * weight
15+
* where RMS(x) = sqrt(mean(x^2) + eps)
16+
*
17+
* Used in LLaMA, Galactica, and other modern language models as a
18+
* simpler and faster alternative to LayerNorm.
19+
*
20+
* Reference:
21+
* - "Root Mean Square Layer Normalization" (https://arxiv.org/abs/1910.07467)
22+
* - LLaMA implementation: https://github.com/facebookresearch/llama
23+
*
24+
* Example:
25+
* @code
26+
* // Create RMSNorm for hidden size 4096
27+
* auto norm = RMSNorm(4096);
28+
*
29+
* // Input: [batch, seq_len, hidden_size]
30+
* auto input = Tensor::randn({2, 10, 4096});
31+
*
32+
* // Output: [batch, seq_len, hidden_size]
33+
* auto output = norm.forward(input);
34+
* @endcode
35+
*/
36+
class RMSNorm : public Module {
37+
public:
38+
/**
39+
* @brief Construct a RMSNorm layer
40+
*
41+
* @param normalized_shape Size of the feature dimension to normalize (typically hidden_size)
42+
* @param eps Small constant for numerical stability (default: 1e-6)
43+
* @param device Device to create the weight on
44+
*/
45+
RMSNorm(size_t normalized_shape,
46+
double eps = 1e-6,
47+
const Device &device = Device());
48+
49+
/**
50+
* @brief Forward pass: apply RMSNorm
51+
*
52+
* @param x Input tensor of shape (*, normalized_shape) where * is any number of dimensions
53+
* @return Normalized tensor with same shape as input
54+
*
55+
* The normalization is applied over the last dimension.
56+
* For example:
57+
* Input: [batch, seq_len, hidden_size] -> normalize over hidden_size
58+
* Input: [batch, hidden_size] -> normalize over hidden_size
59+
*/
60+
Tensor forward(const Tensor &x) const;
61+
62+
// Accessors for parameters
63+
Tensor weight() const { return weight_; }
64+
65+
// Module information
66+
size_t normalized_shape() const { return normalized_shape_; }
67+
double eps() const { return eps_; }
68+
69+
// String representation
70+
std::string extra_repr() const;
71+
72+
// Direct access to parameters as fields
73+
Parameter weight_;
74+
75+
private:
76+
size_t normalized_shape_; // Size of the feature dimension
77+
double eps_; // Epsilon for numerical stability
78+
};
79+
80+
} // namespace infinicore::nn

0 commit comments

Comments
 (0)