Skip to content

Commit 2604159

Browse files
committed
Replace CanonicalQuantizedWeight with torchao tensor subclasses
Delete the custom CanonicalQuantizedWeight dataclass and serialize.py format. Quantized weights are now stored as torchao's native Int4Tensor (4-bit) and IntxUnpackedToInt8Tensor (8-bit) subclasses, serialized via torchao's safetensors integration. Key changes: - quantize_weight returns Int4Tensor or IntxUnpackedToInt8Tensor - quantize_model returns a single state_dict (not two dicts) - 8-bit quantization done in float32 to avoid bf16 precision loss (manual quantize + direct IntxUnpackedToInt8Tensor construction) - Sensitive recipe uses HQQ asymmetric INT4 (scale + zero optimization) - pack_model takes a single state_dict, dispatches by isinstance - pack.py uses TorchAOBaseTensor for quantized weight detection - GGUF unpacker produces Int4Tensor/IntxUnpackedToInt8Tensor directly - serialize.py dissolved — callers inline torchao safetensors directly Breaking change: existing prequantized checkpoints (old format) must be regenerated with quantize_and_save.py.
1 parent e79f101 commit 2604159

19 files changed

Lines changed: 857 additions & 1169 deletions

examples/models/gemma4_31b/README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ both export and eager inference:
1919
| `inference.py --gguf <file>` | GGUF file (Q4_K_M, etc.) → eager generation | ~24 GB GPU |
2020
| `export.py --model-dir <hf>` | one-shot bf16 → quantize → export (no intermediate file) | ~30 GB CPU + CUDA for packing |
2121

22-
The quantized checkpoint is a safetensors file with int values + per-group
23-
scales and a JSON header describing each weight's `QuantConfig`. No tensor
24-
subclass or backend-specific packing — packing for the target backend happens
25-
at load time via `quant.pack_model()`.
22+
The quantized checkpoint is a safetensors file containing torchao tensor
23+
subclasses (`Int4Tensor`, `IntxUnpackedToInt8Tensor`) and plain tensors.
24+
Metadata records each subclass's type and attributes. No backend-specific
25+
packing — packing for the target backend happens at load time via
26+
`quant.pack_model()`.
2627

2728
## Quantization recipes
2829

examples/models/gemma4_31b/export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ def load_and_quantize(
8282
model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone())
8383

8484
print(f"Quantizing with recipe '{recipe_name}'...")
85-
quantized, unquantized = quantize_model(model, recipe)
85+
state_dict = quantize_model(model, recipe)
8686

8787
print(f"Packing for {backend}...")
8888
with torch.device("meta"):
8989
model = Gemma4_31B(config)
90-
pack_model(model, quantized, unquantized, packers=_get_packers(backend))
90+
pack_model(model, state_dict, packers=_get_packers(backend))
9191
model.eval()
9292

9393
print(f"Model: {config.num_hidden_layers} layers, hidden={config.hidden_size}")

examples/models/gemma4_31b/gguf_loader.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ def gguf_to_model_key(gguf_key: str) -> Optional[str]:
6464
return None
6565

6666

67-
def _resolve_tied_lm_head(model, embed_cw, packers):
67+
def _resolve_tied_lm_head(model, embed_quant, packers):
6868
"""Handle tied embed/lm_head after streaming all tensors."""
6969
from executorch.examples.models.gemma4_31b.quant import pack_one
7070

