[ET-VK][quantized] Store dq8ca per-token zero-point as fp32#20491
[ET-VK][quantized] Store dq8ca per-token zero-point as fp32#20491SS-JIA wants to merge 1 commit into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20491
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 4 New Failures, 4 Unrelated FailuresAs of commit e7160e7 with merge base 1621fa2 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This PR needs a
|
Stack from ghstack (oldest at bottom):
The per-token dynamic-activation-quant (
dq8ca) zero-point was corrupted by a tensor-allocation vs shader-access dtype mismatch. The per-token zero-point tensor is created with a float dtype -- fp32, or fp16 underUSE_VULKAN_FP16_INFERENCE-- so its backing image uses a float texel format (rgba32f/rgba16f). But the shader declared and accessed that image with an integer dtype (int8, an integer image formatrgba8i). Reading a float-format image through an integer-format binding is the bug. On ARM Mali (Valhall) GPUs this mismatch corrupted the per-token zero-points: negative zero-points came back as garbage (-kread as-2^23 - k), driving the quantized activation to the int8 floor, the per-group sums to-4096, and the GEMM output to garbage, producing garbled, runaway generation for 8da4w models (e.g. the Llama4-mini TISO TTS backbone on Mali-G715/G710). Adreno happened to tolerate the same mismatch and read correct values, so the corruption was Mali-specific even though the mismatch itself is general.The per-token zero-point is serialized as fp32 by torchao design:
Int8DynamicActivationIntxWeightConfig(8da4w) uses asymmetric per-token activation quant with an explicit fp32zero_point_dtype. Decoding the serialized.pteconfirms the zero-point tensor is FLOAT32, and (like the scale) it is stored in a texture as anrgba32ftexel -- neverrgba8i. The float allocation is the truth; the int8 shader access was the mismatched side.The fix is to declare, store, and read the per-token zero-point as fp32 across the dq8ca qparams shaders, so the shader access dtype matches the tensor's allocation dtype and the texture is read as the
rgba32fimage it actually is. The zero-point value is integer-valued (nudged to[-128, 127]), so fp32 represents it exactly and the consumer'sint(zp)conversion for the integer dequant-correction is lossless. This touches the dq8ca qparams shaders --choose_qparams_per_row,quantize_and_pack_4h4w_with_group_sums,linear_dq8ca_q4gsw_tiled, the sharedlinear_int8_input_scales_zps_loadhelper, and thelinear_q4gsw_coopvariant (whose zero-point binding only matches the descriptor-set layout and is never read) -- plus a documentation comment inChooseQParams.cpp.Because the per-token qparams remain in texture storage (unchanged from before) and only the zero-point dtype changes, this is a pure runtime shader fix: existing texture-qparams 8da4w
.ptefiles are corrected without re-export, since the texture already bakes the zero-point asrgba32fand the shader now reads it as such.Authored with Claude Code.
Differential Revision: D109595977