|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include "module.hpp" |
| 4 | +#include "../ops.hpp" |
| 5 | +#include <optional> |
| 6 | + |
| 7 | +namespace infinicore::nn { |
| 8 | + |
| 9 | +/** |
| 10 | + * @brief Embedding layer that maps indices to dense vectors |
| 11 | + * |
| 12 | + * A simple lookup table that stores embeddings of a fixed dictionary and size. |
| 13 | + * This module is often used to store word embeddings and retrieve them using indices. |
| 14 | + * The input to the module is a tensor of indices, and the output is the corresponding |
| 15 | + * embedding vectors. |
| 16 | + * |
| 17 | + * Similar to PyTorch's nn.Embedding: |
| 18 | + * https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html |
| 19 | + * |
| 20 | + * Example: |
| 21 | + * @code |
| 22 | + * // Create embedding: 10000 words, 300-dimensional embeddings |
| 23 | + * auto embedding = Embedding(10000, 300); |
| 24 | + * |
| 25 | + * // Input: tensor of indices [batch_size, seq_len] |
| 26 | + * auto indices = Tensor::from_data({2, 5}, {3, 5, 12, 8, 99, 0, 1, 45, 67, 23}); |
| 27 | + * |
| 28 | + * // Output: [batch_size, seq_len, embedding_dim] = [2, 5, 300] |
| 29 | + * auto embeddings = embedding.forward(indices); |
| 30 | + * @endcode |
| 31 | + */ |
| 32 | +class Embedding : public Module { |
| 33 | +public: |
| 34 | + /** |
| 35 | + * @brief Construct an Embedding layer |
| 36 | + * |
| 37 | + * @param num_embeddings Size of the dictionary of embeddings (vocabulary size) |
| 38 | + * @param embedding_dim The size of each embedding vector |
| 39 | + * @param padding_idx If specified, the entries at padding_idx do not contribute to gradient |
| 40 | + * and the embedding vector at padding_idx is not updated during training |
| 41 | + * @param device Device to create the embedding weight on |
| 42 | + */ |
| 43 | + Embedding(size_t num_embeddings, |
| 44 | + size_t embedding_dim, |
| 45 | + std::optional<int64_t> padding_idx = std::nullopt, |
| 46 | + const Device &device = Device()); |
| 47 | + |
| 48 | + /** |
| 49 | + * @brief Forward pass: lookup embeddings for given indices |
| 50 | + * |
| 51 | + * @param indices Tensor containing indices into the embedding matrix. |
| 52 | + * Can be any shape (*), typically [batch_size] or [batch_size, seq_len] |
| 53 | + * @return Tensor containing the embedding vectors. |
| 54 | + * Shape: (*, embedding_dim) where * matches the input shape |
| 55 | + * |
| 56 | + * Example: |
| 57 | + * Input shape: [2, 3] -> Output shape: [2, 3, embedding_dim] |
| 58 | + * Input shape: [10] -> Output shape: [10, embedding_dim] |
| 59 | + */ |
| 60 | + Tensor forward(const Tensor &indices) const; |
| 61 | + |
| 62 | + /** |
| 63 | + * @brief Create an Embedding from pretrained vectors |
| 64 | + * |
| 65 | + * @param embeddings Pretrained embedding matrix of shape [num_embeddings, embedding_dim] |
| 66 | + * @param freeze If true, embeddings will not be updated during training |
| 67 | + * @param padding_idx Optional padding index |
| 68 | + * @return Embedding module initialized with the pretrained vectors |
| 69 | + */ |
| 70 | + static Embedding from_pretrained(const Tensor &embeddings, |
| 71 | + bool freeze = true, |
| 72 | + std::optional<int64_t> padding_idx = std::nullopt); |
| 73 | + |
| 74 | + // Accessors for parameters |
| 75 | + Tensor weight() const { return weight_; } |
| 76 | + |
| 77 | + // Module information |
| 78 | + size_t num_embeddings() const { return num_embeddings_; } |
| 79 | + size_t embedding_dim() const { return embedding_dim_; } |
| 80 | + std::optional<int64_t> padding_idx() const { return padding_idx_; } |
| 81 | + |
| 82 | + // String representation |
| 83 | + std::string extra_repr() const; |
| 84 | + |
| 85 | + // Direct access to parameters as fields |
| 86 | + Parameter weight_; |
| 87 | + |
| 88 | +private: |
| 89 | + size_t num_embeddings_; // Vocabulary size |
| 90 | + size_t embedding_dim_; // Embedding dimension |
| 91 | + std::optional<int64_t> padding_idx_; // Optional padding index |
| 92 | +}; |
| 93 | + |
| 94 | +} // namespace infinicore::nn |
0 commit comments