7171
lm_head = getattr(model.lm_head, "weight", None)
7272
if lm_head is None or lm_head.device.type != "meta":
7373
return
74-
if embed_cw is not None:
75-
pack_one(model, "lm_head.weight", embed_cw, packers)
74+
if embed_quant is not None:
75+
pack_one(model, "lm_head.weight", embed_quant, packers)
7676
else:
7777
pack_one(
7878
model,
@@ -114,9 +114,7 @@ def load_gguf_model(
114114
from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig
115115
from executorch.examples.models.gemma4_31b.quant import dequantize_weight, pack_one
116116
from executorch.examples.models.gemma4_31b.quant.gguf import iter_gguf_tensors
117-
from executorch.examples.models.gemma4_31b.quant.serialize import (
118-
CanonicalQuantizedWeight,
119-
)
117+
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
120118

121119
if backend == "cuda":
122120
from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS
@@ -131,7 +129,7 @@ def load_gguf_model(
131129
with torch.device("meta"):
132130
model = Gemma4_31B(config)
133131

134-
embed_cw = None
132+
embed_quant = None
135133
n_processed = 0
136134

137135
print(f"Streaming GGUF from {gguf_path}...")
@@ -140,13 +138,11 @@ def load_gguf_model(
140138
if model_key is None:
141139
continue
142140

143-
if isinstance(result, torch.Tensor) and result.dtype == torch.float32:
141+
if type(result) is torch.Tensor and result.dtype == torch.float32:
144142
result = result.to(torch.bfloat16)
145143

146-
if model_key == "embed_tokens.weight" and isinstance(
147-
result, CanonicalQuantizedWeight
148-
):
149-
embed_cw = result
144+
if model_key == "embed_tokens.weight" and isinstance(result, Int4Tensor):
145+
embed_quant = result
150146
result = dequantize_weight(result, torch.bfloat16)
151147

152148
pack_one(model, model_key, result, packers)
@@ -155,8 +151,8 @@ def load_gguf_model(
155151
if n_processed % 100 == 0:
156152
print(f" Processed {n_processed} tensors...")
157153

158-
_resolve_tied_lm_head(model, embed_cw, packers)
159-
del embed_cw
154+
_resolve_tied_lm_head(model, embed_quant, packers)
155+
del embed_quant
160156

161157
_validate_no_meta(model)
162158
model.eval()

examples/models/gemma4_31b/model.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,10 @@ Modules in `quant/`:
126126
- **Recipe** (`recipe.py`): `QuantConfig` + `QuantRule` + `QuantRecipe`.
127127
Declares what to quantize — says nothing about packing or backends.
128128
- **Quantize** (`quantize.py`): `quantize_weight` / `dequantize_weight` /
129-
`quantize_model`. Produces `CanonicalQuantizedWeight` from fp weights.
130-
- **Serialize** (`serialize.py`): `CanonicalQuantizedWeight` (int8 qdata +
131-
bf16 scale + optional zero). `save()` / `load()` persist to safetensors.
129+
`quantize_model`. Produces torchao tensor subclasses (`Int4Tensor`,
130+
`IntxUnpackedToInt8Tensor`) from fp weights.
131+
- **Serialization**: callers use torchao's safetensors integration
132+
(`torchao.prototype.safetensors`) directly — no wrapper module needed.
132133
- **Pack** (`pack.py` + `pack_cuda.py`): `pack_model` groups weights by
133134
parent module, `pack_one` handles single weights. Per-module packers
134135
dispatch by module type (`nn.Linear`, `nn.Embedding`, extensible for MoE).
@@ -142,11 +143,11 @@ quantize_and_save.py export.py / inference.py
142143
| |
143144
bf16 weights quantized checkpoint (safetensors)
144145
| |
145-
quantize_weight() load()
146+
quantize_weight() load (torchao safetensors)
146147
| |
147-
CanonicalQuantizedWeight CanonicalQuantizedWeight
148+
Int4Tensor / IntxUnpacked Int4Tensor / IntxUnpacked
148149
| |
149-
save() pack_model()
150+
save (torchao safetensors) pack_model()
150151
| |
151152
model.safetensors Int4TilePackedTo4dTensor (runtime)
152153
```
Lines changed: 24 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
11
# quant/
22

3-
Packing-agnostic quantization framework: **recipe → quantize → serialize → pack**.
3+
Quantization framework: **recipe → quantize → pack**.
44

55
## Files
66

77
| File | Concern | Depends on |
88
|---|---|---|
99
| `recipe.py` | **Policy** — what to quantize, what precision, which layers | nothing |
10-
| `quantize.py` | **Computation** — produces/dequantizes canonical weights | recipe, torchao |
11-
| `serialize.py` | **Data format** — saves/loads canonical weights to safetensors | recipe |
12-
| `pack.py` | **Packing dispatch**`pack_model` (bulk) and `pack_one` (streaming) | serialize |
13-
| `pack_cuda.py` | **CUDA packing** — converts canonical to tinygemm/intx runtime format | pack, serialize |
14-
| `gguf.py` | **GGUF import** — unpacks Q4_K/Q6_K blocks to canonical form | recipe, serialize |
10+
| `quantize.py` | **Computation** — produces torchao subclass tensors | recipe, torchao |
11+
| `pack.py` | **Packing dispatch**`pack_model` (bulk) and `pack_one` (streaming) ||
12+
| `pack_cuda.py` | **CUDA packing** — converts Int4Tensor to tinygemm format | pack |
13+
| `gguf.py` | **GGUF import** — unpacks Q4_K/Q6_K blocks to torchao subclasses | torchao |
1514

1615
## Data flow
1716

1817
```
19-
QuantRecipe → quantize_model() → CanonicalQuantizedWeight → save()fileload() → CanonicalQuantizedWeight → pack_model() → runtime model
18+
QuantRecipe → quantize_model() → state_dict{Int4Tensor, IntxUnpackedToInt8Tensor, Tensor}safetensorsstate_dict → pack_model() → runtime model
2019
```
2120

22-
`CanonicalQuantizedWeight` is the interchange point — int8 qdata + bf16
23-
scale + optional zero + config. Everything left of it is backend-agnostic.
24-
Everything right is backend-specific.
21+
Quantized weights are stored as torchao tensor subclasses:
22+
- **Int4Tensor** — 4-bit weights (nibble-packed qdata + transposed scale/zero_point)
23+
- **IntxUnpackedToInt8Tensor** — 8-bit weights (int8 qdata + scale + zero_point)
24+
25+
These are the canonical interchange formats from torchao. Everything left
26+
of `save()` is backend-agnostic. Everything right is backend-specific.
2527

2628
## Adding a new backend
2729

@@ -32,56 +34,21 @@ def pack_linear_for_metal(module, weights): ...
3234
DEFAULT_METAL_PACKERS = {nn.Linear: pack_linear_for_metal}
3335
```
3436

35-
Call `pack_model(model, quantized, unquantized, packers=DEFAULT_METAL_PACKERS)`.
36-
No changes to recipe, quantize, or serialize.
37-
38-
Things to consider:
39-
40-
- **Recipes may need to be backend-aware.** Each backend's kernels have
41-
different constraints (e.g., Metal's `fpa4w` is INT4-only — no INT8 linear
42-
kernel, so the sensitive recipe's 8-bit edge layers would need to be INT4
43-
or dequantized to bf16). Define per-backend recipes or validate recipe
44-
compatibility at pack time.
45-
- **Source transforms before packing.** Some backends replace model modules
46-
(e.g., MLX swaps `FusedMoEExperts``SwitchMLP`, Metal swaps to
47-
`MetalMoEExperts`). These transforms change the module types that
48-
packers dispatch on, so they must run before `pack_model()`. For dense
49-
models (no MoE) this is not needed.
50-
- **Embedding quantization.** Not all backends have a quantized embedding
51-
gather kernel. The packer can dequantize to bf16 at load time — the
52-
disk savings from the canonical format still apply.
53-
54-
## Adding a new model
55-
56-
1. Define a `QuantRecipe` with rules for the model's FQN patterns.
57-
2. If the model has custom module types (e.g., `FusedMoEExperts`), write a
58-
per-module packer and extend the packers dict:
59-
```python
60-
packers = {**DEFAULT_CUDA_PACKERS, FusedMoEExperts: pack_moe_experts}
61-
```
62-
3. No changes to the quant package itself.
37+
Call `pack_model(model, state_dict, packers=DEFAULT_METAL_PACKERS)`.
38+
No changes to recipe or quantize.
6339

6440
## On-disk format
6541

66-
Safetensors with a `format_version` in the header. Per quantized weight:
67-
`{fqn}.qdata` (int8, nibble-packed for 4-bit), `{fqn}.scale` (bf16),
68-
optionally `{fqn}.zero` (bf16). Header JSON records bits, group_size,
69-
symmetric, and method per weight. Unquantized weights stored as-is.
42+
Uses torchao's safetensors integration (`torchao.prototype.safetensors`).
43+
Each tensor subclass is decomposed into its inner tensors
44+
(e.g., `layer._weight_qdata`, `layer._weight_scale`) plus JSON metadata
45+
recording the subclass type and attributes. Plain tensors are stored as-is.
46+
The format is compatible with torchao's `save_pretrained` / `load_pretrained`.
7047

7148
## TODO
7249

73-
- `pack_metal.py` — Metal backend packer. Convert canonical INT4 to
74-
`UIntxWeightOnlyConfig` subclass (torchao experimental) for the
75-
`torchao::_linear_fp_act_4bit_weight` kernel. For MoE models, pack
76-
expert weights into Metal's `gather_qmv` format (asymmetric, unsigned
77-
INT4 with scale + bias buffers).
78-
79-
- `pack_mlx.py` — MLX backend packer. Convert canonical INT4 to
80-
`IntxWeightOnlyConfig` subclass for the `mlx::gather_qmm` kernel.
81-
For MoE models, stack per-expert weights into `SwitchLinear` format.
82-
83-
- `gguf.py` — extend with Q5_K, Q8_0, and other GGUF quant types.
84-
Currently supports Q4_K and Q6_K. Some Q4_K_M files also contain
85-
Q5_K or Q8_0 tensors (for sensitive layers on certain architectures)
86-
which will raise — add support as needed. Q6_K is widened to 8-bit
87-
for CUDA packing since there is no 6-bit CUDA kernel.
50+
- `pack_metal.py` — Metal backend packer.
51+
- `pack_mlx.py` — MLX backend packer.
52+
- `gguf.py` — extend with Q5_K, Q8_0 GGUF quant types.
53+
- Upstream `Int4TilePackedTo4dTensor.from_int4_tensor()` to torchao
54+
to replace the manual conversion in `pack_int4_for_cuda`.

examples/models/gemma4_31b/quant/__init__.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,3 @@
88
from .pack_cuda import DEFAULT_CUDA_PACKERS, load_and_pack_for_cuda # noqa: F401
99
from .quantize import dequantize_weight, quantize_model, quantize_weight # noqa: F401
1010
from .recipe import QuantConfig, QuantRecipe, QuantRule # noqa: F401
11-
from .serialize import ( # noqa: F401
12-
CanonicalQuantizedWeight,
13-
deserialize,
14-
load,
15-
save,
16-
serialize,
17-
)

0 commit comments

Comments
 (0)