Skip to content

Commit 8c49f1c

Browse files
committed
issue/900 - maintains classic embedding for devices yet to be worked on
1 parent 84d73d8 commit 8c49f1c

2 files changed

Lines changed: 78 additions & 14 deletions

File tree

src/infinicore/nn/embedding.cc

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,86 @@ Embedding::Embedding(size_t num_embeddings,
4343
}
4444

4545
Tensor Embedding::forward(const Tensor &indices) const {
46-
// Ensure indices are on the same device as weight
47-
// This avoids synchronous memcpy in ops layer which would hurt performance
48-
Tensor indices_on_device = indices;
49-
if (indices->device() != device_) {
50-
indices_on_device = indices->to(device_);
46+
// TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach
47+
if (device_ == Device::Type::NVIDIA || device_ == Device::Type::ILUVATAR || device_ == Device::Type::METAX || device_ == Device::Type::MOORE) {
48+
// Use op::embedding which supports device-side input and batch dimension
49+
return op::embedding(indices->contiguous()->to(device_), weight_);
5150
}
5251

53-
// Ensure indices are contiguous for efficient access
54-
// op::embedding now supports device-side input for graph recording
55-
Tensor indices_contiguous = indices_on_device->is_contiguous() ? indices_on_device : indices_on_device->contiguous();
52+
// Get the shape of indices
53+
auto indices_shape = indices->shape();
5654

57-
// Use op::embedding which now supports device-side input and batch dimension
58-
// This enables full graph recording support without synchronization
59-
return op::embedding(indices_contiguous, weight_);
55+
// Output shape: indices_shape + [embedding_dim]
56+
std::vector<size_t> output_shape = indices_shape;
57+
output_shape.push_back(embedding_dim_);
58+
59+
// Create output tensor on the same device as weight
60+
auto out = Tensor::empty(output_shape, weight_->dtype(), weight_->device());
61+
62+
// Flatten indices for sequential row copies
63+
auto cpu_device = Device(Device::Type::CPU, 0);
64+
auto indices_cpu = indices->to(cpu_device)->contiguous();
65+
66+
// Calculate total number of lookups
67+
size_t num_lookups = 1;
68+
for (auto dim : indices_shape) {
69+
num_lookups *= dim;
70+
}
71+
72+
const size_t row_bytes = embedding_dim_ * dsize(weight_->dtype());
73+
74+
// Source and destination base pointers
75+
auto *weight_base = weight_->data();
76+
auto *out_base = out->data();
77+
78+
// Helper lambda to read index based on dtype with bounds checking
79+
auto read_index = [&](size_t i) -> int64_t {
80+
auto dtype = indices_cpu->dtype();
81+
if (dtype == DataType::I32) {
82+
const auto *data = reinterpret_cast<const int32_t *>(indices_cpu->data());
83+
return static_cast<int64_t>(data[i]);
84+
} else if (dtype == DataType::I64) {
85+
const auto *data = reinterpret_cast<const int64_t *>(indices_cpu->data());
86+
return data[i];
87+
} else if (dtype == DataType::U32) {
88+
const auto *data = reinterpret_cast<const uint32_t *>(indices_cpu->data());
89+
return static_cast<int64_t>(data[i]);
90+
} else if (dtype == DataType::U64) {
91+
const auto *data = reinterpret_cast<const uint64_t *>(indices_cpu->data());
92+
uint64_t val = data[i];
93+
// Check if value can fit in int64_t
94+
if (val > static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
95+
throw std::out_of_range("Index value out of range for int64_t: " + std::to_string(val));
96+
}
97+
return static_cast<int64_t>(val);
98+
} else {
99+
throw std::runtime_error("Embedding indices must be integer type, got dtype=" + std::to_string(static_cast<int>(dtype)));
100+
}
101+
};
102+
103+
if (weight_->device().getType() == Device::Type::CPU) {
104+
// CPU path: memcpy row by row
105+
for (size_t i = 0; i < num_lookups; ++i) {
106+
int64_t idx = read_index(i);
107+
if (idx < 0 || idx >= static_cast<int64_t>(num_embeddings_)) {
108+
throw std::out_of_range(
109+
"Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")");
110+
}
111+
std::memcpy(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes);
112+
}
113+
} else {
114+
// Device path: use stream-ordered D2D copies
115+
for (size_t i = 0; i < num_lookups; ++i) {
116+
int64_t idx = read_index(i);
117+
if (idx < 0 || idx >= static_cast<int64_t>(num_embeddings_)) {
118+
throw std::out_of_range(
119+
"Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")");
120+
}
121+
context::memcpyD2D(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes);
122+
}
123+
}
124+
125+
return out;
60126
}
61127

62128
std::string Embedding::extra_repr() const {

src/infinicore/ops/embedding/embedding.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
#include "infinicore/ops/embedding.hpp"
2+
23
#include "../../utils.hpp"
3-
#include "infinicore/context/context.hpp"
4-
#include <cstring>
5-
#include <stdexcept>
64

75
namespace infinicore::op {
86
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Embedding);

0 commit comments

Comments
 (0)