99import os .path
1010import shutil
1111import subprocess
12+ from copy import deepcopy
1213from enum import Enum
1314from os import mkdir
15+ from typing import Callable
1416
1517import numpy as np
1618import torch
@@ -65,6 +67,7 @@ def _run_delegated_executorch_program(
6567 npu_results_dir ,
6668 mocker ,
6769 use_qat : bool = False ,
70+ train_fn : Callable [[torch .fx .GraphModule ], None ] | None = None ,
6871) -> ExportedProgram :
6972 if len (input_spec ) == 1 :
7073 # Single input, use --dataset
@@ -112,6 +115,7 @@ def wrapper(*args, **kwargs):
112115 calibration_dataset_dir ,
113116 delegate_to_npu = True ,
114117 use_qat = use_qat ,
118+ train_fn = train_fn ,
115119 )
116120 except RuntimeError as e :
117121 if "Model converted with neutron-converter has" in str (e ):
@@ -139,6 +143,8 @@ def _run_non_delegated_executorch_program(
139143 testing_dataset_dir ,
140144 input_spec ,
141145 cpu_results_dir ,
146+ use_qat : bool = False ,
147+ train_fn : Callable [[torch .fx .GraphModule ], None ] | None = None ,
142148) -> ExportedProgram :
143149 if len (input_spec ) == 1 :
144150 # Single input, use --dataset
@@ -165,7 +171,12 @@ def _run_non_delegated_executorch_program(
165171 --output { cpu_results_dir } --firmware { NSYS_FIRMWARE_PATH } --nsys { NSYS_PATH } --nsys_config { NSYS_CONFIG_PATH } "
166172
167173 non_delegated_program = to_quantized_executorch_program (
168- model , input_spec , calibration_dataset_dir , delegate_to_npu = False
174+ model ,
175+ input_spec ,
176+ calibration_dataset_dir ,
177+ delegate_to_npu = False ,
178+ use_qat = use_qat ,
179+ train_fn = train_fn ,
169180 )
170181
171182 nodes = list (non_delegated_program .exported_program ().graph .nodes )
@@ -348,6 +359,12 @@ def _run_python_program(
348359 store_results (all_outputs , cpu_results_dir , npu_results_dir )
349360
350361
362+ def assert_NSYS ():
363+ assert os .path .exists (NSYS_PATH )
364+ assert os .path .exists (NSYS_CONFIG_PATH )
365+ assert os .path .exists (NSYS_FIRMWARE_PATH )
366+
367+
351368def convert_run_compare (
352369 model : torch .nn .Module ,
353370 input_spec : list [ModelInputSpec ] | tuple ,
@@ -357,6 +374,7 @@ def convert_run_compare(
357374 mocker : MockerFixture = None ,
358375 reference_model : ReferenceModel = ReferenceModel .QUANTIZED_EXECUTORCH_CPP ,
359376 use_qat : bool = False ,
377+ train_fn : Callable [[torch .fx .GraphModule ], None ] | None = None ,
360378):
361379 """
362380 Run provided program twice with neutron-test and check if results correspond. At first,
@@ -372,16 +390,18 @@ def convert_run_compare(
372390 :param reference_model: Version of the model which will be run to obtain reference output data.
373391 :param mocker: Mocker instance used by visualizer.
374392 :param use_qat: If True, applies quantization-aware training before conversion (without the QAT training).
393+ :param train_fn: Train/finetune function for QAT training. Is used only when `use_qat=True`.
375394 """
376- assert os .path .exists (NSYS_PATH )
377- assert os .path .exists (NSYS_CONFIG_PATH )
378- assert os .path .exists (NSYS_FIRMWARE_PATH )
395+ assert_NSYS ()
379396
380397 if not dataset_creator :
381398 dataset_creator = RandomDatasetCreator ()
382399 if not output_comparator :
383400 output_comparator = AllCloseOutputComparator ()
384401
402+ model_to_delegate = model
403+ model_to_not_delegate = deepcopy (model )
404+
385405 test_name = _get_caller_name ()
386406 test_dir = os .path .join (OUTPUTS_DIR , test_name )
387407
@@ -401,7 +421,7 @@ def convert_run_compare(
401421 npu_results_dir = os .path .join (test_dir , "results_npu" )
402422
403423 delegated_program = _run_delegated_executorch_program (
404- model ,
424+ model_to_delegate ,
405425 test_dir ,
406426 test_name ,
407427 calibration_dataset_dir ,
@@ -411,6 +431,7 @@ def convert_run_compare(
411431 npu_results_dir ,
412432 mocker ,
413433 use_qat = use_qat ,
434+ train_fn = train_fn ,
414435 )
415436
416437 output_spec = _get_program_output_spec (delegated_program )
@@ -420,24 +441,27 @@ def convert_run_compare(
420441 # Lower to quantized executorch program, export to `.pte` file and run in c++ using
421442 # examples/nxp/executor_runner/nxp_executor_runner.cpp
422443 _run_non_delegated_executorch_program (
423- model ,
444+ model_to_not_delegate ,
424445 test_dir ,
425446 test_name ,
426447 calibration_dataset_dir ,
427448 testing_dataset_dir ,
428449 input_spec ,
429450 cpu_results_dir ,
451+ use_qat = use_qat ,
452+ train_fn = train_fn ,
430453 )
431454
432455 case ReferenceModel .QUANTIZED_EDGE_PYTHON :
433456 # Lower to quantized edge program and run in Python.
434457 non_delegated_edge_program = (
435458 to_quantized_edge_program (
436- model ,
459+ model_to_not_delegate ,
437460 input_spec ,
438461 calibration_dataset_dir ,
439462 delegate_to_npu = False ,
440463 use_qat = use_qat ,
464+ train_fn = train_fn ,
441465 )
442466 .exported_program ()
443467 .module ()
@@ -454,7 +478,7 @@ def convert_run_compare(
454478 case ReferenceModel .FLOAT_PYTORCH_PYTHON :
455479 # Run the PyTorch nn.Module directly in Python.
456480 _run_python_program (
457- model ,
481+ model_to_not_delegate ,
458482 testing_dataset_dir ,
459483 input_spec ,
460484 output_spec ,
@@ -474,9 +498,96 @@ def convert_run_compare(
474498 )
475499
476500
501+ def convert_run_compare_ptq_qat (
502+ model : torch .nn .Module ,
503+ input_spec : list [ModelInputSpec ] | tuple ,
504+ dlg_model_verifier : GraphVerifier ,
505+ train_fn : Callable [[torch .fx .GraphModule ], None ],
506+ dataset_creator = None ,
507+ output_comparator = None ,
508+ mocker : MockerFixture = None ,
509+ ):
510+ """
511+ Run provided program twice and compare it's results.
512+ The model is once quantized with PTQ and with QAT.
513+
514+ :param model: Executed PyTorch model.
515+ :param input_spec: Model input specification. Can be either tuple - single float32 input model - or list
516+ of ModelInputSpec.
517+ :param dlg_model_verifier: Graph verifier instance.
518+ :param train_fn: Train/finetune function for QAT training.
519+ :param dataset_creator: Creator that should fill provided `dataset_dir` with model input samples.
520+ :param output_comparator: Comparator of results produced by NPU and CPU runs of the program.
521+ :param mocker: Mocker instance used by visualizer.
522+ """
523+ assert_NSYS ()
524+
525+ if not dataset_creator :
526+ dataset_creator = RandomDatasetCreator ()
527+ if not output_comparator :
528+ output_comparator = AllCloseOutputComparator ()
529+
530+ model_ptq = model
531+ model_qat = deepcopy (model )
532+
533+ test_name = _get_caller_name ()
534+ test_dir = os .path .join (OUTPUTS_DIR , test_name )
535+
536+ shutil .rmtree (test_dir , ignore_errors = True )
537+ mkdir (test_dir )
538+
539+ dataset_dir = os .path .join (test_dir , "dataset" )
540+ mkdir (dataset_dir )
541+ if isinstance (input_spec , tuple ):
542+ input_spec = [ModelInputSpec (input_spec )]
543+
544+ (calibration_dataset_dir , testing_dataset_dir ) = dataset_creator .generate_samples (
545+ dataset_dir , input_spec
546+ )
547+
548+ ptq_results_dir = os .path .join (test_dir , "results_ptq" )
549+ qat_results_dir = os .path .join (test_dir , "results_qat" )
550+
551+ delegated_program_ptq = _run_delegated_executorch_program (
552+ model_ptq ,
553+ test_dir ,
554+ test_name ,
555+ calibration_dataset_dir ,
556+ testing_dataset_dir ,
557+ input_spec ,
558+ dlg_model_verifier ,
559+ ptq_results_dir ,
560+ mocker ,
561+ use_qat = False ,
562+ )
563+
564+ _ = _run_delegated_executorch_program (
565+ model_qat ,
566+ test_dir ,
567+ test_name ,
568+ calibration_dataset_dir ,
569+ testing_dataset_dir ,
570+ input_spec ,
571+ dlg_model_verifier ,
572+ qat_results_dir ,
573+ mocker ,
574+ use_qat = True ,
575+ train_fn = train_fn ,
576+ )
577+
578+ output_tensor_spec = _get_program_output_spec (delegated_program_ptq )
579+
580+ ptq_results_dir = os .path .join (test_dir , "results_ptq" )
581+ qat_results_dir = os .path .join (test_dir , "results_qat" )
582+ output_comparator .compare_results (
583+ ptq_results_dir , qat_results_dir , output_tensor_spec
584+ )
585+
586+
477587def _get_caller_name ():
588+ test_function_names = ["convert_run_compare" , "convert_run_compare_ptq_qat" ]
478589 for idx , frame in enumerate (inspect .stack ()):
479- if frame .function == "convert_run_compare" :
590+ if frame .function in test_function_names :
480591 # Look one index above to get caller
481592 return inspect .stack ()[idx + 1 ].function
482593
0 commit comments