Skip to content

Commit 3c65570

Browse files
committed
issue/900 - support embedding on iluvatar, metax, and moore
1 parent 8f7f447 commit 3c65570

8 files changed

Lines changed: 678 additions & 82 deletions

File tree

src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh renamed to src/infiniop/ops/embedding/cuda/embedding_kernel.cuh

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
#ifndef __EMBEDDING_CUDA_KERNEL_CUH__
22
#define __EMBEDDING_CUDA_KERNEL_CUH__
33

4-
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
5-
#include <cuda_fp16.h>
6-
#include <cuda_runtime.h>
74
#include <type_traits>
85

9-
namespace op::embedding::nvidia {
10-
116
// Helper function to check memory alignment
127
__forceinline__ __device__ bool is_aligned(const void *ptr, size_t alignment) {
138
// Use size_t for pointer arithmetic in device code (more compatible)
@@ -118,61 +113,4 @@ __forceinline__ __device__ void copyScalar(
118113
}
119114
}
120115

121-
template <typename T, typename IndexType>
122-
INFINIOP_CUDA_KERNEL embeddingKernel(
123-
T *__restrict__ output,
124-
const IndexType *__restrict__ indices,
125-
const T *__restrict__ weight,
126-
size_t num_indices,
127-
size_t embedding_dim,
128-
size_t vocab_size) {
129-
// Calculate global thread index
130-
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
131-
132-
if (idx < num_indices) {
133-
// Get the index value
134-
IndexType index_val = __ldg(&indices[idx]);
135-
136-
// Bounds check - handle negative indices gracefully
137-
if (index_val >= 0 && static_cast<size_t>(index_val) < vocab_size) {
138-
// Copy embedding vector from weight to output
139-
const T *src = weight + static_cast<size_t>(index_val) * embedding_dim;
140-
T *dst = output + idx * embedding_dim;
141-
142-
// Choose optimal copy strategy based on type and alignment
143-
if constexpr (std::is_same_v<T, float>) {
144-
// Check alignment for float4 (16 bytes)
145-
bool aligned_16 = is_aligned(src, 16) && is_aligned(dst, 16);
146-
if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) {
147-
copyVectorizedFloat4<IndexType>(dst, src, embedding_dim);
148-
} else if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
149-
// Try float2 if not aligned to 16 bytes
150-
copyVectorizedFloat2<IndexType>(dst, src, embedding_dim);
151-
} else {
152-
copyScalar<T, IndexType>(dst, src, embedding_dim);
153-
}
154-
} else if constexpr (std::is_same_v<T, half>) {
155-
// Use half2 for vectorized access
156-
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
157-
copyVectorizedHalf2<IndexType>(dst, src, embedding_dim);
158-
} else {
159-
copyScalar<T, IndexType>(dst, src, embedding_dim);
160-
}
161-
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
162-
// Use bfloat162 for vectorized access
163-
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
164-
copyVectorizedBFloat162<IndexType>(dst, src, embedding_dim);
165-
} else {
166-
copyScalar<T, IndexType>(dst, src, embedding_dim);
167-
}
168-
} else {
169-
// Fallback to scalar copy with __ldg
170-
copyScalar<T, IndexType>(dst, src, embedding_dim);
171-
}
172-
}
173-
}
174-
}
175-
176-
} // namespace op::embedding::nvidia
177-
178116
#endif // __EMBEDDING_CUDA_KERNEL_CUH__
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __EMBEDDING_METAX_H__
2+
#define __EMBEDDING_METAX_H__
3+
4+
#include "../embedding.h"
5+
6+
DESCRIPTOR(metax)
7+
8+
#endif // __EMBEDDING_METAX_H__
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
#include "../../../../utils.h"
2+
#include "../../../devices/metax/metax_common.h"
3+
#include "../../../devices/metax/metax_kernel_common.h"
4+
#include "../../../tensor.h"
5+
#include "../cuda/embedding_kernel.cuh"
6+
#include "embedding_metax.cuh"
7+
8+
template <typename T, typename IndexType>
9+
INFINIOP_METAX_KERNEL embeddingKernel(
10+
T *__restrict__ output,
11+
const IndexType *__restrict__ indices,
12+
const T *__restrict__ weight,
13+
size_t num_indices,
14+
size_t embedding_dim,
15+
size_t vocab_size) {
16+
// Calculate global thread index
17+
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
18+
19+
if (idx < num_indices) {
20+
// Get the index value
21+
IndexType index_val = __ldg(&indices[idx]);
22+
23+
// Bounds check - handle negative indices gracefully
24+
if (index_val >= 0 && static_cast<size_t>(index_val) < vocab_size) {
25+
// Copy embedding vector from weight to output
26+
const T *src = weight + static_cast<size_t>(index_val) * embedding_dim;
27+
T *dst = output + idx * embedding_dim;
28+
29+
// Choose optimal copy strategy based on type and alignment
30+
if constexpr (std::is_same_v<T, float>) {
31+
// Check alignment for float4 (16 bytes)
32+
bool aligned_16 = is_aligned(src, 16) && is_aligned(dst, 16);
33+
if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) {
34+
copyVectorizedFloat4<IndexType>(dst, src, embedding_dim);
35+
} else if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
36+
// Try float2 if not aligned to 16 bytes
37+
copyVectorizedFloat2<IndexType>(dst, src, embedding_dim);
38+
} else {
39+
copyScalar<T, IndexType>(dst, src, embedding_dim);
40+
}
41+
} else if constexpr (std::is_same_v<T, half>) {
42+
// Use half2 for vectorized access
43+
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
44+
copyVectorizedHalf2<IndexType>(dst, src, embedding_dim);
45+
} else {
46+
copyScalar<T, IndexType>(dst, src, embedding_dim);
47+
}
48+
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
49+
// Use bfloat162 for vectorized access
50+
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
51+
copyVectorizedBFloat162<IndexType>(dst, src, embedding_dim);
52+
} else {
53+
copyScalar<T, IndexType>(dst, src, embedding_dim);
54+
}
55+
} else {
56+
// Fallback to scalar copy with __ldg
57+
copyScalar<T, IndexType>(dst, src, embedding_dim);
58+
}
59+
}
60+
}
61+
}
62+
63+
namespace op::embedding::metax {
64+
65+
struct Descriptor::Opaque {
66+
std::shared_ptr<device::metax::Handle::Internal> internal;
67+
};
68+
69+
Descriptor::~Descriptor() {
70+
delete _opaque;
71+
}
72+
73+
infiniStatus_t Descriptor::create(
74+
infiniopHandle_t handle,
75+
Descriptor **desc_ptr,
76+
infiniopTensorDescriptor_t output_desc,
77+
infiniopTensorDescriptor_t input_desc,
78+
infiniopTensorDescriptor_t weight_desc) {
79+
80+
auto input_shape = input_desc->shape();
81+
auto weight_shape = weight_desc->shape();
82+
83+
// Validate shapes
84+
CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE);
85+
CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE);
86+
87+
// Check output shape matches input shape + embedding_dim
88+
auto output_shape = output_desc->shape();
89+
size_t embedding_dim = weight_shape[1];
90+
CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE);
91+
92+
for (size_t i = 0; i < input_shape.size(); ++i) {
93+
CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE);
94+
}
95+
96+
// Validate dtypes
97+
auto input_dtype = input_desc->dtype();
98+
auto weight_dtype = weight_desc->dtype();
99+
CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64,
100+
INFINI_STATUS_BAD_TENSOR_DTYPE);
101+
CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 ||
102+
weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE);
103+
CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE);
104+
105+
// Calculate number of indices (supporting batch dimension)
106+
size_t num_indices = 1;
107+
for (auto dim : input_shape) {
108+
num_indices *= dim;
109+
}
110+
111+
size_t vocab_size = weight_shape[0];
112+
113+
*desc_ptr = new Descriptor(
114+
num_indices,
115+
embedding_dim,
116+
vocab_size,
117+
input_dtype,
118+
weight_dtype,
119+
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
120+
handle->device,
121+
handle->device_id);
122+
123+
return INFINI_STATUS_SUCCESS;
124+
}
125+
126+
infiniStatus_t Descriptor::calculate(
127+
void *output,
128+
const void *input,
129+
const void *weight,
130+
void *stream) const {
131+
132+
if (_num_indices == 0) {
133+
return INFINI_STATUS_SUCCESS;
134+
}
135+
136+
auto hc_stream = reinterpret_cast<hcStream_t>(stream);
137+
138+
// Dynamic block size optimization based on embedding_dim for Metax platform
139+
size_t block_size = 256; // Default block size for Metax
140+
if (_embedding_dim <= 64) {
141+
block_size = 512; // Small embedding_dim: use larger block for better occupancy
142+
} else if (_embedding_dim >= 1024) {
143+
block_size = 128; // Large embedding_dim: use smaller block to reduce register pressure
144+
}
145+
146+
size_t grid_size = (_num_indices + block_size - 1) / block_size;
147+
148+
// Launch kernel based on dtypes for Metax platform
149+
if (_input_dtype == INFINI_DTYPE_I32) {
150+
const int32_t *indices_ptr = reinterpret_cast<const int32_t *>(input);
151+
152+
if (_weight_dtype == INFINI_DTYPE_F32) {
153+
embeddingKernel<float, int32_t><<<grid_size, block_size, 0, hc_stream>>>(
154+
reinterpret_cast<float *>(output),
155+
indices_ptr,
156+
reinterpret_cast<const float *>(weight),
157+
_num_indices,
158+
_embedding_dim,
159+
_vocab_size);
160+
} else if (_weight_dtype == INFINI_DTYPE_F16) {
161+
embeddingKernel<half, int32_t><<<grid_size, block_size, 0, hc_stream>>>(
162+
reinterpret_cast<half *>(output),
163+
indices_ptr,
164+
reinterpret_cast<const half *>(weight),
165+
_num_indices,
166+
_embedding_dim,
167+
_vocab_size);
168+
} else if (_weight_dtype == INFINI_DTYPE_BF16) {
169+
// Use Metax's bfloat16 type
170+
embeddingKernel<__hpcc_bfloat16, int32_t><<<grid_size, block_size, 0, hc_stream>>>(
171+
reinterpret_cast<__hpcc_bfloat16 *>(output),
172+
indices_ptr,
173+
reinterpret_cast<const __hpcc_bfloat16 *>(weight),
174+
_num_indices,
175+
_embedding_dim,
176+
_vocab_size);
177+
} else {
178+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
179+
}
180+
} else if (_input_dtype == INFINI_DTYPE_I64) {
181+
const int64_t *indices_ptr = reinterpret_cast<const int64_t *>(input);
182+
183+
if (_weight_dtype == INFINI_DTYPE_F32) {
184+
embeddingKernel<float, int64_t><<<grid_size, block_size, 0, hc_stream>>>(
185+
reinterpret_cast<float *>(output),
186+
indices_ptr,
187+
reinterpret_cast<const float *>(weight),
188+
_num_indices,
189+
_embedding_dim,
190+
_vocab_size);
191+
} else if (_weight_dtype == INFINI_DTYPE_F16) {
192+
embeddingKernel<half, int64_t><<<grid_size, block_size, 0, hc_stream>>>(
193+
reinterpret_cast<half *>(output),
194+
indices_ptr,
195+
reinterpret_cast<const half *>(weight),
196+
_num_indices,
197+
_embedding_dim,
198+
_vocab_size);
199+
} else if (_weight_dtype == INFINI_DTYPE_BF16) {
200+
embeddingKernel<__hpcc_bfloat16, int64_t><<<grid_size, block_size, 0, hc_stream>>>(
201+
reinterpret_cast<__hpcc_bfloat16 *>(output),
202+
indices_ptr,
203+
reinterpret_cast<const __hpcc_bfloat16 *>(weight),
204+
_num_indices,
205+
_embedding_dim,
206+
_vocab_size);
207+
} else {
208+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
209+
}
210+
} else {
211+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
212+
}
213+
214+
return INFINI_STATUS_SUCCESS;
215+
}
216+
217+
} // namespace op::embedding::metax
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __EMBEDDING_MOORE_H__
2+
#define __EMBEDDING_MOORE_H__
3+
4+
#include "../embedding.h"
5+
6+
DESCRIPTOR(moore)
7+
8+
#endif // __EMBEDDING_MOORE_H__

0 commit comments

Comments
 (0)