Skip to content

Commit 12c1195

Browse files
authored
Arm backend: add VGF PT2E linear quantization modes for LLM export (#19029)
- add vgf_16a8w/8a8w PT2E quantization modes - add backend.vgf.quantize_scope for full vs linear VGF quantization - wire the VGF config through the LLM export and quantizer selection path - add coverage in export_llama_lib tests for the new VGF PT2E modes cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani Signed-off-by: Xingguo Li <xingguo.li@arm.com>
1 parent 920b493 commit 12c1195

4 files changed

Lines changed: 102 additions & 2 deletions

File tree

examples/models/llama/export_llama_lib.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ def build_args_parser() -> argparse.ArgumentParser:
229229
"vulkan_8w",
230230
"tosa_8a8w",
231231
"ethosu_8a8w",
232+
"vgf_8a8w",
233+
"vgf_16a8w",
232234
],
233235
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
234236
)
@@ -456,6 +458,18 @@ def build_args_parser() -> argparse.ArgumentParser:
456458
)
457459
parser.add_argument("-V", "--vulkan", action="store_true")
458460
parser.add_argument("--vulkan-force-fp16", action="store_true")
461+
parser.add_argument("--vgf", "--arm-vgf", dest="vgf", action="store_true")
462+
parser.add_argument(
463+
"--vgf-compile-spec",
464+
default="TOSA-1.0+INT",
465+
help="VGF compile spec, e.g. TOSA-1.0+INT or TOSA-1.0+INT+int16.",
466+
)
467+
parser.add_argument(
468+
"--vgf-quantize-scope",
469+
default="full",
470+
choices=["full", "linear"],
471+
help="VGF quantization scope. Use 'linear' to quantize only Linear modules.",
472+
)
459473
parser.add_argument("--mps", action="store_true")
460474
parser.add_argument("--coreml", action="store_true")
461475
parser.add_argument(
@@ -847,6 +861,7 @@ def get_quantizer_and_quant_params(llm_config):
847861
llm_config.backend.vgf.compile_spec,
848862
llm_config.backend.vgf.compiler_flags,
849863
llm_config.quantization.pt2e_quantize.value,
864+
llm_config.backend.vgf.quantize_scope.value,
850865
)
851866
quantizers.append(vgf_quantizer)
852867
if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize:

examples/models/llama/tests/test_export_llama_lib.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import unittest
99

10+
import torch
11+
1012
from executorch.devtools.backend_debug import get_delegation_info
1113

1214
try:
@@ -28,7 +30,11 @@
2830
build_args_parser,
2931
get_quantizer_and_quant_params,
3032
)
31-
from executorch.extension.llm.export.config.llm_config import LlmConfig, Pt2eQuantize
33+
from executorch.extension.llm.export.config.llm_config import (
34+
LlmConfig,
35+
Pt2eQuantize,
36+
VgfQuantizeScope,
37+
)
3238

