Skip to content

Commit ddb5aee

Browse files
author
ssjia
committed
[ET-VK][quantized] Store dq8ca per-token zero-point as fp32
The per-token dynamic-activation-quant (`dq8ca`) zero-point was corrupted by a tensor-allocation vs shader-access dtype mismatch. The per-token zero-point tensor is created with a float dtype -- fp32, or fp16 under `USE_VULKAN_FP16_INFERENCE` -- so its backing image uses a float texel format (`rgba32f` / `rgba16f`). But the shader declared and accessed that image with an integer dtype (`int8`, an integer image format `rgba8i`). Reading a float-format image through an integer-format binding is the bug. On ARM Mali (Valhall) GPUs this mismatch corrupted the per-token zero-points: negative zero-points came back as garbage (`-k` read as `-2^23 - k`), driving the quantized activation to the int8 floor, the per-group sums to `-4096`, and the GEMM output to garbage, producing garbled, runaway generation for 8da4w models (e.g. the Llama4-mini TISO TTS backbone on Mali-G715/G710). Adreno happened to tolerate the same mismatch and read correct values, so the corruption was Mali-specific even though the mismatch itself is general. The per-token zero-point is serialized as fp32 by torchao design: `Int8DynamicActivationIntxWeightConfig` (8da4w) uses asymmetric per-token activation quant with an explicit fp32 `zero_point_dtype`. Decoding the serialized `.pte` confirms the zero-point tensor is FLOAT32, and (like the scale) it is stored in a texture as an `rgba32f` texel -- never `rgba8i`. The float allocation is the truth; the int8 shader access was the mismatched side. The fix is to declare, store, and read the per-token zero-point as fp32 across the dq8ca qparams shaders, so the shader access dtype matches the tensor's allocation dtype and the texture is read as the `rgba32f` image it actually is. The zero-point value is integer-valued (nudged to `[-128, 127]`), so fp32 represents it exactly and the consumer's `int(zp)` conversion for the integer dequant-correction is lossless. This touches the dq8ca qparams shaders -- `choose_qparams_per_row`, `quantize_and_pack_4h4w_with_group_sums`, `linear_dq8ca_q4gsw_tiled`, the shared `linear_int8_input_scales_zps_load` helper, and the `linear_q4gsw_coop` variant (whose zero-point binding only matches the descriptor-set layout and is never read) -- plus a documentation comment in `ChooseQParams.cpp`. Because the per-token qparams remain in texture storage (unchanged from before) and only the zero-point dtype changes, this is a pure runtime shader fix: existing texture-qparams 8da4w `.pte` files are corrected without re-export, since the texture already bakes the zero-point as `rgba32f` and the shader now reads it as such. Authored with Claude Code. Differential Revision: [D109595977](https://our.internmc.facebook.com/intern/diff/D109595977/) ghstack-source-id: 396618146 Pull-Request: #20491
1 parent 1621fa2 commit ddb5aee

8 files changed

Lines changed: 24 additions & 16 deletions

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ layout(std430) buffer;
3030
#include "common.glslh"
3131

3232
${layout_declare_tensor(B, "w", "t_scales", DTYPE, "texture3d")}
33-
${layout_declare_tensor(B, "w", "t_zps", "int8", "texture3d")}
33+
${layout_declare_tensor(B, "w", "t_zps", "float", "texture3d")}
3434
${layout_declare_tensor(B, "r", "t_input", DTYPE, STORAGE, is_scalar_array=False)}
3535

3636
${layout_declare_ubo(B, "ivec4", "input_sizes")}
@@ -196,7 +196,7 @@ void main() {
196196

197197
if (worker_id == 0) {
198198
imageStore(t_scales, ivec3(output_y4, 0, 0), scales_out);
199-
imageStore(t_zps, ivec3(output_y4, 0, 0), zps_out);
199+
imageStore(t_zps, ivec3(output_y4, 0, 0), vec4(zps_out));
200200
}
201201

202202
}

backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ ${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=Fa
4646
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", PACKED_INT8_INPUT_STORAGE, is_scalar_array=False)}
4747
${layout_declare_tensor(B, "r", "t_int8_input_sums", "int", "buffer", is_scalar_array=False)}
4848
${layout_declare_tensor(B, "r", "t_int8_input_scales", DTYPE, "texture3d")}
49-
${layout_declare_tensor(B, "r", "t_int8_input_zps", "int8", "texture3d")}
49+
${layout_declare_tensor(B, "r", "t_int8_input_zps", "float", "texture3d")}
5050
${layout_declare_tensor(B, "r", "t_packed_int4_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)}
5151
${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)}
5252
${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)}

backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_scales_zps_load.glslh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ void load_int8_input_scales_and_zps(
2020
[[unroll]] for (int m4 = 0; m4 < TILE_M4; m4++) {
2121
scales.data[m4] =
2222
VEC4_T(texelFetch(t_int8_input_scales, ivec3(m4_start + m4, 0, 0), 0));
23-
zps.data[m4] = texelFetch(t_int8_input_zps, ivec3(m4_start + m4, 0, 0), 0);
23+
zps.data[m4] =
24+
ivec4(texelFetch(t_int8_input_zps, ivec3(m4_start + m4, 0, 0), 0));
2425
}
2526
}
2627

backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ $if DYNAMIC_QUANT_VARIANT:
4040
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", PACKED_INPUT_STORAGE, is_scalar_array=False)}
4141
${layout_declare_tensor(B, "r", "t_int_input_sums", "int", "buffer", is_scalar_array=False)}
4242
${layout_declare_tensor(B, "r", "t_input_scale", DTYPE, "texture3d")}
43-
${layout_declare_tensor(B, "r", "t_input_zp", "int", "texture3d")}
43+
${layout_declare_tensor(B, "r", "t_input_zp", "float", "texture3d")}
4444
${layout_declare_tensor(B, "r", "t_packed_int4_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)}
4545
${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)}
4646
${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)}

backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ ${layout_declare_tensor(B, "w", "t_packed_int8_input", "int", OUTPUT_STORAGE, is
3333
${layout_declare_tensor(B, "w", "t_int8_input_sums", "int", "buffer", is_scalar_array=False)}
3434
${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)}
3535
${layout_declare_tensor(B, "r", "t_int8_input_scales", DTYPE, "texture3d")}
36-
${layout_declare_tensor(B, "r", "t_int8_input_zps", "int8", "texture3d")}
36+
${layout_declare_tensor(B, "r", "t_int8_input_zps", "float", "texture3d")}
3737

3838
${layout_declare_ubo(B, "ivec4", "input_sizes")}
3939

backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ utils::uvec3 pick_choose_qparams_per_row_local_wg_size(
7878
return {workers_per_output, outputs_per_wg, 1u};
7979
}
8080

81+
// The per-token zero-point tensor is fp32-typed (matching torchao's serialized
82+
// asymmetric per-token zero_point_dtype=fp32), even though its values are
83+
// integer-valued in [-128, 127]. The shaders read it as a float texel and
84+
// convert to int for the integer dequant-correction. Declaring the shader
85+
// binding fp32 to match the tensor's allocation avoids the
86+
// float-image-read-through-an-integer-binding format mismatch that corrupted
87+
// negative zero-points on Mali.
8188
void add_choose_qparams_per_row_node(
8289
ComputeGraph& graph,
8390
const ValueRef& input,

backends/vulkan/test/custom_ops/test_choose_qparams_per_row.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ TestCase create_test_case_from_config(
4444
config.channel_size < kRefDimSizeLimit);
4545
std::string prefix = is_perf ? "PERF" : "ACCU";
4646
std::string in_dtype = dtype_short(input_dtype);
47-
std::string out_dtype = "f32,i8"; // pair: (scale, zero_point)
47+
std::string out_dtype = "f32,f32"; // pair: (scale, zero_point)
4848
std::string shape_str = "[" + std::to_string(config.num_channels) + "," +
4949
std::to_string(config.channel_size) + "]";
5050
std::string storage_str = repr_str(storage_type, utils::kWidthPacked);
@@ -81,10 +81,10 @@ TestCase create_test_case_from_config(
8181
utils::kWidthPacked,
8282
DataGenType::ZEROS);
8383

84-
// Output zero_point tensor (int8) - [num_channels]
84+
// Output zero_point tensor (float) - [num_channels]
8585
ValueSpec zero_point_out(
8686
{config.num_channels},
87-
vkapi::kChar, // int8 for quantized zero point
87+
vkapi::kFloat,
8888
utils::kTexture3D, // Always buffer as per requirement
8989
utils::kWidthPacked,
9090
DataGenType::ZEROS);
@@ -289,7 +289,7 @@ void choose_qparams_per_channel_reference_impl(TestCase& test_case) {
289289

290290
// Prepare output data
291291
auto& scale_ref_data = scale_out_spec.get_ref_float_data();
292-
auto& zero_point_ref_data = zero_point_out_spec.get_ref_int8_data();
292+
auto& zero_point_ref_data = zero_point_out_spec.get_ref_float_data();
293293
scale_ref_data.resize(num_channels);
294294
zero_point_ref_data.resize(num_channels);
295295

@@ -312,9 +312,9 @@ void choose_qparams_per_channel_reference_impl(TestCase& test_case) {
312312
calculate_scale_and_zero_point_reference(
313313
min_val, max_val, quant_min, quant_max, scale, zero_point);
314314

315-
// Store results (cast zero_point to int8)
315+
// Store results
316316
scale_ref_data[channel] = scale;
317-
zero_point_ref_data[channel] = static_cast<int8_t>(zero_point);
317+
zero_point_ref_data[channel] = static_cast<float>(zero_point);
318318
}
319319
}
320320

backends/vulkan/test/custom_ops/test_q4gsw_linear.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ TestCase create_test_case_from_config(
9797

9898
ValueSpec input_zero_point(
9999
{1, config.M}, // Per-input channel tensor
100-
vkapi::kChar,
100+
vkapi::kFloat,
101101
storage_type,
102102
utils::kWidthPacked,
103103
DataGenType::RANDINT);
@@ -428,7 +428,7 @@ void linear_dq8ca_q4gsw_reference_impl(TestCase& test_case) {
428428
auto& input_scale_data =
429429
input_scale_spec.get_float_data(); // Per-input channel tensor
430430
auto& input_zero_point_data =
431-
input_zeros_spec.get_int8_data(); // Per-input channel tensor
431+
input_zeros_spec.get_float_data(); // Per-input channel tensor
432432

433433
auto& weight_data = weight_spec.get_uint8_data();
434434
auto& weight_sums_data = weight_sums_spec.get_int32_data();
@@ -462,8 +462,8 @@ void linear_dq8ca_q4gsw_reference_impl(TestCase& test_case) {
462462

463463
// Use per-input channel scale and zero point - index by batch dimension
464464
float input_scale = input_scale_data[b]; // {1, M} -> index by batch
465-
int8_t input_zero_point =
466-
input_zero_point_data[b]; // {1, M} -> index by batch
465+
int8_t input_zero_point = static_cast<int8_t>(
466+
input_zero_point_data[b]); // {1, M} -> index by batch
467467

468468
float quant_input_f =
469469
std::round(input_data[input_idx] / input_scale) + input_zero_point;

0 commit comments

Comments
 (0)