Skip to content

Commit 15ad846

Browse files
authored
[ET-VK] Add float16 to float32 fallback for devices without 16-bit buffer support (pytorch#16859)
Some Vulkan devices do not support 16-bit storage buffers (the VK_KHR_16bit_storage extension with storageBuffer16BitAccess). Previously, attempting to run models with fp16 buffer-backed tensors on such devices would crash because the shaders require 16-bit buffer access. This diff adds transparent fallback behavior: when a device lacks float16 buffer support, fp16 data is automatically converted to fp32 when copying between CPU memory and staging buffers, and the GPU tensors/buffers use fp32 storage internally. Key changes: 1. StagingBuffer dtype aliasing: - Added `get_staging_dtype()` which returns kFloat when kHalf is requested but the device lacks float16 buffer support - Moved StagingBuffer constructor to cpp file to apply this aliasing 2. Half<->Float conversion utilities: - Added `half_to_float()` and `float_to_half()` for IEEE 754 compliant conversion between 16-bit and 32-bit floats - Added `cast_half_to_float_and_copy_from()` and `cast_float_to_half_and_copy_to()` methods to StagingBuffer 3. vTensor dtype aliasing: - Modified `get_effective_scalar_type()` in Tensor.cpp to alias kHalf to kFloat when float16 buffers are not supported 4. Data transfer dtype handling: - Updated `create_staging_buffer()` in PrepackNode.cpp to handle dtype conversion when staging dtype differs from TensorRef dtype - Updated `maybe_cast_and_copy_into_staging()` and `maybe_cast_and_copy_from_staging()` in ComputeGraph.cpp to handle kHalf<->kFloat conversion 5. Shader updates for conv2d weight prepacking: - Modified conv2d_prepack_weights.glsl/yaml and conv2d_dw_prepack_weights.glsl/yaml to support separate buffer dtype (BUF_DTYPE) from texture dtype (DTYPE) - Added shader variant combinations: [float,float], [half,float], [half,half] - Updated Convolution.cpp and Staging.cpp to select correct shader variant based on staging dtype Differential Revision: [D91281381](https://our.internmc.facebook.com/intern/diff/D91281381/)
1 parent f7664d3 commit 15ad846

22 files changed

Lines changed: 473 additions & 122 deletions
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/api/containers/StagingBuffer.h>
10+
11+
namespace vkcompute {
12+
namespace api {
13+
14+
namespace {
15+
16+
//
17+
// The following fp16<->fp32 conversion functions are adapted from:
18+
// executorch/runtime/core/portable_type/c10/torch/headeronly/util/Half.h
19+
// (fp16_ieee_to_fp32_value and fp16_ieee_from_fp32_value)
20+
//
21+
22+
inline float fp32_from_bits(uint32_t bits) {
23+
float result;
24+
std::memcpy(&result, &bits, sizeof(result));
25+
return result;
26+
}
27+
28+
inline uint32_t fp32_to_bits(float f) {
29+
uint32_t bits;
30+
std::memcpy(&bits, &f, sizeof(bits));
31+
return bits;
32+
}
33+
34+
/*
35+
* Convert a 16-bit floating-point number in IEEE half-precision format, in bit
36+
* representation, to a 32-bit floating-point number in IEEE single-precision
37+
* format.
38+
*/
39+
float half_to_float(uint16_t h) {
40+
/*
41+
* Extend the half-precision floating-point number to 32 bits and shift to the
42+
* upper part of the 32-bit word:
43+
* +---+-----+------------+-------------------+
44+
* | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
45+
* +---+-----+------------+-------------------+
46+
* Bits 31 26-30 16-25 0-15
47+
*/
48+
const uint32_t w = (uint32_t)h << 16;
49+
/*
50+
* Extract the sign of the input number into the high bit of the 32-bit word:
51+
*/
52+
const uint32_t sign = w & UINT32_C(0x80000000);
53+
/*
54+
* Extract mantissa and biased exponent of the input number into the high bits
55+
* of the 32-bit word:
56+
*/
57+
const uint32_t two_w = w + w;
58+
59+
/*
60+
* Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become
61+
* mantissa and exponent of a single-precision floating-point number:
62+
*
63+
* The exponent needs to be corrected by the difference in exponent bias
64+
* between single-precision and half-precision formats (0x7F - 0xF = 0x70).
65+
* We use 0xE0 initially and then scale by 2^(-112) to handle Inf/NaN.
66+
*/
67+
constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23;
68+
constexpr uint32_t scale_bits = (uint32_t)15 << 23;
69+
float exp_scale_val = 0;
70+
std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
71+
const float exp_scale = exp_scale_val;
72+
const float normalized_value =
73+
fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
74+
75+
/*
76+
* Convert denormalized half-precision inputs into single-precision results
77+
* (always normalized). Zero inputs are also handled here.
78+
*/
79+
constexpr uint32_t magic_mask = UINT32_C(126) << 23;
80+
constexpr float magic_bias = 0.5f;
81+
const float denormalized_value =
82+
fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
83+
84+
/*
85+
* Choose either results of conversion of input as a normalized number, or
86+
* as a denormalized number, depending on the input exponent.
87+
*/
88+
constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27;
89+
const uint32_t result = sign |
90+
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)
91+
: fp32_to_bits(normalized_value));
92+
return fp32_from_bits(result);
93+
}
94+
95+
/*
96+
* Convert a 32-bit floating-point number in IEEE single-precision format to a
97+
* 16-bit floating-point number in IEEE half-precision format, in bit
98+
* representation.
99+
*/
100+
uint16_t float_to_half(float f) {
101+
constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23;
102+
constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23;
103+
float scale_to_inf_val = 0, scale_to_zero_val = 0;
104+
std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));
105+
std::memcpy(
106+
&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));
107+
const float scale_to_inf = scale_to_inf_val;
108+
const float scale_to_zero = scale_to_zero_val;
109+
110+
float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
111+
112+
const uint32_t w = fp32_to_bits(f);
113+
const uint32_t shl1_w = w + w;
114+
const uint32_t sign = w & UINT32_C(0x80000000);
115+
uint32_t bias = shl1_w & UINT32_C(0xFF000000);
116+
if (bias < UINT32_C(0x71000000)) {
117+
bias = UINT32_C(0x71000000);
118+
}
119+
120+
base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
121+
const uint32_t bits = fp32_to_bits(base);
122+
const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
123+
const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
124+
const uint32_t nonsign = exp_bits + mantissa_bits;
125+
return static_cast<uint16_t>(
126+
(sign >> 16) |
127+
(shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign));
128+
}
129+
130+
} // namespace
131+
132+
StagingBuffer::StagingBuffer(
133+
Context* context_p,
134+
const vkapi::ScalarType dtype,
135+
const size_t numel,
136+
const vkapi::CopyDirection direction)
137+
: context_p_(context_p),
138+
dtype_(get_staging_dtype(context_p, dtype)),
139+
vulkan_buffer_(context_p_->adapter_ptr()->vma().create_staging_buffer(
140+
element_size(dtype_) * numel,
141+
direction)),
142+
mapped_data_(nullptr) {}
143+
144+
vkapi::ScalarType get_staging_dtype(
145+
Context* context_p,
146+
vkapi::ScalarType dtype) {
147+
if (dtype == vkapi::kHalf &&
148+
!context_p->adapter_ptr()->has_full_float16_buffers_support()) {
149+
return vkapi::kFloat;
150+
}
151+
return dtype;
152+
}
153+
154+
void StagingBuffer::cast_half_to_float_and_copy_from(
155+
const uint16_t* src,
156+
const size_t numel) {
157+
VK_CHECK_COND(numel <= this->numel());
158+
float* dst = reinterpret_cast<float*>(data());
159+
for (size_t i = 0; i < numel; ++i) {
160+
dst[i] = half_to_float(src[i]);
161+
}
162+
}
163+
164+
void StagingBuffer::cast_float_to_half_and_copy_to(
165+
uint16_t* dst,
166+
const size_t numel) {
167+
VK_CHECK_COND(numel <= this->numel());
168+
vmaInvalidateAllocation(
169+
vulkan_buffer_.vma_allocator(),
170+
vulkan_buffer_.allocation(),
171+
0u,
172+
VK_WHOLE_SIZE);
173+
const float* src = reinterpret_cast<const float*>(data());
174+
for (size_t i = 0; i < numel; ++i) {
175+
dst[i] = float_to_half(src[i]);
176+
}
177+
}
178+
179+
} // namespace api
180+
} // namespace vkcompute