3339
UNWANTED_OPS = [
3440
"aten_permute_copy_default",
@@ -111,3 +117,49 @@ def test_get_quantizer_and_quant_params_returns_vgf_quantizer(self):
111117
self.assertIsNone(quant_dtype)
112118
self.assertEqual(len(quantizers), 1)
113119
self.assertIsInstance(quantizers[0], VgfQuantizer)
120+
self.assertIsNotNone(quantizers[0].global_config)
121+
self.assertEqual(quantizers[0].module_type_config, {})
122+
123+
@unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available")
124+
def test_get_quantizer_and_quant_params_returns_vgf_linear_quantizer(self):
125+
llm_config = LlmConfig()
126+
llm_config.backend.vgf.enabled = True
127+
llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT"
128+
llm_config.backend.vgf.quantize_scope = VgfQuantizeScope.linear
129+
llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_8a8w
130+
131+
_pt2e_quant_params, quantizers, _quant_dtype = get_quantizer_and_quant_params(
132+
llm_config
133+
)
134+
135+
self.assertEqual(len(quantizers), 1)
136+
self.assertIsInstance(quantizers[0], VgfQuantizer)
137+
self.assertIsNone(quantizers[0].global_config)
138+
self.assertIn(torch.nn.Linear, quantizers[0].module_type_config)
139+
140+
@unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available")
141+
def test_vgf_16a8w_requires_int16_compile_spec_extension(self):
142+
llm_config = LlmConfig()
143+
llm_config.backend.vgf.enabled = True
144+
llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT"
145+
llm_config.backend.vgf.quantize_scope = VgfQuantizeScope.linear
146+
llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_16a8w
147+
148+
with self.assertRaisesRegex(ValueError, "INT16 support"):
149+
get_quantizer_and_quant_params(llm_config)
150+
151+
@unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available")
152+
def test_vgf_16a8w_accepts_int16_compile_spec_extension(self):
153+
llm_config = LlmConfig()
154+
llm_config.backend.vgf.enabled = True
155+
llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT+int16"
156+
llm_config.backend.vgf.quantize_scope = VgfQuantizeScope.linear
157+
llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_16a8w
158+
159+
_pt2e_quant_params, quantizers, _quant_dtype = get_quantizer_and_quant_params(
160+
llm_config
161+
)
162+
163+
self.assertEqual(len(quantizers), 1)
164+
self.assertIsInstance(quantizers[0], VgfQuantizer)
165+
self.assertIn(torch.nn.Linear, quantizers[0].module_type_config)

extension/llm/export/config/llm_config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ class Pt2eQuantize(str, Enum):
377377
tosa_8a8w = "tosa_8a8w"
378378
ethosu_8a8w = "ethosu_8a8w"
379379
vgf_8a8w = "vgf_8a8w"
380+
vgf_16a8w = "vgf_16a8w"
380381

381382

382383
class SpinQuant(str, Enum):
@@ -587,6 +588,11 @@ class EthosUConfig:
587588
system_config: str = "default"
588589

589590

591+
class VgfQuantizeScope(str, Enum):
592+
full = "full"
593+
linear = "linear"
594+
595+
590596
@dataclass
591597
class VgfConfig:
592598
"""
@@ -596,6 +602,7 @@ class VgfConfig:
596602
enabled: bool = False
597603
compile_spec: Optional[str] = "TOSA-1.0+INT"
598604
compiler_flags: List[str] = field(default_factory=list)
605+
quantize_scope: VgfQuantizeScope = VgfQuantizeScope.full
599606

600607

601608
@dataclass
@@ -815,6 +822,16 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
815822
if hasattr(args, "group_size") and args.group_size:
816823
llm_config.backend.openvino.nncf_compression_group_size = args.group_size
817824

825+
# VGF
826+
if hasattr(args, "vgf"):
827+
llm_config.backend.vgf.enabled = args.vgf
828+
if hasattr(args, "vgf_compile_spec"):
829+
llm_config.backend.vgf.compile_spec = args.vgf_compile_spec
830+
if hasattr(args, "vgf_quantize_scope") and args.vgf_quantize_scope:
831+
llm_config.backend.vgf.quantize_scope = VgfQuantizeScope(
832+
args.vgf_quantize_scope
833+
)
834+
818835
# TorchAoKernels
819836
if any(
820837
hasattr(args, a)

extension/llm/export/quantizer_lib.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,8 +367,10 @@ def get_vgf_quantizer(
367367
compile_spec: Optional[str],
368368
compiler_flags: Optional[List[str]],
369369
pt2e_quantize: str,
370+
quantize_scope: str,
370371
):
371372
from executorch.backends.arm.quantizer.arm_quantizer import (
373+
get_symmetric_a16w8_quantization_config,
372374
get_symmetric_quantization_config,
373375
VgfQuantizer,
374376
)
@@ -379,8 +381,22 @@ def get_vgf_quantizer(
379381
quantizer = VgfQuantizer(compile_spec_obj)
380382

381383
if pt2e_quantize == "vgf_8a8w":
382-
quantizer.set_global(get_symmetric_quantization_config())
384+
quantization_config = get_symmetric_quantization_config()
385+
elif pt2e_quantize == "vgf_16a8w":
386+
if not compile_spec_obj.tosa_spec.support_extension("int16"):
387+
raise ValueError(
388+
"vgf_16a8w requires a VGF compile spec with INT16 support, "
389+
"for example TOSA-1.0+INT+int16."
390+
)
391+
quantization_config = get_symmetric_a16w8_quantization_config()
383392
else:
384393
raise ValueError(f"Unsupported quantizer specification {pt2e_quantize}")
385394

395+
if quantize_scope == "full":
396+
quantizer.set_global(quantization_config)
397+
elif quantize_scope == "linear":
398+
quantizer.set_module_type(torch.nn.Linear, quantization_config)
399+
else:
400+
raise ValueError(f"Unsupported VGF quantization scope {quantize_scope}")
401+
386402
return quantizer

0 commit comments

Comments
 (0)