Skip to content

Commit 9576316

Browse files
ssjiaSS-JIA
authored andcommitted
[ExecuTorch][Llama][ez] Remove unwrap_tensor_subclass from 4w quantization path
D100066455 changed the Llama export pipeline to quantize weights in the checkpoint dtype (typically bfloat16) before casting to the computation dtype (fp32). This introduced a regression for Vulkan 4w export: `dequantize_affine` ops produced bfloat16 outputs, which Vulkan doesn't support, causing the graph to be split into multiple partitions. When `sym_constrain_range_for_size` constraint nodes were partitioned into a different delegate than the `_local_scalar_dense` + `slice_copy` ops they constrain, ExportPass re-tracing (in ConvertToLinearPass, SpecPropPass, etc.) would crash with `GuardOnDataDependentSymNode: Could not guard on data-dependent expression u539 < 0`. The root cause is `unwrap_tensor_subclass()`. This function decomposes `IntxUnpackedToInt8Tensor` into plain tensors via `torch.nn.utils.parametrize`, capturing the subclass's metadata — including its `dtype` attribute (which controls `dequantize_affine`'s output dtype) — as a frozen snapshot in `UnwrapTensorSubclass.rebuild_stack`. A subsequent `model.to(dtype=fp32)` casts the plain tensors but cannot update the frozen metadata, so `dequantize_affine` continues to output bfloat16. `unwrap_tensor_subclass()` was originally needed as a workaround because `torch.export` did not support tensor subclasses natively. This is no longer the case — the 8da4w path already works without it (using the same `IntxUnpackedToInt8Tensor` subclass), and `torch.export` traces through the subclass correctly. Removing it makes 4w consistent with 8da4w and avoids the metadata-freezing issue entirely. This change was authored with Claude. Differential Revision: [D101252037](https://our.internmc.facebook.com/intern/diff/D101252037/) ghstack-source-id: 368594388 Pull Request resolved: pytorch#18957
1 parent 1154d34 commit 9576316

1 file changed

Lines changed: 0 additions & 1 deletion

File tree

examples/models/llama/source_transformation/quantize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ def filter_fn(m, fqn):
190190
),
191191
)
192192
quantize_(model, q_config)
193-
model = unwrap_tensor_subclass(model)
194193

195194
return model
196195
else:

0 commit comments

Comments
 (0)