Skip to content

Commit 7b843e4

Browse files
Add support for exporting 8-bit quantized Token2Wav model
Differential Revision: D93525295 Pull Request resolved: #17517
1 parent 0c92654 commit 7b843e4

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

examples/models/llama/source_transformation/quantize.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def quantize( # noqa C901
131131
if verbose:
132132
print("quantized model:", model)
133133
return model
134-
elif qmode == "8da4w":
134+
elif qmode in ("8da4w", "8da8w"):
135135
if group_size is None:
136136
# TODO: Default value for group size for 8da4w. Need this here for refactor, will clean this up.
137137
group_size = 128
@@ -169,11 +169,12 @@ def filter_fn(m, fqn):
169169
is_linear or is_lora_linear
170170
) and has_shape_compatible_with_group_size
171171

172+
weight_dtype = torch.int4 if qmode == "8da4w" else torch.int8
172173
quantize_(
173174
model,
174175
Int8DynamicActivationIntxWeightConfig(
175176
# pyre-ignore[16]
176-
weight_dtype=torch.int4,
177+
weight_dtype=weight_dtype,
177178
weight_granularity=(
178179
PerAxis(0) if group_size == 0 else PerGroup(group_size)
179180
),

0 commit comments

Comments
 (0)