Skip to content

Commit 9ed8ef5

Browse files
cyyevermeta-codesync[bot]
authored andcommitted
Switch bf16 quantize ops to at::kBFloat16 (#5735)
Summary: Pull Request resolved: #5735 Reviewed By: spcyppt Differential Revision: D103884190 Pulled By: q10 fbshipit-source-id: b062b27c778c21711af8f8b29de36adb5daefc2d
1 parent 3ca19e6 commit 9ed8ef5

3 files changed

Lines changed: 10 additions & 17 deletions

File tree

fbgemm_gpu/src/quantize_ops/quantize_bfloat16.cu

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,18 @@ namespace fbgemm_gpu {
2424
DLL_PUBLIC at::Tensor _float_to_bfloat16_gpu(const at::Tensor& input) {
2525
CUDA_DEVICE_GUARD(input);
2626

27-
// TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia
28-
// NCCL input.options().dtype(at::kBFloat16)); // at::kBFloat16
29-
auto output = at::empty({}, input.options().dtype(at::kHalf));
27+
auto output = at::empty({}, input.options().dtype(at::kBFloat16));
3028
output.resize_(0);
3129

3230
auto iter = at::TensorIteratorConfig()
3331
.check_all_same_dtype(false)
3432
.add_output(output)
3533
.add_input(input)
3634
.build();
37-
at::native::gpu_kernel(iter, [] GPU_LAMBDA(float in) -> at::Half {
35+
at::native::gpu_kernel(iter, [] GPU_LAMBDA(float in) -> at::BFloat16 {
3836
fbgemm_gpu::fint32 temp;
3937
temp.F = in;
40-
return at::Half((temp.I + (1 << 15)) >> 16, at::Half::from_bits());
38+
return at::BFloat16((temp.I + (1 << 15)) >> 16, at::BFloat16::from_bits());
4139
});
4240

4341
return output;
@@ -62,7 +60,7 @@ DLL_PUBLIC at::Tensor _bfloat16_to_float_gpu(const at::Tensor& input) {
6260
.add_input(input)
6361
.build();
6462

65-
at::native::gpu_kernel(iter, [] GPU_LAMBDA(at::Half in) -> float {
63+
at::native::gpu_kernel(iter, [] GPU_LAMBDA(at::BFloat16 in) -> float {
6664
fbgemm_gpu::fint32 temp;
6765
temp.I = in.x << 16;
6866
return temp.F;

fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -593,24 +593,20 @@ void BFloat16QuantizedToFloat_ref(
593593
}
594594
}
595595

596-
// TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia NCCL
597596
at::Tensor _float_to_bfloat16_cpu(const at::Tensor& input) {
598597
TENSOR_ON_CPU(input);
599598

600599
const auto input_sizes = input.sizes();
601-
auto output = at::empty(
602-
input_sizes,
603-
input.options().dtype(at::kHalf)); // at::kHalf
600+
auto output = at::empty(input_sizes, input.options().dtype(at::kBFloat16));
604601

605602
FloatToBFloat16Quantized_ref(
606603
input.const_data_ptr<float>(),
607604
input.numel(),
608-
reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::Half>()));
605+
reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::BFloat16>()));
609606

610607
return output;
611608
}
612609

613-
// TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia NCCL
614610
at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input) {
615611
TENSOR_ON_CPU(input);
616612

@@ -619,7 +615,7 @@ at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input) {
619615
auto output = at::empty(input_sizes, input.options().dtype(at::kFloat));
620616

621617
BFloat16QuantizedToFloat_ref(
622-
reinterpret_cast<const at::BFloat16*>(input.const_data_ptr<at::Half>()),
618+
input.const_data_ptr<at::BFloat16>(),
623619
input.numel(),
624620
output.mutable_data_ptr<float>());
625621

fbgemm_gpu/test/quantize/bfloat16_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class SparseNNOperatorsGPUTest(unittest.TestCase):
2727
k=st.integers(min_value=2, max_value=2),
2828
n=st.integers(min_value=2, max_value=2),
2929
)
30+
@settings(deadline=10000, suppress_health_check=[HealthCheck.filter_too_much])
3031
def test_dense_mlp_quantize_ops(
3132
self, precision: str, batch_size: int, k: int, n: int
3233
) -> None:
@@ -69,17 +70,15 @@ def test_quantize_op(self, nrows: int, ncols: int) -> None:
6970
return
7071
f = np.vectorize(lambda x: bfloat_quantize(x))
7172
reference = f(input_data.numpy())
72-
quantized_data_uint16 = quantized_data.numpy()
73-
quantized_data_uint16.dtype = np.uint16
73+
quantized_data_uint16 = quantized_data.view(torch.uint16).numpy()
7474
np.testing.assert_array_almost_equal(quantized_data_uint16, reference)
7575

7676
if torch.cuda.is_available():
7777
input_data_gpu = input_data.cuda()
7878
quantized_data_gpu = torch.ops.fbgemm.FloatToBfloat16Quantized(
7979
input_data_gpu
8080
)
81-
quantized_data_numpy = quantized_data_gpu.cpu().numpy()
82-
quantized_data_numpy.dtype = np.uint16
81+
quantized_data_numpy = quantized_data_gpu.view(torch.uint16).cpu().numpy()
8382
np.testing.assert_allclose(quantized_data_numpy, reference)
8483

8584
# pyre-fixme[56]: Pyre was not able to infer the type of argument

0 commit comments

Comments
 (0)