Skip to content

Commit e706f4c

Browse files
committed
Make float16/bfloat16 distinct types
1 parent 0872fda commit e706f4c

15 files changed

Lines changed: 248 additions & 211 deletions

fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,23 @@ using Tensor = at::Tensor;
2525

2626
namespace fbgemm_gpu {
2727

28+
// Map at::Half/at::BFloat16 to the corresponding fbgemm strict types;
29+
// other types (e.g. float) pass through unchanged.
30+
template <typename T>
31+
struct to_fbgemm_type {
32+
using type = T;
33+
};
34+
template <>
35+
struct to_fbgemm_type<at::Half> {
36+
using type = fbgemm::float16;
37+
};
38+
template <>
39+
struct to_fbgemm_type<at::BFloat16> {
40+
using type = fbgemm::bfloat16;
41+
};
42+
template <typename T>
43+
using to_fbgemm_type_t = typename to_fbgemm_type<T>::type;
44+
2845
template <typename input_t>
2946
Tensor& _float_to_fused8bitrowwise_cpu_out_t(
3047
Tensor& output,
@@ -55,7 +72,7 @@ Tensor& _float_to_fused8bitrowwise_cpu_out_t(
5572
return output;
5673
}
5774

58-
template <typename output_t, bool is_uint16_t_of_type_bf16 = false>
75+
template <typename output_t>
5976
Tensor& _fused8bitrowwise_to_float_cpu_out_t(
6077
Tensor& output,
6178
const Tensor& input,
@@ -86,9 +103,7 @@ Tensor& _fused8bitrowwise_to_float_cpu_out_t(
86103
auto output_data = static_cast<output_t*>(
87104
output.mutable_data_ptr()); // output.mutable_data_ptr<output_t>(); ->
88105
// Yields unresolved data_ptr symbol.
89-
fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<
90-
output_t,
91-
is_uint16_t_of_type_bf16>(
106+
fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<output_t>(
92107
input.const_data_ptr<uint8_t>(),
93108
nrows,
94109
ncols,
@@ -206,17 +221,12 @@ Tensor _fusednbitrowwise_sbfront_to_float_or_half_cpu(
206221
"Unsupported output dtype for _fusednbitrowwise_sbfront_to_float_or_half_cpu");
207222
}
208223

209-
using output_ty = std::
210-
conditional_t<std::is_same_v<output_t, float>, float, fbgemm::float16>;
224+
using output_ty = to_fbgemm_type_t<output_t>;
211225
output_ty* output_data = static_cast<output_ty*>(
212226
output.mutable_data_ptr()); // output.mutable_data_ptr<output_t>(); ->
213227
// Yields unresolved data_ptr symbol.
214228

215-
constexpr bool is_uint16_t_of_type_bf16 =
216-
std::is_same_v<output_t, at::BFloat16>;
217-
fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<
218-
output_ty,
219-
is_uint16_t_of_type_bf16>(
229+
fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<output_ty>(
220230
bit_rate,
221231
input.const_data_ptr<uint8_t>(),
222232
nrows,
@@ -234,7 +244,7 @@ Tensor& _fused8bitrowwise_to_float_cpu_out(
234244
const Tensor& input,
235245
const bool scale_bias_last,
236246
const bool quant_padding_float_type) {
237-
return _fused8bitrowwise_to_float_cpu_out_t<float, false>(
247+
return _fused8bitrowwise_to_float_cpu_out_t<float>(
238248
output, input, scale_bias_last, quant_padding_float_type);
239249
}
240250

@@ -243,7 +253,7 @@ Tensor& fused8bitrowwise_to_half_cpu_out(
243253
const Tensor& input,
244254
const bool scale_bias_last,
245255
const bool quant_padding_float_type) {
246-
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::float16, false>(
256+
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::float16>(
247257
output, input, scale_bias_last, quant_padding_float_type);
248258
}
249259

@@ -252,7 +262,7 @@ Tensor& _fused8bitrowwise_to_bfloat16_cpu_out(
252262
const Tensor& input,
253263
const bool scale_bias_last,
254264
const bool quant_padding_float_type) {
255-
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::bfloat16, true>(
265+
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::bfloat16>(
256266
output, input, scale_bias_last, quant_padding_float_type);
257267
}
258268

include/fbgemm/FloatConversion.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -210,18 +210,18 @@ inline typename Tgt::value_type ieee754_trunc(typename Src::value_type value) {
210210

211211
inline float16 cpu_float2half_rn(float f) {
212212
uint32_t f_u32 = std::bit_cast<uint32_t>(f);
213-
return detail::ieee754_trunc<
213+
return {detail::ieee754_trunc<
214214
/*Src=*/detail::IEEE754Single,
215215
/*Tgt=*/detail::IEEE754Half,
216-
detail::RoundingMode::ToNearestTiesToEven>(f_u32);
216+
detail::RoundingMode::ToNearestTiesToEven>(f_u32)};
217217
}
218218

219219
inline float16 cpu_float2half_rz(float f) {
220220
uint32_t f_u32 = std::bit_cast<uint32_t>(f);
221-
return detail::ieee754_trunc<
221+
return {detail::ieee754_trunc<
222222
/*Src=*/detail::IEEE754Single,
223223
/*Tgt=*/detail::IEEE754Half,
224-
detail::RoundingMode::ToZero>(f_u32);
224+
detail::RoundingMode::ToZero>(f_u32)};
225225
}
226226

227227
// Converts a 16-bit unsigned integer representation of a IEEE754 half-precision
@@ -245,10 +245,10 @@ inline float cpu_half2float_ref(const float16 h) {
245245
constexpr uint32_t f32_most_significant_bit = 1u << 22;
246246

247247
// Get sign and exponent alone by themselves
248-
uint32_t sign_bit = (h >> f16_num_non_sign_bits) & 1;
249-
uint32_t exponent = (h >> f16_num_mantissa_bits) & f16_exponent_mask;
248+
uint32_t sign_bit = (h.val >> f16_num_non_sign_bits) & 1;
249+
uint32_t exponent = (h.val >> f16_num_mantissa_bits) & f16_exponent_mask;
250250
// Shift mantissa so that it fills the most significant bits of a float32
251-
uint32_t mantissa = (h & f16_mantissa_mask)
251+
uint32_t mantissa = (h.val & f16_mantissa_mask)
252252
<< (f32_num_mantissa_bits - f16_num_mantissa_bits);
253253

254254
if (exponent == f16_exponent_mask) { // NaN or Inf
@@ -280,10 +280,10 @@ inline float cpu_half2float_ref(const float16 h) {
280280

281281
inline float cpu_half2float(const float16 h) {
282282
#ifdef HAS_NATIVE_FP16_TYPE
283-
return std::bit_cast<__fp16>(h);
283+
return std::bit_cast<__fp16>(h.val);
284284
#elif defined(HAS_F16C)
285285
// Use F16C VCVTPH2PS instruction
286-
__m128i v = _mm_cvtsi32_si128(static_cast<int>(h));
286+
__m128i v = _mm_cvtsi32_si128(static_cast<int>(h.val));
287287
return _mm_cvtss_f32(_mm_cvtph_ps(v));
288288
#else
289289
return cpu_half2float_ref(h);
@@ -293,25 +293,25 @@ inline float cpu_half2float(const float16 h) {
293293
inline float16 cpu_float2half(const float f) {
294294
#ifdef HAS_NATIVE_FP16_TYPE
295295
__fp16 h = f;
296-
return std::bit_cast<float16>(h);
296+
return {std::bit_cast<uint16_t>(h)};
297297
#elif defined(HAS_F16C)
298298
// Use F16C VCVTPS2PH instruction
299299
__m128 v = _mm_set_ss(f);
300-
return static_cast<float16>(
301-
_mm_extract_epi16(_mm_cvtps_ph(v, _MM_FROUND_TO_NEAREST_INT), 0));
300+
return {static_cast<uint16_t>(
301+
_mm_extract_epi16(_mm_cvtps_ph(v, _MM_FROUND_TO_NEAREST_INT), 0))};
302302
#else
303303
return cpu_float2half_rn(f);
304304
#endif
305305
}
306306

307307
inline float cpu_bf162float(bfloat16 src) {
308-
uint32_t val_fp32 = static_cast<uint32_t>(src) << 16;
308+
uint32_t val_fp32 = static_cast<uint32_t>(src.val) << 16;
309309
return std::bit_cast<float>(val_fp32);
310310
}
311311

312312
inline bfloat16 cpu_float2bfloat16(float src) {
313313
uint32_t temp = std::bit_cast<uint32_t>(src);
314-
return (temp + (1u << 15)) >> 16;
314+
return {static_cast<uint16_t>((temp + (1u << 15)) >> 16)};
315315
}
316316

317317
} // namespace fbgemm

include/fbgemm/QuantUtils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat(
324324
* This version intentionally supports only 8-bit because
325325
* the corresponding quantize version only supports 8-bit.
326326
*/
327-
template <typename OutputType, bool is_uint16_t_of_type_bf16 = false>
327+
template <typename OutputType>
328328
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
329329
const uint8_t* input,
330330
size_t input_rows,
@@ -360,7 +360,7 @@ FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef(
360360
* Same as FusedNBitRowwiseQuantizedSBHalfToFloat but unoptimized.
361361
* This should not be called directly except in testing.
362362
*/
363-
template <typename OutputType, bool is_uint16_t_of_type_bf16 = false>
363+
template <typename OutputType>
364364
FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
365365
int bit_rate,
366366
const uint8_t* input,
@@ -373,7 +373,7 @@ FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
373373
* Same as Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf but unoptimized.
374374
* This should not be called directly except in testing.
375375
*/
376-
template <typename OutputType, bool is_uint16_t_of_type_bf16 = false>
376+
template <typename OutputType>
377377
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
378378
const uint8_t* input,
379379
size_t input_rows,

include/fbgemm/Types.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,18 @@
1212

1313
namespace fbgemm {
1414

15-
using float16 = std::uint16_t;
16-
using bfloat16 = std::uint16_t;
15+
struct float16 {
16+
uint16_t val;
17+
bool operator==(const float16&) const = default;
18+
};
19+
20+
struct bfloat16 {
21+
uint16_t val;
22+
bool operator==(const bfloat16&) const = default;
23+
};
24+
25+
static_assert(sizeof(float16) == 2);
26+
static_assert(sizeof(bfloat16) == 2);
1727

1828
constexpr int64_t round_up(int64_t val, int64_t unit) {
1929
return (val + unit - 1) / unit * unit;

src/EmbeddingSpMDM.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,6 +1660,7 @@ GenerateEmbeddingSpMDMRowWiseSparse(
16601660
INSTANTIATE_SPMDMFP8_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
16611661
#define INSTANTIATE_SPMDMFP8_BASE_float(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
16621662
#define INSTANTIATE_SPMDMFP8_BASE_uint16_t(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
1663+
#define INSTANTIATE_SPMDMFP8_BASE_float16(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
16631664

16641665
#define INSTANTIATE_SPMDM_BASE_THREAD_LOCAL( \
16651666
IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \

src/EmbeddingSpMDMAutovec.cc

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,17 @@ static inline void fill_output(
6969
} else if constexpr (std::is_same_v<OutType, uint16_t>) {
7070
if (is_bf16_out) {
7171
for (int j = 0; j < block_size; ++j) {
72-
out[j] = cpu_float2bfloat16(src[j]);
72+
out[j] = cpu_float2bfloat16(src[j]).val;
7373
}
7474
} else {
7575
for (int j = 0; j < block_size; ++j) {
76-
out[j] = cpu_float2half(src[j]);
76+
out[j] = cpu_float2half(src[j]).val;
7777
}
7878
}
79+
} else if constexpr (std::is_same_v<OutType, float16>) {
80+
for (int j = 0; j < block_size; ++j) {
81+
out[j] = cpu_float2half(src[j]);
82+
}
7983
}
8084
}
8185

@@ -1053,18 +1057,24 @@ static bool ALWAYS_INLINE EmbeddingSpMDMRowWiseSparse_autovec(
10531057
#ifdef FBGEMM_VECTOR_WIDTH
10541058
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
10551059
const InType* inptr = input_row++;
1056-
out[j] = std::fma(
1057-
weight,
1058-
std::is_same_v<InType, float16> ? cpu_half2float(*inptr) : *inptr,
1059-
out[j]);
1060+
float in_val = 0.f;
1061+
if constexpr (std::is_same_v<InType, float16>) {
1062+
in_val = cpu_half2float(*inptr);
1063+
} else {
1064+
in_val = *inptr;
1065+
}
1066+
out[j] = std::fma(weight, in_val, out[j]);
10601067
}
10611068
#endif
10621069
for (; j < block_size; ++j) {
10631070
const InType* inptr = input_row++;
1064-
out[j] = std::fma(
1065-
weight,
1066-
std::is_same_v<InType, float16> ? cpu_half2float(*inptr) : *inptr,
1067-
out[j]);
1071+
float in_val = 0.f;
1072+
if constexpr (std::is_same_v<InType, float16>) {
1073+
in_val = cpu_half2float(*inptr);
1074+
} else {
1075+
in_val = *inptr;
1076+
}
1077+
out[j] = std::fma(weight, in_val, out[j]);
10681078
}
10691079
}
10701080
if (normalize_by_lengths && len) {
@@ -2303,9 +2313,10 @@ GenerateEmbeddingSpMDMRowWiseSparse_autovec(
23032313
INSTANTIATE_SPMDM_NBIT_WITH_STRIDES(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
23042314
INSTANTIATE_SPMDM_FP8(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
23052315

2306-
#define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \
2307-
INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float) \
2308-
INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \
2316+
#define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \
2317+
INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float) \
2318+
INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \
2319+
INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, uint16_t) \
23092320
INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, uint8_t)
23102321

23112322
#define INSTANTIATE_SPMDM_OFFSET_T(INDEX_TYPE) \
@@ -2356,10 +2367,11 @@ INSTANTIATE_SPMDM_OFFSET_T(int64_t)
23562367
bool is_bf16_out, \
23572368
bool is_bf16_in);
23582369

2359-
#define INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \
2360-
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float) \
2361-
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float16) \
2362-
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, std::uint8_t) \
2370+
#define INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \
2371+
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float) \
2372+
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float16) \
2373+
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, std::uint16_t) \
2374+
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, std::uint8_t) \
23632375
INSTANTIATE_SPMDM_ROWWISE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE)
23642376

23652377
#define INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, INDEX_TYPE) \
@@ -2372,6 +2384,7 @@ INSTANTIATE_SPMDM_OFFSET_T(int64_t)
23722384

23732385
INSTANTIATE_SPMDM_INDEX_T(float)
23742386
INSTANTIATE_SPMDM_INDEX_T(float16)
2387+
INSTANTIATE_SPMDM_INDEX_T(std::uint16_t)
23752388
INSTANTIATE_SPMDM_INDEX_T(std::uint8_t)
23762389

23772390
#undef INSTANTIATE_SPMDM_ROWWISE

src/EmbeddingSpMDMAvx2.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ bool EmbeddingSpMDMBlockSize1_(
151151

152152
INSTANTIATE_SPMDM_INDEX_T(float)
153153
INSTANTIATE_SPMDM_INDEX_T(float16)
154+
INSTANTIATE_SPMDM_INDEX_T(std::uint16_t)
154155
INSTANTIATE_SPMDM_INDEX_T(std::uint8_t)
155156

156157
#undef INSTANTIATE_SPMDM_INDEX_T

src/EmbeddingSpMDMNBit.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,6 +1392,7 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse(
13921392
#define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \
13931393
INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, float) \
13941394
INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, uint16_t) \
1395+
INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, float16) \
13951396
INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, uint8_t) \
13961397
template FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< \
13971398
uint8_t, \

0 commit comments

Comments
 (0)