2121from nncf .common .logging import nncf_logger
2222from nncf .common .quantization .structs import QuantizationPreset
2323from nncf .data import Dataset
24+ from nncf .openvino .engine import calibration_device_context
2425from nncf .openvino .graph .metatypes .groups import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS
2526from nncf .openvino .graph .metatypes .openvino_metatypes import OVIfMetatype
2627from nncf .openvino .graph .metatypes .openvino_metatypes import get_node_metatype
@@ -119,9 +120,11 @@ def _extract_all_subgraphs(model: ov.Model, current_id: str) -> None:
119120 f"The model consists of { if_ops_number } If node(-s) with then and else bodies. \
120121 Main model and all If bodies will be quantized recursively."
121122 )
122- quantized_model , _ = apply_algorithm_if_bodies (
123- quantization_algorithm , model , graphs , main_model_graph_id , calibration_dataset , subset_size , 1
124- )
123+ calibration_device = advanced_parameters .calibration_device if advanced_parameters else None
124+ with calibration_device_context (calibration_device ):
125+ quantized_model , _ = apply_algorithm_if_bodies (
126+ quantization_algorithm , model , graphs , main_model_graph_id , calibration_dataset , subset_size , 1
127+ )
125128
126129 if is_weight_compression_needed (advanced_parameters ):
127130 compress_quantize_weights_transformation (quantized_model )
@@ -168,7 +171,9 @@ def native_quantize_impl(
168171 )
169172 graph = GraphConverter .create_nncf_graph (model )
170173 warning_model_no_batchwise_support (graph , advanced_parameters , model_type , OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS )
171- quantized_model = quantization_algorithm .apply (model , graph , dataset = calibration_dataset )
174+ calibration_device = advanced_parameters .calibration_device if advanced_parameters else None
175+ with calibration_device_context (calibration_device ):
176+ quantized_model = quantization_algorithm .apply (model , graph , dataset = calibration_dataset )
172177
173178 if is_weight_compression_needed (advanced_parameters ):
174179 compress_quantize_weights_transformation (quantized_model )
@@ -296,15 +301,19 @@ def quantize_with_accuracy_control_impl(
296301 advanced_accuracy_restorer_parameters .num_ranking_workers ,
297302 advanced_accuracy_restorer_parameters .restore_mode ,
298303 )
299- quantized_model = accuracy_restorer .apply (
300- model ,
301- initial_metric_results ,
302- quantized_model ,
303- quantized_metric_results ,
304- validation_dataset ,
305- validation_dataset_size ,
306- evaluator ,
304+ calibration_device = (
305+ advanced_quantization_parameters .calibration_device if advanced_quantization_parameters else None
307306 )
307+ with calibration_device_context (calibration_device ):
308+ quantized_model = accuracy_restorer .apply (
309+ model ,
310+ initial_metric_results ,
311+ quantized_model ,
312+ quantized_metric_results ,
313+ validation_dataset ,
314+ validation_dataset_size ,
315+ evaluator ,
316+ )
308317
309318 if compress_weights :
310319 compress_quantize_weights_transformation (quantized_model )
@@ -402,12 +411,15 @@ def compress_weights_impl(
402411 advanced_parameters ,
403412 )
404413
414+ calibration_device = advanced_parameters .calibration_device if advanced_parameters else None
415+
405416 statistics_points = None
406417 if advanced_parameters and advanced_parameters .statistics_path :
407418 # If there is no such directory, then caches statistics
408419 statistics_path = Path (advanced_parameters .statistics_path )
409420 if not statistics_path .exists ():
410- cache_weight_compression_statistics (model , graph , dataset , subset_size , statistics_path )
421+ with calibration_device_context (calibration_device ):
422+ cache_weight_compression_statistics (model , graph , dataset , subset_size , statistics_path )
411423 statistics_aggregator = StatisticsAggregatorFactory .create (model , dataset )
412424 compression_algorithm .set_backend_entity (model )
413425 _ , matmul_input_to_output_nodes_map = compression_algorithm .get_compression_nodes_info (graph )
@@ -421,4 +433,5 @@ def compress_weights_impl(
421433 statistics_aggregator .load_statistics_from_dir (statistics_path )
422434 statistics_points = statistics_aggregator .statistic_points
423435
424- return compression_algorithm .apply (model , graph , statistics_points , dataset )
436+ with calibration_device_context (calibration_device ):
437+ return compression_algorithm .apply (model , graph , statistics_points , dataset )
0 commit comments