@@ -43,20 +43,86 @@ Embedding::Embedding(size_t num_embeddings,
4343}
4444
4545Tensor Embedding::forward (const Tensor &indices) const {
46- // Ensure indices are on the same device as weight
47- // This avoids synchronous memcpy in ops layer which would hurt performance
48- Tensor indices_on_device = indices;
49- if (indices->device () != device_) {
50- indices_on_device = indices->to (device_);
46+ // TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach
47+ if (device_ == Device::Type::NVIDIA || device_ == Device::Type::ILUVATAR || device_ == Device::Type::METAX || device_ == Device::Type::MOORE) {
48+ // Use op::embedding which supports device-side input and batch dimension
49+ return op::embedding (indices->contiguous ()->to (device_), weight_);
5150 }
5251
53- // Ensure indices are contiguous for efficient access
54- // op::embedding now supports device-side input for graph recording
55- Tensor indices_contiguous = indices_on_device->is_contiguous () ? indices_on_device : indices_on_device->contiguous ();
52+ // Get the shape of indices
53+ auto indices_shape = indices->shape ();
5654
57- // Use op::embedding which now supports device-side input and batch dimension
58- // This enables full graph recording support without synchronization
59- return op::embedding (indices_contiguous, weight_);
55+ // Output shape: indices_shape + [embedding_dim]
56+ std::vector<size_t > output_shape = indices_shape;
57+ output_shape.push_back (embedding_dim_);
58+
59+ // Create output tensor on the same device as weight
60+ auto out = Tensor::empty (output_shape, weight_->dtype (), weight_->device ());
61+
62+ // Flatten indices for sequential row copies
63+ auto cpu_device = Device (Device::Type::CPU, 0 );
64+ auto indices_cpu = indices->to (cpu_device)->contiguous ();
65+
66+ // Calculate total number of lookups
67+ size_t num_lookups = 1 ;
68+ for (auto dim : indices_shape) {
69+ num_lookups *= dim;
70+ }
71+
72+ const size_t row_bytes = embedding_dim_ * dsize (weight_->dtype ());
73+
74+ // Source and destination base pointers
75+ auto *weight_base = weight_->data ();
76+ auto *out_base = out->data ();
77+
78+ // Helper lambda to read index based on dtype with bounds checking
79+ auto read_index = [&](size_t i) -> int64_t {
80+ auto dtype = indices_cpu->dtype ();
81+ if (dtype == DataType::I32) {
82+ const auto *data = reinterpret_cast <const int32_t *>(indices_cpu->data ());
83+ return static_cast <int64_t >(data[i]);
84+ } else if (dtype == DataType::I64) {
85+ const auto *data = reinterpret_cast <const int64_t *>(indices_cpu->data ());
86+ return data[i];
87+ } else if (dtype == DataType::U32) {
88+ const auto *data = reinterpret_cast <const uint32_t *>(indices_cpu->data ());
89+ return static_cast <int64_t >(data[i]);
90+ } else if (dtype == DataType::U64) {
91+ const auto *data = reinterpret_cast <const uint64_t *>(indices_cpu->data ());
92+ uint64_t val = data[i];
93+ // Check if value can fit in int64_t
94+ if (val > static_cast <uint64_t >(std::numeric_limits<int64_t >::max ())) {
95+ throw std::out_of_range (" Index value out of range for int64_t: " + std::to_string (val));
96+ }
97+ return static_cast <int64_t >(val);
98+ } else {
99+ throw std::runtime_error (" Embedding indices must be integer type, got dtype=" + std::to_string (static_cast <int >(dtype)));
100+ }
101+ };
102+
103+ if (weight_->device ().getType () == Device::Type::CPU) {
104+ // CPU path: memcpy row by row
105+ for (size_t i = 0 ; i < num_lookups; ++i) {
106+ int64_t idx = read_index (i);
107+ if (idx < 0 || idx >= static_cast <int64_t >(num_embeddings_)) {
108+ throw std::out_of_range (
109+ " Index out of range: " + std::to_string (idx) + " (num_embeddings=" + std::to_string (num_embeddings_) + " )" );
110+ }
111+ std::memcpy (out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes);
112+ }
113+ } else {
114+ // Device path: use stream-ordered D2D copies
115+ for (size_t i = 0 ; i < num_lookups; ++i) {
116+ int64_t idx = read_index (i);
117+ if (idx < 0 || idx >= static_cast <int64_t >(num_embeddings_)) {
118+ throw std::out_of_range (
119+ " Index out of range: " + std::to_string (idx) + " (num_embeddings=" + std::to_string (num_embeddings_) + " )" );
120+ }
121+ context::memcpyD2D (out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes);
122+ }
123+ }
124+
125+ return out;
60126}
61127
62128std::string Embedding::extra_repr () const {
0 commit comments