backends/vulkan/runtime/api/containers/StagingBuffer.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
namespace vkcompute {
2020
namespace api {
2121

22+
vkapi::ScalarType get_staging_dtype(
23+
Context* context_p,
24+
vkapi::ScalarType dtype);
25+
2226
class StagingBuffer final {
2327
private:
2428
Context* context_p_;
@@ -32,13 +36,7 @@ class StagingBuffer final {
3236
Context* context_p,
3337
const vkapi::ScalarType dtype,
3438
const size_t numel,
35-
const vkapi::CopyDirection direction)
36-
: context_p_(context_p),
37-
dtype_(dtype),
38-
vulkan_buffer_(context_p_->adapter_ptr()->vma().create_staging_buffer(
39-
element_size(dtype_) * numel,
40-
direction)),
41-
mapped_data_(nullptr) {}
39+
const vkapi::CopyDirection direction);
4240

4341
StagingBuffer(const StagingBuffer&) = delete;
4442
StagingBuffer& operator=(const StagingBuffer&) = delete;
@@ -92,6 +90,12 @@ class StagingBuffer final {
9290
}
9391
}
9492

93+
void cast_half_to_float_and_copy_from(
94+
const uint16_t* src,
95+
const size_t numel);
96+
97+
void cast_float_to_half_and_copy_to(uint16_t* dst, const size_t numel);
98+
9599
inline void copy_to(void* dst, const size_t nbytes) {
96100
VK_CHECK_COND(nbytes <= this->nbytes());
97101
vmaInvalidateAllocation(

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,25 @@ PackedDimInfo calculate_packed_dim_info(
6363

6464
/*
6565
* For PackedInt8 memory layouts, ensure that the scalar type used for the
66-
* tensor is kInt8x4. Otherwise, return the original scalar type.
66+
* tensor is kInt8x4.
67+
*
68+
* For kHalf dtype on devices that don't support float16 buffers, alias to
69+
* kFloat.
70+
*
71+
* Otherwise, return the original scalar type.
6772
*/
6873
vkapi::ScalarType get_effective_scalar_type(
74+
Context* const context,
6975
const vkapi::ScalarType dtype,
7076
const utils::GPUMemoryLayout memory_layout) {
7177
vkapi::ScalarType effective_dtype = dtype;
7278
if (utils::is_packed_int8_layout(memory_layout)) {
7379
VK_CHECK_COND(dtype == vkapi::kInt8x4 || dtype == vkapi::kChar);
7480
effective_dtype = vkapi::kInt8x4;
81+
} else if (
82+
dtype == vkapi::kHalf &&
83+
!context->adapter_ptr()->has_full_float16_buffers_support()) {
84+
effective_dtype = vkapi::kFloat;
7585
}
7686
return effective_dtype;
7787
}
@@ -726,7 +736,7 @@ vTensor::vTensor(
726736
const utils::GPUMemoryLayout memory_layout,
727737
const bool allocate_memory,
728738
const utils::AxisMapLayout axis_map_layout)
729-
: dtype_(get_effective_scalar_type(dtype, memory_layout)),
739+
: dtype_(get_effective_scalar_type(context, dtype, memory_layout)),
730740
packed_dim_info_(calculate_packed_dim_info(memory_layout, storage_type)),
731741
// Calculate tensor metadata
732742
sizes_(sizes.begin(), sizes.end()),

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1313

14+
#include <executorch/backends/vulkan/runtime/api/containers/StagingBuffer.h>
15+
1416
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1517

1618
#include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
@@ -350,6 +352,11 @@ vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const {
350352
VK_THROW("Could not get dtype of value with type ", val.type());
351353
}
352354

355+
vkapi::ScalarType ComputeGraph::get_staging_dtype_for(
356+
const ValueRef idx) const {
357+
return api::get_staging_dtype(context_.get(), dtype_of(idx));
358+
}
359+
353360
bool ComputeGraph::is_contiguous_buffer_tensor(const ValueRef idx) const {
354361
if (!val_is_tensor(idx)) {
355362
return false;
@@ -923,6 +930,10 @@ void ComputeGraph::maybe_cast_and_copy_into_staging(
923930
src_data_dtype == vkapi::kDouble && staging_dtype == vkapi::kFloat) {
924931
const double* casted_data = reinterpret_cast<const double*>(data);
925932
staging->cast_and_copy_from<double, float>(casted_data, numel);
933+
} else if (
934+
src_data_dtype == vkapi::kHalf && staging_dtype == vkapi::kFloat) {
935+
const uint16_t* casted_data = reinterpret_cast<const uint16_t*>(data);
936+
staging->cast_half_to_float_and_copy_from(casted_data, numel);
926937
} else {
927938
VK_THROW(
928939
"Unsupported type conversion from ",
@@ -962,6 +973,10 @@ void ComputeGraph::maybe_cast_and_copy_from_staging(
962973
dst_data_dtype == vkapi::kDouble && staging_dtype == vkapi::kFloat) {
963974
double* casted_data = reinterpret_cast<double*>(data);
964975
staging->cast_and_copy_to<float, double>(casted_data, numel);
976+
} else if (
977+
dst_data_dtype == vkapi::kHalf && staging_dtype == vkapi::kFloat) {
978+
uint16_t* casted_data = reinterpret_cast<uint16_t*>(data);
979+
staging->cast_float_to_half_and_copy_to(casted_data, numel);
965980
} else {
966981
VK_THROW(
967982
"Unsupported type conversion from staging dtype ",

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,8 @@ class ComputeGraph final {
352352

353353
vkapi::ScalarType dtype_of(const ValueRef idx) const;
354354

355+
vkapi::ScalarType get_staging_dtype_for(const ValueRef idx) const;
356+
355357
inline const utils::ivec3& logical_limits_of(const ValueRef idx) const {
356358
return values_.at(idx).toConstTensor().logical_limits();
357359
}
@@ -997,17 +999,19 @@ class ComputeGraph final {
997999
// Input/Output
9981000
//
9991001

1002+
private:
10001003
void
10011004
copy_into_staging(const ValueRef idx, const void* data, const size_t numel);
10021005

1006+
void copy_from_staging(const ValueRef idx, void* data, const size_t numel);
1007+
1008+
public:
10031009
void maybe_cast_and_copy_into_staging(
10041010
const ValueRef idx,
10051011
const void* data,
10061012
const size_t numel,
10071013
const vkapi::ScalarType src_data_dtype);
10081014

1009-
void copy_from_staging(const ValueRef idx, void* data, const size_t numel);
1010-
10111015
void maybe_cast_and_copy_from_staging(
10121016
const ValueRef idx,
10131017
void* data,
@@ -1110,6 +1114,10 @@ class ComputeGraph final {
11101114
return context_->adapter_ptr()->supports_int16_shader_types();
11111115
}
11121116

1117+
inline bool float16_buffers_enabled() const {
1118+
return context_->adapter_ptr()->has_full_float16_buffers_support();
1119+
}
1120+
11131121
inline size_t execute_count() const {
11141122
return execute_count_;
11151123
}

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,36 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
7070
vkapi::CopyDirection::HOST_TO_DEVICE);
7171
graph->update_staging_nbytes_in_cmd(staging.buffer().mem_size_as_size_t());
7272
size_t nbytes = numel * vkapi::element_size(tref->dtype);
73-
staging.copy_from(tref->data, nbytes);
73+
74+
// In some cases the staging dtype will diverge from the TensorRef dtype. The
75+
// most common case for this is when the tensor data is float16, but the GPU
76+
// does not support 16-bit storage buffers. In these cases, the tensor data
77+
// is manually casted to the staging dtype.
78+
vkapi::ScalarType staging_dtype = staging.dtype();
79+
vkapi::ScalarType tref_dtype = tref->dtype;
80+
if (staging_dtype == tref_dtype) {
81+
staging.copy_from(tref->data, nbytes);
82+
} else {
83+
// Hard-coded type conversion cases
84+
if (tref_dtype == vkapi::kHalf && staging_dtype == vkapi::kFloat) {
85+
const uint16_t* casted_data =
86+
reinterpret_cast<const uint16_t*>(tref->data);
87+
staging.cast_half_to_float_and_copy_from(casted_data, numel);
88+
} else if (tref_dtype == vkapi::kLong && staging_dtype == vkapi::kInt) {
89+
const int64_t* casted_data = reinterpret_cast<const int64_t*>(tref->data);
90+
staging.cast_and_copy_from<int64_t, int32_t>(casted_data, numel);
91+
} else if (tref_dtype == vkapi::kDouble && staging_dtype == vkapi::kFloat) {
92+
const double* casted_data = reinterpret_cast<const double*>(tref->data);
93+
staging.cast_and_copy_from<double, float>(casted_data, numel);
94+
} else {
95+
VK_THROW(
96+
"Unsupported type conversion from ",
97+
tref_dtype,
98+
" to staging dtype ",
99+
staging_dtype);
100+
}
101+
}
102+
74103
// Once the staging buffer is copied, if the TensorRef owns a FreeableBuffer,
75104
// it can be freed.
76105
tref->free_buffer();

0 commit comments

Comments
 (0)