Skip to content

Commit 2b887bf

Browse files
committed
Make float16/bfloat16 distinct types
1 parent 0872fda commit 2b887bf

18 files changed

Lines changed: 265 additions & 227 deletions

bench/EmbeddingQuantizeFloatToFloatOrHalfBenchmark.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ static void performance_test_bf16() {
9393
randFill<uint8_t>(inpVec, 0, 20);
9494

9595
int output_columns = colSize - 2 * sizeof(float);
96-
aligned_vector<float16> outVec(rowSize * output_columns);
96+
aligned_vector<bfloat16> outVec(rowSize * output_columns);
9797

9898
double duration = 0.0f;
9999

@@ -102,7 +102,7 @@ static void performance_test_bf16() {
102102
duration = measureWithWarmup(
103103
[&]() {
104104
for (int i = 0; i < kNumRepeats; ++i) {
105-
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<float16, true>(
105+
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<bfloat16>(
106106
inpVec.data(), rowSize, colSize, outVec.data());
107107
}
108108
},

bench/EmbeddingQuantizeNBitToFloatOrHalfBenchmark.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ static void performance_test_bf16() {
107107
aligned_vector<uint8_t> inpVec(rowSize * bytes_per_row);
108108
randFill<uint8_t>(inpVec, 0, 20);
109109

110-
aligned_vector<float16> outVec(rowSize * colSize);
110+
aligned_vector<bfloat16> outVec(rowSize * colSize);
111111

112112
double duration = 0.0f;
113113

@@ -116,7 +116,7 @@ static void performance_test_bf16() {
116116
duration = measureWithWarmup(
117117
[&]() {
118118
for (int i = 0; i < kNumRepeats; ++i) {
119-
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<float16, true>(
119+
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<bfloat16>(
120120
bit_rate,
121121
inpVec.data(),
122122
rowSize,

bench/EmbeddingSpMDMBenchmark.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ static void run_benchmark(
146146
batch_size,
147147
lengths_sum,
148148
num_rows,
149-
embedding_table_fp16.data(),
149+
reinterpret_cast<const uint16_t*>(embedding_table_fp16.data()),
150150
indices_32.data(),
151151
offsets.data(),
152152
has_weight ? weights.data() : nullptr,
@@ -158,7 +158,7 @@ static void run_benchmark(
158158
batch_size,
159159
lengths_sum,
160160
num_rows,
161-
embedding_table_fp16.data(),
161+
reinterpret_cast<const uint16_t*>(embedding_table_fp16.data()),
162162
indices.data(),
163163
offsets.data(),
164164
has_weight ? weights.data() : nullptr,
@@ -172,7 +172,7 @@ static void run_benchmark(
172172
batch_size,
173173
lengths_sum,
174174
num_rows,
175-
embedding_table_bf16.data(),
175+
reinterpret_cast<const uint16_t*>(embedding_table_bf16.data()),
176176
indices_32.data(),
177177
offsets.data(),
178178
has_weight ? weights.data() : nullptr,
@@ -184,7 +184,7 @@ static void run_benchmark(
184184
batch_size,
185185
lengths_sum,
186186
num_rows,
187-
embedding_table_bf16.data(),
187+
reinterpret_cast<const uint16_t*>(embedding_table_bf16.data()),
188188
indices.data(),
189189
offsets.data(),
190190
has_weight ? weights.data() : nullptr,
@@ -223,19 +223,19 @@ static void run_benchmark(
223223
embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);
224224
auto kernel_fp32_i64 = GenerateEmbeddingSpMDM<float, int64_t>(
225225
embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);
226-
auto kernel_fp16_i32 = GenerateEmbeddingSpMDM<float16, int32_t>(
226+
auto kernel_fp16_i32 = GenerateEmbeddingSpMDM<uint16_t, int32_t>(
227227
embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);
228-
auto kernel_fp16_i64 = GenerateEmbeddingSpMDM<float16, int64_t>(
228+
auto kernel_fp16_i64 = GenerateEmbeddingSpMDM<uint16_t, int64_t>(
229229
embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);
230-
auto kernel_bf16_i32 = GenerateEmbeddingSpMDM<bfloat16, int32_t>(
230+
auto kernel_bf16_i32 = GenerateEmbeddingSpMDM<uint16_t, int32_t>(
231231
embedding_dim,
232232
has_weight,
233233
normalize_by_lengths,
234234
prefetch ? 16 : 0,
235235
/*is_weight_positional=*/false,
236236
/*use_offsets=*/true,
237237
/*is_bf16_out=*/true);
238-
auto kernel_bf16_i64 = GenerateEmbeddingSpMDM<bfloat16, int64_t>(
238+
auto kernel_bf16_i64 = GenerateEmbeddingSpMDM<uint16_t, int64_t>(
239239
embedding_dim,
240240
has_weight,
241241
normalize_by_lengths,
@@ -254,7 +254,7 @@ static void run_benchmark(
254254
batch_size,
255255
lengths_sum,
256256
num_rows,
257-
embedding_table_fp16.data(),
257+
reinterpret_cast<const uint16_t*>(embedding_table_fp16.data()),
258258
indices_32.data(),
259259
offsets.data(),
260260
has_weight ? weights.data() : nullptr,
@@ -264,7 +264,7 @@ static void run_benchmark(
264264
batch_size,
265265
lengths_sum,
266266
num_rows,
267-
embedding_table_fp16.data(),
267+
reinterpret_cast<const uint16_t*>(embedding_table_fp16.data()),
268268
indices.data(),
269269
offsets.data(),
270270
has_weight ? weights.data() : nullptr,
@@ -276,7 +276,7 @@ static void run_benchmark(
276276
batch_size,
277277
lengths_sum,
278278
num_rows,
279-
embedding_table_bf16.data(),
279+
reinterpret_cast<const uint16_t*>(embedding_table_bf16.data()),
280280
indices_32.data(),
281281
offsets.data(),
282282
has_weight ? weights.data() : nullptr,
@@ -286,7 +286,7 @@ static void run_benchmark(
286286
batch_size,
287287
lengths_sum,
288288
num_rows,
289-
embedding_table_bf16.data(),
289+
reinterpret_cast<const uint16_t*>(embedding_table_bf16.data()),
290290
indices.data(),
291291
offsets.data(),
292292
has_weight ? weights.data() : nullptr,

fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,23 @@ using Tensor = at::Tensor;
2525

2626
namespace fbgemm_gpu {
2727

28+
// Map at::Half/at::BFloat16 to the corresponding fbgemm strict types;
29+
// other types (e.g. float) pass through unchanged.
30+
template <typename T>
31+
struct to_fbgemm_type {
32+
using type = T;
33+
};
34+
template <>
35+
struct to_fbgemm_type<at::Half> {
36+
using type = fbgemm::float16;
37+
};
38+
template <>
39+
struct to_fbgemm_type<at::BFloat16> {
40+
using type = fbgemm::bfloat16;
41+
};
42+
template <typename T>
43+
using to_fbgemm_type_t = typename to_fbgemm_type<T>::type;
44+
2845
template <typename input_t>
2946
Tensor& _float_to_fused8bitrowwise_cpu_out_t(
3047
Tensor& output,
@@ -55,7 +72,7 @@ Tensor& _float_to_fused8bitrowwise_cpu_out_t(
5572
return output;
5673
}
5774

58-
template <typename output_t, bool is_uint16_t_of_type_bf16 = false>
75+
template <typename output_t>
5976
Tensor& _fused8bitrowwise_to_float_cpu_out_t(
6077
Tensor& output,
6178
const Tensor& input,
@@ -86,9 +103,7 @@ Tensor& _fused8bitrowwise_to_float_cpu_out_t(
86103
auto output_data = static_cast<output_t*>(
87104
output.mutable_data_ptr()); // output.mutable_data_ptr<output_t>(); ->
88105
// Yields unresolved data_ptr symbol.
89-
fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<
90-
output_t,
91-
is_uint16_t_of_type_bf16>(
106+
fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<output_t>(
92107
input.const_data_ptr<uint8_t>(),
93108
nrows,
94109
ncols,
@@ -206,17 +221,12 @@ Tensor _fusednbitrowwise_sbfront_to_float_or_half_cpu(
206221
"Unsupported output dtype for _fusednbitrowwise_sbfront_to_float_or_half_cpu");
207222
}
208223

209-
using output_ty = std::
210-
conditional_t<std::is_same_v<output_t, float>, float, fbgemm::float16>;
224+
using output_ty = to_fbgemm_type_t<output_t>;
211225
output_ty* output_data = static_cast<output_ty*>(
212226
output.mutable_data_ptr()); // output.mutable_data_ptr<output_t>(); ->
213227
// Yields unresolved data_ptr symbol.
214228

215-
constexpr bool is_uint16_t_of_type_bf16 =
216-
std::is_same_v<output_t, at::BFloat16>;
217-
fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<
218-
output_ty,
219-
is_uint16_t_of_type_bf16>(
229+
fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<output_ty>(
220230
bit_rate,
221231
input.const_data_ptr<uint8_t>(),
222232
nrows,
@@ -234,7 +244,7 @@ Tensor& _fused8bitrowwise_to_float_cpu_out(
234244
const Tensor& input,
235245
const bool scale_bias_last,
236246
const bool quant_padding_float_type) {
237-
return _fused8bitrowwise_to_float_cpu_out_t<float, false>(
247+
return _fused8bitrowwise_to_float_cpu_out_t<float>(
238248
output, input, scale_bias_last, quant_padding_float_type);
239249
}
240250

@@ -243,7 +253,7 @@ Tensor& fused8bitrowwise_to_half_cpu_out(
243253
const Tensor& input,
244254
const bool scale_bias_last,
245255
const bool quant_padding_float_type) {
246-
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::float16, false>(
256+
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::float16>(
247257
output, input, scale_bias_last, quant_padding_float_type);
248258
}
249259

@@ -252,7 +262,7 @@ Tensor& _fused8bitrowwise_to_bfloat16_cpu_out(
252262
const Tensor& input,
253263
const bool scale_bias_last,
254264
const bool quant_padding_float_type) {
255-
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::bfloat16, true>(
265+
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::bfloat16>(
256266
output, input, scale_bias_last, quant_padding_float_type);
257267
}
258268

include/fbgemm/FloatConversion.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -210,18 +210,18 @@ inline typename Tgt::value_type ieee754_trunc(typename Src::value_type value) {
210210

211211
inline float16 cpu_float2half_rn(float f) {
212212
uint32_t f_u32 = std::bit_cast<uint32_t>(f);
213-
return detail::ieee754_trunc<
213+
return {detail::ieee754_trunc<
214214
/*Src=*/detail::IEEE754Single,
215215
/*Tgt=*/detail::IEEE754Half,
216-
detail::RoundingMode::ToNearestTiesToEven>(f_u32);
216+
detail::RoundingMode::ToNearestTiesToEven>(f_u32)};
217217
}
218218

219219
inline float16 cpu_float2half_rz(float f) {
220220
uint32_t f_u32 = std::bit_cast<uint32_t>(f);
221-
return detail::ieee754_trunc<
221+
return {detail::ieee754_trunc<
222222
/*Src=*/detail::IEEE754Single,
223223
/*Tgt=*/detail::IEEE754Half,
224-
detail::RoundingMode::ToZero>(f_u32);
224+
detail::RoundingMode::ToZero>(f_u32)};
225225
}
226226

227227
// Converts a 16-bit unsigned integer representation of a IEEE754 half-precision
@@ -245,10 +245,10 @@ inline float cpu_half2float_ref(const float16 h) {
245245
constexpr uint32_t f32_most_significant_bit = 1u << 22;
246246

247247
// Get sign and exponent alone by themselves
248-
uint32_t sign_bit = (h >> f16_num_non_sign_bits) & 1;
249-
uint32_t exponent = (h >> f16_num_mantissa_bits) & f16_exponent_mask;
248+
uint32_t sign_bit = (h.val >> f16_num_non_sign_bits) & 1;
249+
uint32_t exponent = (h.val >> f16_num_mantissa_bits) & f16_exponent_mask;
250250
// Shift mantissa so that it fills the most significant bits of a float32
251-
uint32_t mantissa = (h & f16_mantissa_mask)
251+
uint32_t mantissa = (h.val & f16_mantissa_mask)
252252
<< (f32_num_mantissa_bits - f16_num_mantissa_bits);
253253

254254
if (exponent == f16_exponent_mask) { // NaN or Inf
@@ -280,10 +280,10 @@ inline float cpu_half2float_ref(const float16 h) {
280280

281281
inline float cpu_half2float(const float16 h) {
282282
#ifdef HAS_NATIVE_FP16_TYPE
283-
return std::bit_cast<__fp16>(h);
283+
return std::bit_cast<__fp16>(h.val);
284284
#elif defined(HAS_F16C)
285285
// Use F16C VCVTPH2PS instruction
286-
__m128i v = _mm_cvtsi32_si128(static_cast<int>(h));
286+
__m128i v = _mm_cvtsi32_si128(static_cast<int>(h.val));
287287
return _mm_cvtss_f32(_mm_cvtph_ps(v));
288288
#else
289289
return cpu_half2float_ref(h);
@@ -293,25 +293,25 @@ inline float cpu_half2float(const float16 h) {
293293
inline float16 cpu_float2half(const float f) {
294294
#ifdef HAS_NATIVE_FP16_TYPE
295295
__fp16 h = f;
296-
return std::bit_cast<float16>(h);
296+
return {std::bit_cast<uint16_t>(h)};
297297
#elif defined(HAS_F16C)
298298
// Use F16C VCVTPS2PH instruction
299299
__m128 v = _mm_set_ss(f);
300-
return static_cast<float16>(
301-
_mm_extract_epi16(_mm_cvtps_ph(v, _MM_FROUND_TO_NEAREST_INT), 0));
300+
return {static_cast<uint16_t>(
301+
_mm_extract_epi16(_mm_cvtps_ph(v, _MM_FROUND_TO_NEAREST_INT), 0))};
302302
#else
303303
return cpu_float2half_rn(f);
304304
#endif
305305
}
306306

307307
inline float cpu_bf162float(bfloat16 src) {
308-
uint32_t val_fp32 = static_cast<uint32_t>(src) << 16;
308+
uint32_t val_fp32 = static_cast<uint32_t>(src.val) << 16;
309309
return std::bit_cast<float>(val_fp32);
310310
}
311311

312312
inline bfloat16 cpu_float2bfloat16(float src) {
313313
uint32_t temp = std::bit_cast<uint32_t>(src);
314-
return (temp + (1u << 15)) >> 16;
314+
return {static_cast<uint16_t>((temp + (1u << 15)) >> 16)};
315315
}
316316

317317
} // namespace fbgemm

include/fbgemm/QuantUtils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat(
324324
* This version intentionally supports only 8-bit because
325325
* the corresponding quantize version only supports 8-bit.
326326
*/
327-
template <typename OutputType, bool is_uint16_t_of_type_bf16 = false>
327+
template <typename OutputType>
328328
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
329329
const uint8_t* input,
330330
size_t input_rows,
@@ -360,7 +360,7 @@ FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef(
360360
* Same as FusedNBitRowwiseQuantizedSBHalfToFloat but unoptimized.
361361
* This should not be called directly except in testing.
362362
*/
363-
template <typename OutputType, bool is_uint16_t_of_type_bf16 = false>
363+
template <typename OutputType>
364364
FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
365365
int bit_rate,
366366
const uint8_t* input,
@@ -373,7 +373,7 @@ FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
373373
* Same as Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf but unoptimized.
374374
* This should not be called directly except in testing.
375375
*/
376-
template <typename OutputType, bool is_uint16_t_of_type_bf16 = false>
376+
template <typename OutputType>
377377
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
378378
const uint8_t* input,
379379
size_t input_rows,

include/fbgemm/Types.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,18 @@
1212

1313
namespace fbgemm {
1414

15-
using float16 = std::uint16_t;
16-
using bfloat16 = std::uint16_t;
15+
struct float16 {
16+
uint16_t val;
17+
bool operator==(const float16&) const = default;
18+
};
19+
20+
struct bfloat16 {
21+
uint16_t val;
22+
bool operator==(const bfloat16&) const = default;
23+
};
24+
25+
static_assert(sizeof(float16) == 2);
26+
static_assert(sizeof(bfloat16) == 2);
1727

1828
constexpr int64_t round_up(int64_t val, int64_t unit) {
1929
return (val + unit - 1) / unit * unit;

src/EmbeddingSpMDM.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,6 +1660,7 @@ GenerateEmbeddingSpMDMRowWiseSparse(
16601660
INSTANTIATE_SPMDMFP8_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
16611661
#define INSTANTIATE_SPMDMFP8_BASE_float(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
16621662
#define INSTANTIATE_SPMDMFP8_BASE_uint16_t(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
1663+
#define INSTANTIATE_SPMDMFP8_BASE_float16(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
16631664

16641665
#define INSTANTIATE_SPMDM_BASE_THREAD_LOCAL( \
16651666
IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
@@ -1695,6 +1696,7 @@ GenerateEmbeddingSpMDMRowWiseSparse(
16951696
INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, int64_t)
16961697

16971698
INSTANTIATE_SPMDM_INDEX_T(float)
1699+
INSTANTIATE_SPMDM_INDEX_T(float16)
16981700
INSTANTIATE_SPMDM_INDEX_T(uint16_t)
16991701
INSTANTIATE_SPMDM_INDEX_T(uint8_t)
17001702

0 commit comments

Comments
 (0)