Commit 9576316
[ExecuTorch][Llama][ez] Remove unwrap_tensor_subclass from 4w quantization path
D100066455 changed the Llama export pipeline to quantize weights in the checkpoint dtype (typically bfloat16) before casting to the computation dtype (fp32). This introduced a regression for Vulkan 4w export: `dequantize_affine` ops produced bfloat16 outputs, which Vulkan doesn't support, causing the graph to be split into multiple partitions. When `sym_constrain_range_for_size` constraint nodes were partitioned into a different delegate than the `_local_scalar_dense` + `slice_copy` ops they constrain, ExportPass re-tracing (in ConvertToLinearPass, SpecPropPass, etc.) would crash with `GuardOnDataDependentSymNode: Could not guard on data-dependent expression u539 < 0`.
The root cause is `unwrap_tensor_subclass()`. This function decomposes `IntxUnpackedToInt8Tensor` into plain tensors via `torch.nn.utils.parametrize`, capturing the subclass's metadata — including its `dtype` attribute (which controls `dequantize_affine`'s output dtype) — as a frozen snapshot in `UnwrapTensorSubclass.rebuild_stack`. A subsequent `model.to(dtype=fp32)` casts the plain tensors but cannot update the frozen metadata, so `dequantize_affine` continues to output bfloat16.
`unwrap_tensor_subclass()` was originally needed as a workaround because `torch.export` did not support tensor subclasses natively. This is no longer the case — the 8da4w path already works without it (using the same `IntxUnpackedToInt8Tensor` subclass), and `torch.export` traces through the subclass correctly. Removing it makes 4w consistent with 8da4w and avoids the metadata-freezing issue entirely.
This change was authored with Claude.
Differential Revision: [D101252037](https://our.internmc.facebook.com/intern/diff/D101252037/)
ghstack-source-id: 368594388
Pull Request resolved: pytorch#189571 parent 1154d34 commit 9576316
1 file changed
Lines changed: 0 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
190 | 190 | | |
191 | 191 | | |
192 | 192 | | |
193 | | - | |
194 | 193 | | |
195 | 194 | | |
196 | 195 | | |
| |||
0 commit comments