@@ -146,7 +146,7 @@ static void run_benchmark(
146146 batch_size,
147147 lengths_sum,
148148 num_rows,
149- embedding_table_fp16.data (),
149+ reinterpret_cast < const uint16_t *>( embedding_table_fp16.data () ),
150150 indices_32.data (),
151151 offsets.data (),
152152 has_weight ? weights.data () : nullptr ,
@@ -158,7 +158,7 @@ static void run_benchmark(
158158 batch_size,
159159 lengths_sum,
160160 num_rows,
161- embedding_table_fp16.data (),
161+ reinterpret_cast < const uint16_t *>( embedding_table_fp16.data () ),
162162 indices.data (),
163163 offsets.data (),
164164 has_weight ? weights.data () : nullptr ,
@@ -172,7 +172,7 @@ static void run_benchmark(
172172 batch_size,
173173 lengths_sum,
174174 num_rows,
175- embedding_table_bf16.data (),
175+ reinterpret_cast < const uint16_t *>( embedding_table_bf16.data () ),
176176 indices_32.data (),
177177 offsets.data (),
178178 has_weight ? weights.data () : nullptr ,
@@ -184,7 +184,7 @@ static void run_benchmark(
184184 batch_size,
185185 lengths_sum,
186186 num_rows,
187- embedding_table_bf16.data (),
187+ reinterpret_cast < const uint16_t *>( embedding_table_bf16.data () ),
188188 indices.data (),
189189 offsets.data (),
190190 has_weight ? weights.data () : nullptr ,
@@ -223,19 +223,19 @@ static void run_benchmark(
223223 embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0 );
224224 auto kernel_fp32_i64 = GenerateEmbeddingSpMDM<float , int64_t >(
225225 embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0 );
226- auto kernel_fp16_i32 = GenerateEmbeddingSpMDM<float16 , int32_t >(
226+ auto kernel_fp16_i32 = GenerateEmbeddingSpMDM<uint16_t , int32_t >(
227227 embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0 );
228- auto kernel_fp16_i64 = GenerateEmbeddingSpMDM<float16 , int64_t >(
228+ auto kernel_fp16_i64 = GenerateEmbeddingSpMDM<uint16_t , int64_t >(
229229 embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0 );
230- auto kernel_bf16_i32 = GenerateEmbeddingSpMDM<bfloat16 , int32_t >(
230+ auto kernel_bf16_i32 = GenerateEmbeddingSpMDM<uint16_t , int32_t >(
231231 embedding_dim,
232232 has_weight,
233233 normalize_by_lengths,
234234 prefetch ? 16 : 0 ,
235235 /* is_weight_positional=*/ false ,
236236 /* use_offsets=*/ true ,
237237 /* is_bf16_out=*/ true );
238- auto kernel_bf16_i64 = GenerateEmbeddingSpMDM<bfloat16 , int64_t >(
238+ auto kernel_bf16_i64 = GenerateEmbeddingSpMDM<uint16_t , int64_t >(
239239 embedding_dim,
240240 has_weight,
241241 normalize_by_lengths,
@@ -254,7 +254,7 @@ static void run_benchmark(
254254 batch_size,
255255 lengths_sum,
256256 num_rows,
257- embedding_table_fp16.data (),
257+ reinterpret_cast < const uint16_t *>( embedding_table_fp16.data () ),
258258 indices_32.data (),
259259 offsets.data (),
260260 has_weight ? weights.data () : nullptr ,
@@ -264,7 +264,7 @@ static void run_benchmark(
264264 batch_size,
265265 lengths_sum,
266266 num_rows,
267- embedding_table_fp16.data (),
267+ reinterpret_cast < const uint16_t *>( embedding_table_fp16.data () ),
268268 indices.data (),
269269 offsets.data (),
270270 has_weight ? weights.data () : nullptr ,
@@ -276,7 +276,7 @@ static void run_benchmark(
276276 batch_size,
277277 lengths_sum,
278278 num_rows,
279- embedding_table_bf16.data (),
279+ reinterpret_cast < const uint16_t *>( embedding_table_bf16.data () ),
280280 indices_32.data (),
281281 offsets.data (),
282282 has_weight ? weights.data () : nullptr ,
@@ -286,7 +286,7 @@ static void run_benchmark(
286286 batch_size,
287287 lengths_sum,
288288 num_rows,
289- embedding_table_bf16.data (),
289+ reinterpret_cast < const uint16_t *>( embedding_table_bf16.data () ),
290290 indices.data (),
291291 offsets.data (),
292292 has_weight ? weights.data () : nullptr ,
0 commit comments