Skip to content

Commit b24535b

Browse files
NXP backend: Add utilities for QAT testing (#17623)
### Summary Add utilities to QAT model testing. ### Test plan N/A --------- Co-authored-by: Simon Strycek <simon.strycek@nxp.com>
1 parent 8b30cfe commit b24535b

File tree

3 files changed

+289
-32
lines changed

3 files changed

+289
-32
lines changed

backends/nxp/tests_models/executors.py

Lines changed: 120 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import os.path
1010
import shutil
1111
import subprocess
12+
from copy import deepcopy
1213
from enum import Enum
1314
from os import mkdir
15+
from typing import Callable
1416

1517
import numpy as np
1618
import torch
@@ -65,6 +67,7 @@ def _run_delegated_executorch_program(
6567
npu_results_dir,
6668
mocker,
6769
use_qat: bool = False,
70+
train_fn: Callable[[torch.fx.GraphModule], None] | None = None,
6871
) -> ExportedProgram:
6972
if len(input_spec) == 1:
7073
# Single input, use --dataset
@@ -112,6 +115,7 @@ def wrapper(*args, **kwargs):
112115
calibration_dataset_dir,
113116
delegate_to_npu=True,
114117
use_qat=use_qat,
118+
train_fn=train_fn,
115119
)
116120
except RuntimeError as e:
117121
if "Model converted with neutron-converter has" in str(e):
@@ -139,6 +143,8 @@ def _run_non_delegated_executorch_program(
139143
testing_dataset_dir,
140144
input_spec,
141145
cpu_results_dir,
146+
use_qat: bool = False,
147+
train_fn: Callable[[torch.fx.GraphModule], None] | None = None,
142148
) -> ExportedProgram:
143149
if len(input_spec) == 1:
144150
# Single input, use --dataset
@@ -165,7 +171,12 @@ def _run_non_delegated_executorch_program(
165171
--output {cpu_results_dir} --firmware {NSYS_FIRMWARE_PATH} --nsys {NSYS_PATH} --nsys_config {NSYS_CONFIG_PATH}"
166172

167173
non_delegated_program = to_quantized_executorch_program(
168-
model, input_spec, calibration_dataset_dir, delegate_to_npu=False
174+
model,
175+
input_spec,
176+
calibration_dataset_dir,
177+
delegate_to_npu=False,
178+
use_qat=use_qat,
179+
train_fn=train_fn,
169180
)
170181

171182
nodes = list(non_delegated_program.exported_program().graph.nodes)
@@ -348,6 +359,12 @@ def _run_python_program(
348359
store_results(all_outputs, cpu_results_dir, npu_results_dir)
349360

350361

362+
def assert_NSYS():
363+
assert os.path.exists(NSYS_PATH)
364+
assert os.path.exists(NSYS_CONFIG_PATH)
365+
assert os.path.exists(NSYS_FIRMWARE_PATH)
366+
367+
351368
def convert_run_compare(
352369
model: torch.nn.Module,
353370
input_spec: list[ModelInputSpec] | tuple,
@@ -357,6 +374,7 @@ def convert_run_compare(
357374
mocker: MockerFixture = None,
358375
reference_model: ReferenceModel = ReferenceModel.QUANTIZED_EXECUTORCH_CPP,
359376
use_qat: bool = False,
377+
train_fn: Callable[[torch.fx.GraphModule], None] | None = None,
360378
):
361379
"""
362380
Run provided program twice with neutron-test and check if results correspond. At first,
@@ -372,16 +390,18 @@ def convert_run_compare(
372390
:param reference_model: Version of the model which will be run to obtain reference output data.
373391
:param mocker: Mocker instance used by visualizer.
374392
:param use_qat: If True, applies quantization-aware training before conversion (without the QAT training).
393+
:param train_fn: Train/finetune function for QAT training. Is used only when `use_qat=True`.
375394
"""
376-
assert os.path.exists(NSYS_PATH)
377-
assert os.path.exists(NSYS_CONFIG_PATH)
378-
assert os.path.exists(NSYS_FIRMWARE_PATH)
395+
assert_NSYS()
379396

380397
if not dataset_creator:
381398
dataset_creator = RandomDatasetCreator()
382399
if not output_comparator:
383400
output_comparator = AllCloseOutputComparator()
384401

402+
model_to_delegate = model
403+
model_to_not_delegate = deepcopy(model)
404+
385405
test_name = _get_caller_name()
386406
test_dir = os.path.join(OUTPUTS_DIR, test_name)
387407

@@ -401,7 +421,7 @@ def convert_run_compare(
401421
npu_results_dir = os.path.join(test_dir, "results_npu")
402422

403423
delegated_program = _run_delegated_executorch_program(
404-
model,
424+
model_to_delegate,
405425
test_dir,
406426
test_name,
407427
calibration_dataset_dir,
@@ -411,6 +431,7 @@ def convert_run_compare(
411431
npu_results_dir,
412432
mocker,
413433
use_qat=use_qat,
434+
train_fn=train_fn,
414435
)
415436

416437
output_spec = _get_program_output_spec(delegated_program)
@@ -420,24 +441,27 @@ def convert_run_compare(
420441
# Lower to quantized executorch program, export to `.pte` file and run in c++ using
421442
# examples/nxp/executor_runner/nxp_executor_runner.cpp
422443
_run_non_delegated_executorch_program(
423-
model,
444+
model_to_not_delegate,
424445
test_dir,
425446
test_name,
426447
calibration_dataset_dir,
427448
testing_dataset_dir,
428449
input_spec,
429450
cpu_results_dir,
451+
use_qat=use_qat,
452+
train_fn=train_fn,
430453
)
431454

432455
case ReferenceModel.QUANTIZED_EDGE_PYTHON:
433456
# Lower to quantized edge program and run in Python.
434457
non_delegated_edge_program = (
435458
to_quantized_edge_program(
436-
model,
459+
model_to_not_delegate,
437460
input_spec,
438461
calibration_dataset_dir,
439462
delegate_to_npu=False,
440463
use_qat=use_qat,
464+
train_fn=train_fn,
441465
)
442466
.exported_program()
443467
.module()
@@ -454,7 +478,7 @@ def convert_run_compare(
454478
case ReferenceModel.FLOAT_PYTORCH_PYTHON:
455479
# Run the PyTorch nn.Module directly in Python.
456480
_run_python_program(
457-
model,
481+
model_to_not_delegate,
458482
testing_dataset_dir,
459483
input_spec,
460484
output_spec,
@@ -474,9 +498,96 @@ def convert_run_compare(
474498
)
475499

476500

501+
def convert_run_compare_ptq_qat(
502+
model: torch.nn.Module,
503+
input_spec: list[ModelInputSpec] | tuple,
504+
dlg_model_verifier: GraphVerifier,
505+
train_fn: Callable[[torch.fx.GraphModule], None],
506+
dataset_creator=None,
507+
output_comparator=None,
508+
mocker: MockerFixture = None,
509+
):
510+
"""
511+
Run provided program twice and compare it's results.
512+
The model is once quantized with PTQ and with QAT.
513+
514+
:param model: Executed PyTorch model.
515+
:param input_spec: Model input specification. Can be either tuple - single float32 input model - or list
516+
of ModelInputSpec.
517+
:param dlg_model_verifier: Graph verifier instance.
518+
:param train_fn: Train/finetune function for QAT training.
519+
:param dataset_creator: Creator that should fill provided `dataset_dir` with model input samples.
520+
:param output_comparator: Comparator of results produced by NPU and CPU runs of the program.
521+
:param mocker: Mocker instance used by visualizer.
522+
"""
523+
assert_NSYS()
524+
525+
if not dataset_creator:
526+
dataset_creator = RandomDatasetCreator()
527+
if not output_comparator:
528+
output_comparator = AllCloseOutputComparator()
529+
530+
model_ptq = model
531+
model_qat = deepcopy(model)
532+
533+
test_name = _get_caller_name()
534+
test_dir = os.path.join(OUTPUTS_DIR, test_name)
535+
536+
shutil.rmtree(test_dir, ignore_errors=True)
537+
mkdir(test_dir)
538+
539+
dataset_dir = os.path.join(test_dir, "dataset")
540+
mkdir(dataset_dir)
541+
if isinstance(input_spec, tuple):
542+
input_spec = [ModelInputSpec(input_spec)]
543+
544+
(calibration_dataset_dir, testing_dataset_dir) = dataset_creator.generate_samples(
545+
dataset_dir, input_spec
546+
)
547+
548+
ptq_results_dir = os.path.join(test_dir, "results_ptq")
549+
qat_results_dir = os.path.join(test_dir, "results_qat")
550+
551+
delegated_program_ptq = _run_delegated_executorch_program(
552+
model_ptq,
553+
test_dir,
554+
test_name,
555+
calibration_dataset_dir,
556+
testing_dataset_dir,
557+
input_spec,
558+
dlg_model_verifier,
559+
ptq_results_dir,
560+
mocker,
561+
use_qat=False,
562+
)
563+
564+
_ = _run_delegated_executorch_program(
565+
model_qat,
566+
test_dir,
567+
test_name,
568+
calibration_dataset_dir,
569+
testing_dataset_dir,
570+
input_spec,
571+
dlg_model_verifier,
572+
qat_results_dir,
573+
mocker,
574+
use_qat=True,
575+
train_fn=train_fn,
576+
)
577+
578+
output_tensor_spec = _get_program_output_spec(delegated_program_ptq)
579+
580+
ptq_results_dir = os.path.join(test_dir, "results_ptq")
581+
qat_results_dir = os.path.join(test_dir, "results_qat")
582+
output_comparator.compare_results(
583+
ptq_results_dir, qat_results_dir, output_tensor_spec
584+
)
585+
586+
477587
def _get_caller_name():
588+
test_function_names = ["convert_run_compare", "convert_run_compare_ptq_qat"]
478589
for idx, frame in enumerate(inspect.stack()):
479-
if frame.function == "convert_run_compare":
590+
if frame.function in test_function_names:
480591
# Look one index above to get caller
481592
return inspect.stack()[idx + 1].function
482593

0 commit comments

Comments
 (0)