@@ -69,13 +69,17 @@ static inline void fill_output(
6969 } else if constexpr (std::is_same_v<OutType, uint16_t >) {
7070 if (is_bf16_out) {
7171 for (int j = 0 ; j < block_size; ++j) {
72- out[j] = cpu_float2bfloat16 (src[j]);
72+ out[j] = cpu_float2bfloat16 (src[j]). val ;
7373 }
7474 } else {
7575 for (int j = 0 ; j < block_size; ++j) {
76- out[j] = cpu_float2half (src[j]);
76+ out[j] = cpu_float2half (src[j]). val ;
7777 }
7878 }
79+ } else if constexpr (std::is_same_v<OutType, float16>) {
80+ for (int j = 0 ; j < block_size; ++j) {
81+ out[j] = cpu_float2half (src[j]);
82+ }
7983 }
8084}
8185
@@ -1053,18 +1057,24 @@ static bool ALWAYS_INLINE EmbeddingSpMDMRowWiseSparse_autovec(
10531057#ifdef FBGEMM_VECTOR_WIDTH
10541058 for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
10551059 const InType* inptr = input_row++;
1056- out[j] = std::fma (
1057- weight,
1058- std::is_same_v<InType, float16> ? cpu_half2float (*inptr) : *inptr,
1059- out[j]);
1060+ float in_val = 0 .f ;
1061+ if constexpr (std::is_same_v<InType, float16>) {
1062+ in_val = cpu_half2float (*inptr);
1063+ } else {
1064+ in_val = *inptr;
1065+ }
1066+ out[j] = std::fma (weight, in_val, out[j]);
10601067 }
10611068#endif
10621069 for (; j < block_size; ++j) {
10631070 const InType* inptr = input_row++;
1064- out[j] = std::fma (
1065- weight,
1066- std::is_same_v<InType, float16> ? cpu_half2float (*inptr) : *inptr,
1067- out[j]);
1071+ float in_val = 0 .f ;
1072+ if constexpr (std::is_same_v<InType, float16>) {
1073+ in_val = cpu_half2float (*inptr);
1074+ } else {
1075+ in_val = *inptr;
1076+ }
1077+ out[j] = std::fma (weight, in_val, out[j]);
10681078 }
10691079 }
10701080 if (normalize_by_lengths && len) {
@@ -2303,9 +2313,10 @@ GenerateEmbeddingSpMDMRowWiseSparse_autovec(
23032313 INSTANTIATE_SPMDM_NBIT_WITH_STRIDES (INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
23042314 INSTANTIATE_SPMDM_FP8 (INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
23052315
2306- #define INSTANTIATE_SPMDM_OUT_T (INDEX_TYPE, OFFSET_TYPE ) \
2307- INSTANTIATE_SPMDM_BASE (INDEX_TYPE, OFFSET_TYPE, float ) \
2308- INSTANTIATE_SPMDM_BASE (INDEX_TYPE, OFFSET_TYPE, float16) \
2316+ #define INSTANTIATE_SPMDM_OUT_T (INDEX_TYPE, OFFSET_TYPE ) \
2317+ INSTANTIATE_SPMDM_BASE (INDEX_TYPE, OFFSET_TYPE, float ) \
2318+ INSTANTIATE_SPMDM_BASE (INDEX_TYPE, OFFSET_TYPE, float16) \
2319+ INSTANTIATE_SPMDM_BASE (INDEX_TYPE, OFFSET_TYPE, uint16_t ) \
23092320 INSTANTIATE_SPMDM_BASE (INDEX_TYPE, OFFSET_TYPE, uint8_t )
23102321
23112322#define INSTANTIATE_SPMDM_OFFSET_T (INDEX_TYPE ) \
@@ -2356,10 +2367,11 @@ INSTANTIATE_SPMDM_OFFSET_T(int64_t)
23562367 bool is_bf16_out, \
23572368 bool is_bf16_in);
23582369
2359- #define INSTANTIATE_SPMDM_OUT_T (IN_TYPE, INDEX_TYPE, OFFSET_TYPE ) \
2360- INSTANTIATE_SPMDM_BASE (IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float ) \
2361- INSTANTIATE_SPMDM_BASE (IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float16) \
2362- INSTANTIATE_SPMDM_BASE (IN_TYPE, INDEX_TYPE, OFFSET_TYPE, std::uint8_t ) \
2370+ #define INSTANTIATE_SPMDM_OUT_T (IN_TYPE, INDEX_TYPE, OFFSET_TYPE ) \
2371+ INSTANTIATE_SPMDM_BASE (IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float ) \
2372+ INSTANTIATE_SPMDM_BASE (IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float16) \
2373+ INSTANTIATE_SPMDM_BASE (IN_TYPE, INDEX_TYPE, OFFSET_TYPE, std::uint16_t ) \
2374+ INSTANTIATE_SPMDM_BASE (IN_TYPE, INDEX_TYPE, OFFSET_TYPE, std::uint8_t ) \
23632375 INSTANTIATE_SPMDM_ROWWISE (IN_TYPE, INDEX_TYPE, OFFSET_TYPE)
23642376
23652377#define INSTANTIATE_SPMDM_OFFSET_T (IN_TYPE, INDEX_TYPE ) \
@@ -2372,6 +2384,7 @@ INSTANTIATE_SPMDM_OFFSET_T(int64_t)
23722384
23732385INSTANTIATE_SPMDM_INDEX_T (float )
23742386INSTANTIATE_SPMDM_INDEX_T (float16)
2387+ INSTANTIATE_SPMDM_INDEX_T (std::uint16_t )
23752388INSTANTIATE_SPMDM_INDEX_T (std::uint8_t )
23762389
23772390#undef INSTANTIATE_SPMDM_ROWWISE
0 commit comments