diff --git a/.github/workflows/pre_commit.yml b/.github/workflows/pre_commit.yml index 889dd0d9..14b46acf 100644 --- a/.github/workflows/pre_commit.yml +++ b/.github/workflows/pre_commit.yml @@ -152,9 +152,12 @@ jobs: - *install-dependencies - - name: Run python unit tests + - name: Run model_api unit tests run: uv --directory model_api run pytest tests/unit --cov + - name: Run model_converter unit tests + run: uv --directory model_converter run --group tests pytest tests/unit --cov + - &prepare-test-data name: Prepare test data run: | diff --git a/model_converter/examples/config.json b/model_converter/examples/config.json index a585cc26..e63d6f8d 100644 --- a/model_converter/examples/config.json +++ b/model_converter/examples/config.json @@ -618,6 +618,35 @@ "scale_values": "58.395 57.12 57.375", "license": "apache-2.0", "license_link": "https://spdx.org/licenses/Apache-2.0.html" + }, + { + "model_short_name": "maskrcnn_resnet50_fpn", + "model_class_name": "torchvision.models.detection.maskrcnn_resnet50_fpn", + "model_library": "torchvision", + "model_full_name": "Mask R-CNN ResNet-50 FPN", + "description": "Mask R-CNN with a ResNet-50-FPN backbone trained on COCO for object detection and instance segmentation", + "docs": "https://docs.pytorch.org/vision/main/models/generated/torchvision.models.detection.maskrcnn_resnet50_fpn.html", + "weights_url": "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", + "input_shape": [1, 3, 800, 800], + "input_names": ["image"], + "output_names": ["boxes", "labels", "masks"], + "model_params": null, + "model_type": "MaskRCNN", + "reverse_input_channels": true, + "mean_values": "0 0 0", + "scale_values": "255 255 255", + "resize_type": "fit_to_window_letterbox", + "pad_value": 0, + "input_dtype": "u8", + "confidence_threshold": 0.5, + "postprocess_semantic_masks": true, + "nms_execute": false, + "iou_threshold": 0.5, + "agnostic_nms": false, + "nms_max_predictions": 200, + "license": "bsd-3-clause", + "license_link": "https://spdx.org/licenses/BSD-3-Clause.html", + "labels": "COCO_V1" } ] } diff --git a/model_converter/pyproject.toml b/model_converter/pyproject.toml index 0870d65f..ffebb6d1 100644 --- a/model_converter/pyproject.toml +++ b/model_converter/pyproject.toml @@ -178,6 +178,9 @@ fixable = ["ALL"] unfixable = [] dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +[tool.ruff.lint.per-file-ignores] +"**/tests/**/*.py" = ["SLF001", "FBT003"] + [tool.ruff.lint.mccabe] max-complexity = 15 @@ -193,6 +196,16 @@ notice-rgx = """ [tool.bandit] skips = ["B101", "B310"] +[tool.pytest.ini_options] +pythonpath = ["src"] + +[tool.coverage.run] +source = ["model_converter"] + [tool.coverage.report] -fail_under = 45 +fail_under = 100 show_missing = true +exclude_lines = [ + "if __name__ == .__main__.", + "if TYPE_CHECKING:", +] diff --git a/model_converter/src/model_converter/__init__.py b/model_converter/src/model_converter/__init__.py index bf6d76c7..05acc4b9 100644 --- a/model_converter/src/model_converter/__init__.py +++ b/model_converter/src/model_converter/__init__.py @@ -5,6 +5,6 @@ """Tools for converting models to OpenVINO IR.""" -from .model_converter import ModelConverter, list_models, main +from .cli import ModelConverter, list_models, main __all__ = ["ModelConverter", "list_models", "main"] diff --git a/model_converter/src/model_converter/__main__.py b/model_converter/src/model_converter/__main__.py index 1334cddb..674523e2 100644 --- a/model_converter/src/model_converter/__main__.py +++ b/model_converter/src/model_converter/__main__.py @@ -5,7 +5,8 @@ """Run the model converter with ``python -m model_converter``.""" -from .model_converter import main +import sys -if __name__ == "__main__": - raise SystemExit(main()) +from model_converter.cli import main + +sys.exit(main()) diff --git a/model_converter/src/model_converter/adapters/__init__.py b/model_converter/src/model_converter/adapters/__init__.py new file mode 100644 index 00000000..2c94d058 --- /dev/null +++ b/model_converter/src/model_converter/adapters/__init__.py @@ -0,0 +1,42 @@ +# +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +"""Export adapters for different model types.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from model_converter.adapters.base import ExportAdapter +from model_converter.adapters.maskrcnn import TorchvisionMaskRCNNExportAdapter + +if TYPE_CHECKING: + import torch + +_ADAPTER_REGISTRY: dict[str, type[ExportAdapter]] = { + "maskrcnn": TorchvisionMaskRCNNExportAdapter, +} + + +def get_adapter(model_type: str, model: torch.nn.Module) -> torch.nn.Module: + """ + Get the appropriate export adapter for a model type. + + If no adapter is registered for the model type, returns the model unchanged. + + Args: + model_type: Model type string (e.g., "MaskRCNN") + model: The PyTorch model to adapt + + Returns: + Adapted model (or original model if no adapter needed) + """ + adapter_class = _ADAPTER_REGISTRY.get(model_type.lower()) + if adapter_class is not None: + return adapter_class(model) + return model + + +__all__ = ["ExportAdapter", "TorchvisionMaskRCNNExportAdapter", "get_adapter"] diff --git a/model_converter/src/model_converter/adapters/base.py b/model_converter/src/model_converter/adapters/base.py new file mode 100644 index 00000000..82f8005c --- /dev/null +++ b/model_converter/src/model_converter/adapters/base.py @@ -0,0 +1,16 @@ +# +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +"""Base export adapter interface.""" + +import torch.nn as nn + + +class ExportAdapter(nn.Module): + """Base class for export adapters that reshape model outputs for Model API.""" + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model diff --git a/model_converter/src/model_converter/adapters/maskrcnn.py b/model_converter/src/model_converter/adapters/maskrcnn.py new file mode 100644 index 00000000..2a0e16c1 --- /dev/null +++ b/model_converter/src/model_converter/adapters/maskrcnn.py @@ -0,0 +1,31 @@ +# +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +"""Mask R-CNN export adapter for TorchVision models.""" + +from collections import OrderedDict + +import torch + +from model_converter.adapters.base import ExportAdapter + + +class TorchvisionMaskRCNNExportAdapter(ExportAdapter): + """Adapt TorchVision Mask R-CNN to the Model API MaskRCNN output contract.""" + + def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Return boxes-with-scores, shifted labels, and raw masks for one image.""" + image_list = [images[0]] + transformed_images, _ = self.model.transform(image_list, None) + features = self.model.backbone(transformed_images.tensors) + if isinstance(features, torch.Tensor): + features = OrderedDict([("0", features)]) + proposals, _ = self.model.rpn(transformed_images, features, None) + predictions, _ = self.model.roi_heads(features, proposals, transformed_images.image_sizes, None) + prediction = predictions[0] + boxes = torch.cat((prediction["boxes"], prediction["scores"].unsqueeze(1)), dim=1) + labels = prediction["labels"] - 1 + masks = prediction["masks"].squeeze(1) + return boxes, labels, masks diff --git a/model_converter/src/model_converter/cli.py b/model_converter/src/model_converter/cli.py new file mode 100644 index 00000000..41d18de1 --- /dev/null +++ b/model_converter/src/model_converter/cli.py @@ -0,0 +1,1177 @@ +#!/usr/bin/env -S uv run --script +# +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +""" +PyTorch to OpenVINO Model Converter + +Usage: + uv run python model_converter.py config.json -o ./output_models + +""" + +import argparse +import importlib +import json +import logging +import shutil +import sys +from pathlib import Path +from typing import Any + +import cv2 +import numpy as np +import torch +import torch.nn as nn + +from model_converter.adapters import get_adapter +from model_converter.downloaders import URLDownloader + +_MODEL_API_METADATA_FIELDS = ( + "resize_type", + "pad_value", + "input_dtype", + "confidence_threshold", + "postprocess_semantic_masks", + "nms_execute", + "iou_threshold", + "agnostic_nms", + "nms_max_predictions", +) + + +class ModelConverter: + """Handles conversion of PyTorch models to OpenVINO format.""" + + def __init__( + self, + output_dir: Path, + cache_dir: Path, + verbose: bool = False, + dataset_path: Path | None = None, + ): + """ + Initialize the ModelConverter. + + Args: + output_dir: Directory to save converted models + cache_dir: Directory to cache downloaded weights + verbose: Enable verbose logging + dataset_path: Path to calibration dataset for quantization + """ + self.output_dir = Path(output_dir) + self.cache_dir = Path(cache_dir) + self.dataset_path = Path(dataset_path) if dataset_path else None + self.output_dir.mkdir(parents=True, exist_ok=True) + self.cache_dir.mkdir(parents=True, exist_ok=True) + + # Setup logging + log_level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=log_level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + self.logger = logging.getLogger(__name__) + + # Initialize downloaders + self._url_downloader = URLDownloader(cache_dir=self.cache_dir) + + def get_labels(self, label_set: str) -> str | None: + """ + Get label list for a given label set. + + Args: + label_set: Name of the label set (e.g., "IMAGENET1K_V1") + + Returns: + Space-separated string of labels, or None if not found + """ + if label_set == "IMAGENET1K_V1": + from torchvision.models._meta import _IMAGENET_CATEGORIES + + categories = _IMAGENET_CATEGORIES + categories = [label.replace(" ", "_") for label in categories] + return " ".join(categories) + + if label_set == "IMAGENET21K": + from timm.data import ImageNetInfo + + info = ImageNetInfo("imagenet21k") + categories = info.label_descriptions() + categories = [desc.split(",")[0].strip().replace(" ", "_") for desc in categories] + return " ".join(categories) + + if label_set == "COCO_V1": + from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights + + categories = MaskRCNN_ResNet50_FPN_Weights.COCO_V1.meta["categories"] + categories = [label.replace(" ", "_") for label in categories] + return " ".join(categories) + + return None + + def load_model_class( + self, + class_path: str, + ) -> type: + """ + Dynamically load a model class from a Python path. + + Args: + class_path: Full Python path to the class (e.g., 'torchvision.models.resnet.resnet50') + + Returns: + The model class + """ + try: + module_path, class_name = class_path.rsplit(".", 1) + self.logger.debug(f"Importing module: {module_path}") + # nosemgrep: python.lang.security.audit.non-literal-import.non-literal-import + module = importlib.import_module( + module_path, + ) + model_class = getattr(module, class_name) + self.logger.debug(f"Loaded class: {class_name}") + return model_class + except Exception as e: + self.logger.error(f"Failed to import module {module_path}: {e}") + raise + + def load_checkpoint( + self, + checkpoint_path: Path, + ) -> dict[str, Any]: + """ + Load PyTorch checkpoint file. + + Args: + checkpoint_path: Path to checkpoint file + + Returns: + Checkpoint dictionary + """ + try: + checkpoint = torch.load( # nosemgrep: trailofbits.python.pickles-in-pytorch.pickles-in-pytorch + checkpoint_path, + map_location="cpu", + weights_only=True, + ) + self.logger.debug(f"Loaded checkpoint from: {checkpoint_path}") + return checkpoint + except Exception as e: + self.logger.error(f"Failed to load checkpoint: {e}") + raise + + def load_huggingface_model( + self, + repo_id: str, + revision: str, + model_library: str = "timm", + model_params: dict[str, Any] | None = None, + ) -> nn.Module: + """ + Load a model from Hugging Face Hub. + + Args: + repo_id: Hugging Face repository ID + revision: Immutable revision/commit SHA for the Hugging Face repository + model_library: Library to use ('timm', 'transformers', etc.) + model_params: Optional parameters for model loading + + Returns: + Loaded model instance + """ + try: + if model_library == "timm": + import timm + + repo_ref = f"hf-hub:{repo_id}@{revision}" + self.logger.info(f"Loading timm model: {repo_ref}") + model = timm.create_model( + repo_ref, + pretrained=True, + cache_dir=self.cache_dir, + **(model_params or {}), + ) + elif model_library == "transformers": + from transformers import AutoModel + + self.logger.info(f"Loading transformers model: {repo_id}@{revision}") + model = AutoModel.from_pretrained( + repo_id, + revision=revision, + cache_dir=self.cache_dir, + **(model_params or {}), + ) + else: + error_msg = f"Unsupported model library: {model_library}" + raise ValueError(error_msg) + + model.eval() + self.logger.info("✓ Hugging Face model loaded successfully") + return model + + except Exception as e: + self.logger.error(f"Failed to load Hugging Face model: {e}") + raise + + def create_model( + self, + model_class: type, + checkpoint: dict[str, Any], + model_params: dict[str, Any] | None = None, + ) -> nn.Module: + """ + Create and initialize model instance. + + Args: + model_class: Model class to instantiate + checkpoint: Checkpoint containing model weights + model_params: Optional parameters for model initialization + + Returns: + Initialized model instance + """ + try: + # Handle torch.nn.Module case (checkpoint contains full model) + if model_class == torch.nn.Module: + if "model" in checkpoint: + model = checkpoint["model"] + elif "state_dict" in checkpoint: + # Cannot reconstruct architecture from state_dict alone + error_msg = ( + "Checkpoint contains only state_dict. Please specify the model class instead of torch.nn.Module" + ) + raise ValueError(error_msg) + else: + # Assume checkpoint is the model itself + model = checkpoint + + if not isinstance(model, nn.Module): + error_msg = "Checkpoint does not contain a valid model" + raise ValueError(error_msg) + else: + # Instantiate model class + model = model_class(**model_params) if model_params else model_class() + + # Load weights + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + elif "model" in checkpoint: + if isinstance(checkpoint["model"], nn.Module): + return checkpoint["model"] + state_dict = checkpoint["model"] + else: + state_dict = checkpoint + + model.load_state_dict(state_dict, strict=False) + + model.eval() + self.logger.info("✓ Model created and loaded successfully") + return model + + except Exception as e: + self.logger.error(f"Failed to create model: {e}") + raise + + def copy_readme( + self, + model_config: dict[str, Any], + output_folder: Path, + variant: str = "fp16", + ) -> None: + """ + Copy README template to model folder and replace placeholders. + + Args: + model_config: Model configuration used to fill template placeholders + output_folder: Folder where the model is saved + variant: Model variant ('fp16' or 'int8') + """ + try: + model_short_name = str(model_config.get("model_short_name", "")).strip() + model_library = str(model_config.get("model_library", "timm")).strip() + model_license = str(model_config.get("license", "")).strip() + model_license_link = str(model_config.get("license_link", "")).strip() + docs = str(model_config.get("docs", "")).strip() + + def template_placeholder(name: str) -> str: + return f"<<{name}>>" + + if not model_short_name: + error_msg = "Model config must define a non-empty model_short_name" + raise ValueError(error_msg) + + if not model_license_link: + error_msg = f"Model '{model_short_name}' must define a non-empty license_link" + raise ValueError(error_msg) + + if not model_license: + error_msg = f"Model '{model_short_name}' must define a non-empty license" + raise ValueError(error_msg) + + if not docs: + self.logger.warning( + f"Model '{model_short_name}' does not define 'docs' field. Placeholder will be empty.", + ) + + # Determine which README template to use based on model library + template_name = f"README-{model_library}-{variant}.md" + template_path = Path(__file__).parent.parent / "templates" / template_name + + if not template_path.exists(): + self.logger.warning(f"README template not found: {template_path}") + return + + # Read template + readme_content = template_path.read_text() + + placeholders = { + template_placeholder("license"): model_license, + template_placeholder("license_link"): model_license_link, + template_placeholder("model_name"): model_short_name, + template_placeholder("model_short_name"): model_short_name, + template_placeholder("variant"): variant, + template_placeholder("docs"): docs, + } + + for key, value in model_config.items(): + if value is None: + continue + if isinstance(value, (str, int, float, bool)): + placeholders[template_placeholder(key)] = str(value) + + for placeholder, value in placeholders.items(): + readme_content = readme_content.replace(placeholder, value) + + # Write to model folder + output_readme = output_folder / "README.md" + output_readme.write_text(readme_content) + self.logger.debug(f"Copied README to: {output_readme}") + + except (OSError, UnicodeError, ValueError) as e: + self.logger.warning(f"Failed to copy README: {e}") + + def _collect_dataset_entries(self, image_dir: Path) -> list[tuple[Path, int]]: + """Collect dataset image paths with their class labels.""" + image_entries: list[tuple[Path, int]] = [] + for class_dir in sorted(image_dir.iterdir()): + if class_dir.is_dir(): + class_label = int(class_dir.name) + for pattern in ["*.JPEG", "*.jpg", "*.png"]: + for img_path in class_dir.glob(pattern): + image_entries.append((img_path, class_label)) + return image_entries + + def _preprocess_calibration_image( + self, + img_path: Path, + width: int, + height: int, + mean: np.ndarray, + scale: np.ndarray, + reverse_input_channels: bool, + ) -> np.ndarray | None: + """Load and preprocess a single calibration image.""" + img = cv2.imread(str(img_path)) + if img is None: + return None + + img = cv2.resize(img, (width, height)) + img = img.astype(np.float32) + + if reverse_input_channels: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + img = (img - mean) / scale + img = img.transpose(2, 0, 1) + return np.expand_dims(img, axis=0) + + def create_calibration_dataset( + self, + input_shape: list[int], + mean_values: str | None = None, + scale_values: str | None = None, + reverse_input_channels: bool = True, + subset_size: int = 5000, + return_labels: bool = False, + ) -> tuple[list[np.ndarray], list[int]] | list[np.ndarray]: + """ + Create calibration dataset from sample validation images. + + Args: + input_shape: Target input shape [batch, channels, height, width] + mean_values: Space-separated mean values for normalization + scale_values: Space-separated scale values for normalization + reverse_input_channels: Whether to reverse RGB to BGR + subset_size: Number of images to use for calibration + return_labels: Whether to return labels along with images + + Returns: + List of preprocessed image arrays, or tuple of (images, labels) + """ + if not self.dataset_path or not self.dataset_path.exists(): + self.logger.warning("Dataset path not provided or doesn't exist. Skipping quantization.") + return [] + + # Parse mean and scale values + mean = np.array([float(x) for x in mean_values.split()]) if mean_values else np.array([0, 0, 0]) + scale = np.array([float(x) for x in scale_values.split()]) if scale_values else np.array([1, 1, 1]) + + _, _, height, width = input_shape + calibration_data: list[np.ndarray] = [] + + # Find all images in the dataset + image_dir = self.dataset_path + if not image_dir.exists(): + self.logger.error(f"Image directory not found: {image_dir}") + return ([], []) + + image_entries = self._collect_dataset_entries(image_dir) + if not image_entries: + self.logger.error("No images found in dataset") + return ([], []) + + self.logger.info(f"Found {len(image_entries)} images in dataset") + self.logger.info(f"Using {min(subset_size, len(image_entries))} images for calibration") + + if return_labels: + labels: list[int] = [] + for i, (img_path, class_label) in enumerate(image_entries[:subset_size]): + try: + img = self._preprocess_calibration_image( + img_path=img_path, + width=width, + height=height, + mean=mean, + scale=scale, + reverse_input_channels=reverse_input_channels, + ) + if img is None: + continue + + calibration_data.append(img) + labels.append(class_label) + + if (i + 1) % 50 == 0: + self.logger.debug(f"Processed {i + 1}/{subset_size} images") + + except (cv2.error, OSError, TypeError, ValueError) as e: + self.logger.warning(f"Failed to process {img_path}: {e}") + continue + + self.logger.info(f"✓ Created calibration dataset with {len(calibration_data)} images") + return calibration_data, labels + + for i, (img_path, _) in enumerate(image_entries[:subset_size]): + try: + img = self._preprocess_calibration_image( + img_path=img_path, + width=width, + height=height, + mean=mean, + scale=scale, + reverse_input_channels=reverse_input_channels, + ) + if img is None: + continue + + calibration_data.append(img) + + if (i + 1) % 50 == 0: + self.logger.debug(f"Processed {i + 1}/{subset_size} images") + + except (cv2.error, OSError, TypeError, ValueError) as e: + self.logger.warning(f"Failed to process {img_path}: {e}") + continue + + self.logger.info(f"✓ Created calibration dataset with {len(calibration_data)} images") + return calibration_data, [] + + def validate_model( + self, + model_path: Path, + validation_data: list[np.ndarray], + labels: list[int], + ) -> float: + """ + Validate OpenVINO model and compute top-1 accuracy. + + Args: + model_path: Path to the OpenVINO model (.xml) + validation_data: List of validation images + labels: List of ground truth labels + + Returns: + Top-1 accuracy (0.0 to 1.0) + """ + try: + import openvino as ov + + core = ov.Core() + model = core.read_model(model_path) + compiled_model = core.compile_model(model, device_name="CPU") + output_layer = compiled_model.outputs[0] + + predictions: list[int] = [] + for img in validation_data: + result = compiled_model(img)[output_layer] + pred_class = np.argmax(result, axis=1)[0] + predictions.append(pred_class) + + # Compute accuracy + correct = sum(predicted == label for predicted, label in zip(predictions, labels)) + return correct / len(labels) + + except (ImportError, OSError, RuntimeError, TypeError, ValueError) as e: + self.logger.error(f"Failed to validate model: {e}") + return 0.0 + + def quantize_model( + self, + model_path: Path, + calibration_data: list[np.ndarray], + model_config: dict[str, Any], + preset: str = "accuracy", + validation_data: list[np.ndarray] | None = None, + validation_labels: list[int] | None = None, + ) -> Path: + """ + Quantize OpenVINO model to INT8 using NNCF. + + Args: + model_path: Path to the FP32 OpenVINO model (.xml) + calibration_data: List of calibration images + model_config: Model configuration used for README rendering + preset: Quantization preset ('accuracy', 'performance', 'mixed') + validation_data: Optional validation images for accuracy measurement + validation_labels: Optional validation labels for accuracy measurement + + Returns: + Path to the quantized model + """ + if not calibration_data: + self.logger.warning("No calibration data provided. Skipping quantization.") + return model_path + + try: + import nncf + import openvino as ov + + self.logger.info(f"Quantizing model with {len(calibration_data)} calibration samples") + self.logger.info(f"Using preset: {preset}") + + # Load the model + core = ov.Core() + model = core.read_model(model_path) + + # Create calibration dataset generator + def calibration_dataset(): + for data in calibration_data: + yield data + + # Map preset string to NNCF enum + preset_map = { + "performance": nncf.QuantizationPreset.PERFORMANCE, + "mixed": nncf.QuantizationPreset.MIXED, + } + nncf_preset = preset_map.get(preset.lower(), nncf.QuantizationPreset.MIXED) + + # Quantize the model + quantized_model = nncf.quantize( + model, + calibration_dataset=nncf.Dataset(calibration_dataset()), + preset=nncf_preset, + subset_size=len(calibration_data), + ) + + # Extract model name from the FP32 model path + # The FP32 path is like: output_dir/model_name-fp16-ov/model_name_fp32.xml + model_name = model_path.stem # Gets model_name_fp32 from model_name_fp32.xml + # Remove _fp32 suffix if present + if model_name.endswith("_fp32"): + model_name = model_name[:-5] + + # Create output folder with -int8-ov suffix + output_folder = model_path.parent.parent / f"{model_name}-int8-ov" + output_folder.mkdir(parents=True, exist_ok=True) + + # Save quantized model with model name inside the folder + output_path = output_folder / f"{model_name}.xml" + ov.save_model(quantized_model, output_path, compress_to_fp16=True) + self.logger.info(f"✓ Quantized model saved: {output_path}") + + # Save model_info as config.json to track downloads + with (output_folder / "config.json").open("w") as f: + json.dump(quantized_model.get_rt_info(["model_info"]).value, f, indent=4) + + # Validate accuracy if validation data provided + if validation_data and validation_labels: + self.logger.info("Validating FP32 model accuracy...") + fp32_accuracy = self.validate_model(model_path, validation_data, validation_labels) + self.logger.info(f"FP32 Top-1 Accuracy: {fp32_accuracy * 100:.2f}%") + + self.logger.info("Validating INT8 model accuracy...") + int8_accuracy = self.validate_model(output_path, validation_data, validation_labels) + self.logger.info(f"INT8 Top-1 Accuracy: {int8_accuracy * 100:.2f}%") + + accuracy_drop = (fp32_accuracy - int8_accuracy) * 100 + self.logger.info(f"Accuracy Drop: {accuracy_drop:.2f}%") + + # Copy .gitattributes file + gitattributes_template = Path(__file__).parent.parent / "templates" / ".gitattributes" + if gitattributes_template.exists(): + shutil.copy2(gitattributes_template, output_folder / ".gitattributes") + self.logger.debug(f"Copied .gitattributes to: {output_folder}") + + # Copy README for INT8 model + self.copy_readme( + model_config, + output_folder, + variant="int8", + ) + + return output_path + + except ImportError: + self.logger.error("NNCF not installed. Install with: pip install nncf") + return model_path + except (OSError, RuntimeError, TypeError, ValueError) as e: + self.logger.error(f"Failed to quantize model: {e}") + import traceback + + self.logger.debug(traceback.format_exc()) + return model_path + + def export_to_openvino( + self, + model: nn.Module, + input_shape: list[int], + output_path: Path, + model_config: dict[str, Any], + input_names: list[str] | None = None, + output_names: list[str] | None = None, + metadata: dict[tuple[str, str], str] | None = None, + ) -> tuple[Path, Path]: + """ + Export PyTorch model to OpenVINO format. + + Args: + model: PyTorch model to export + input_shape: Input tensor shape [batch, channels, height, width] + output_path: Path to save the model (without extension) + model_config: Model configuration used for README rendering + input_names: Names for input tensors + output_names: Names for output tensors + metadata: Metadata to embed in the model + + Returns: + Tuple of (fp16_model_path, fp32_model_path) - FP16 for final use, FP32 for quantization + """ + import openvino as ov + + try: + model = self._prepare_model_for_export(model, model_config) + model.eval() + dummy_input = self._create_example_input(input_shape, model_config) + self.logger.info("Direct PyTorch to OpenVINO conversion") + ov_model = ov.convert_model(model, example_input=dummy_input) + self.logger.info("✓ PyTorch to OpenVINO conversion complete") + + # Reshape model to fixed input shape (remove dynamic dimensions) + first_input = ov_model.input(0) + input_name_for_reshape = next(iter(first_input.get_names())) if first_input.get_names() else 0 + + self.logger.debug(f"Setting fixed input shape: {input_shape}") + ov_model.reshape({input_name_for_reshape: input_shape}) + + # Post-process the model + ov_model = self._postprocess_openvino_model( + ov_model, + input_names=input_names, + output_names=output_names, + metadata=metadata, + ) + + # Create output folder with -fp16-ov suffix + model_name = output_path.name + output_folder = output_path.parent / f"{model_name}-fp16-ov" + output_folder.mkdir(parents=True, exist_ok=True) + + # Save FP32 model for quantization (temporary) + fp32_xml_path = output_folder / f"{model_name}_fp32.xml" + ov.save_model(ov_model, fp32_xml_path, compress_to_fp16=False) + self.logger.debug(f"Saved FP32 model for quantization: {fp32_xml_path}") + + # Save the FP16 model (final) + xml_path = output_folder / f"{model_name}.xml" + ov.save_model(ov_model, xml_path, compress_to_fp16=True) + self.logger.info(f"✓ Model saved: {xml_path}") + + # Save model_info as config.json to track downloads + with (output_folder / "config.json").open("w") as f: + json.dump(ov_model.get_rt_info(["model_info"]).value, f, indent=4) + + # Copy .gitattributes file + gitattributes_template = Path(__file__).parent.parent / "templates" / ".gitattributes" + if gitattributes_template.exists(): + shutil.copy2(gitattributes_template, output_folder / ".gitattributes") + self.logger.debug(f"Copied .gitattributes to: {output_folder}") + + # Copy README for FP16 model + self.copy_readme( + model_config, + output_folder, + variant="fp16", + ) + + return xml_path, fp32_xml_path + + except Exception as e: + self.logger.error(f"Failed to export model: {e}") + raise + + def _prepare_model_for_export(self, model: nn.Module, model_config: dict[str, Any]) -> nn.Module: + """Prepare model for OpenVINO conversion.""" + model_type = str(model_config.get("model_type", "")) + adapted = get_adapter(model_type, model) + if adapted is not model: + self.logger.info(f"Applied export adapter for model type: {model_type}") + return adapted + + def _create_example_input(self, input_shape: list[int], model_config: dict[str, Any]) -> torch.Tensor: + """Create example input suitable for the configured model type.""" + if str(model_config.get("model_type", "")).lower() == "maskrcnn": + return torch.rand(*input_shape) + return torch.randn(*input_shape) + + def _postprocess_openvino_model( + self, + model: Any, + input_names: list[str] | None = None, + output_names: list[str] | None = None, + metadata: dict[tuple[str, str], str] | None = None, + ) -> Any: + """ + Post-process OpenVINO model (set names, add metadata). + + Args: + model: OpenVINO model + input_names: Names for input tensors + output_names: Names for output tensors + metadata: Metadata to embed + + Returns: + Post-processed model + """ + # Set input names + if input_names: + for i, name in enumerate(input_names): + if i < len(model.inputs): + model.input(i).set_names({name}) + self.logger.debug(f"Set input {i} name to: {name}") + + # Set output names + if output_names: + for i, name in enumerate(output_names): + if i < len(model.outputs): + model.output(i).set_names({name}) + self.logger.debug(f"Set output {i} name to: {name}") + + # Add metadata + if metadata: + for key, value in metadata.items(): + model.set_rt_info(value, list(key)) + self.logger.debug(f"Set metadata {key}: {value}") + + return model + + def _load_model_from_config(self, config: dict[str, Any]) -> Any: + """Load a PyTorch model based on configuration (HuggingFace or traditional weights).""" + huggingface_repo = config.get("huggingface_repo") + if huggingface_repo: + huggingface_revision = config.get("huggingface_revision") + if not huggingface_revision: + error_msg = "Hugging Face models must define 'huggingface_revision' with an immutable commit SHA" + raise ValueError(error_msg) + + model_library = config.get("model_library", "timm") + model_params = config.get("model_params") + return self.load_huggingface_model( + repo_id=huggingface_repo, + revision=huggingface_revision, + model_library=model_library, + model_params=model_params, + ) + + # Traditional PyTorch model workflow + weights_url = config["weights_url"] + weights_path = self._url_downloader.download(url=weights_url) + + model_class_name = config.get("model_class_name", "torch.nn.Module") + model_class = self.load_model_class(model_class_name) + + checkpoint = self.load_checkpoint(weights_path) + + model_params = config.get("model_params") + return self.create_model(model_class, checkpoint, model_params) + + def _quantize_and_cleanup(self, config: dict[str, Any], fp32_model_path: Path, **kwargs: Any) -> None: + """Run INT8 quantization and clean up temporary FP32 model files.""" + model_type = kwargs["model_type"] + self.logger.info("Creating calibration dataset for INT8 quantization") + return_validation_labels = model_type == "Classification" and bool(config.get("labels")) + + if return_validation_labels: + self.logger.info("Creating validation dataset for accuracy measurement") + validation_data, validation_labels = self.create_calibration_dataset( + input_shape=kwargs["input_shape"], + mean_values=kwargs["mean_values"], + scale_values=kwargs["scale_values"], + reverse_input_channels=kwargs["reverse_input_channels"], + subset_size=300, + return_labels=return_validation_labels, + ) + + if validation_data: + self.quantize_model( + model_path=fp32_model_path, + calibration_data=validation_data, + model_config=config, + preset="mixed", + validation_data=validation_data if validation_labels else None, + validation_labels=validation_labels or None, + ) + + # Clean up temporary FP32 model after quantization + try: + if fp32_model_path.exists(): + fp32_model_path.unlink() + self.logger.debug(f"Removed temporary FP32 model: {fp32_model_path}") + fp32_bin_path = fp32_model_path.with_suffix(".bin") + if fp32_bin_path.exists(): + fp32_bin_path.unlink() + self.logger.debug(f"Removed temporary FP32 weights: {fp32_bin_path}") + except OSError as e: + self.logger.warning(f"Failed to remove temporary FP32 files: {e}") + + def process_model_config(self, config: dict[str, Any]) -> bool: + """ + Process a single model configuration. + + Args: + config: Model configuration dictionary + + Returns: + True if successful, False otherwise + """ + model_short_name = config.get("model_short_name", "unknown") + model_license = config.get("license") + model_license_link = config.get("license_link") + + # Check if both FP16 and INT8 models already exist + fp16_model_path = self.output_dir / f"{model_short_name}-fp16-ov" / f"{model_short_name}.xml" + int8_model_path = self.output_dir / f"{model_short_name}-int8-ov" / f"{model_short_name}.xml" + + if fp16_model_path.exists() and int8_model_path.exists(): + self.logger.info(f"Skipping {model_short_name}: FP16 and INT8 models already exist") + return True + + try: + if not model_license: + error_msg = f"Model '{model_short_name}' must define 'license' in configuration" + raise ValueError(error_msg) + if not model_license_link: + error_msg = f"Model '{model_short_name}' must define 'license_link' in configuration" + raise ValueError(error_msg) + + self.logger.info("=" * 80) + self.logger.info(f"Processing model: {config.get('model_full_name', model_short_name)}") + self.logger.info(f"Short name: {model_short_name}") + if "description" in config: + self.logger.info(f"Description: {config['description']}") + self.logger.info("=" * 80) + + model = self._load_model_from_config(config) + + # Prepare export parameters + input_shape = config.get("input_shape", [1, 3, 224, 224]) + input_names = config.get("input_names", ["input"]) + output_names = config.get("output_names", ["result"]) + + # Prepare metadata from config (with defaults for normalization) + reverse_input_channels = config.get("reverse_input_channels", True) + mean_values = config.get("mean_values", "123.675 116.28 103.53") + scale_values = config.get("scale_values", "58.395 57.12 57.375") + model_type = config.get("model_type", "") + + metadata = { + ("model_info", "model_type"): model_type, + ("model_info", "model_short_name"): model_short_name, + ("model_info", "reverse_input_channels"): self._metadata_value(reverse_input_channels), + ("model_info", "mean_values"): mean_values, + ("model_info", "scale_values"): scale_values, + } + + for metadata_field in _MODEL_API_METADATA_FIELDS: + if metadata_field in config and config[metadata_field] is not None: + metadata["model_info", metadata_field] = self._metadata_value(config[metadata_field]) + + # Add labels if specified in config + labels_config = config.get("labels") + if labels_config: + labels = self.get_labels(labels_config) + if labels: + metadata["model_info", "labels"] = labels + self.logger.info(f"Added {labels_config} labels to metadata") + else: + self.logger.warning(f"Could not load labels for: {labels_config}") + + output_path = self.output_dir / model_short_name + fp16_model_path, fp32_model_path = self.export_to_openvino( + model=model, + input_shape=input_shape, + output_path=output_path, + model_config=config, + input_names=input_names, + output_names=output_names, + metadata=metadata, + ) + + # Quantize the model if dataset is available + if self.dataset_path and self.dataset_path.exists(): + self._quantize_and_cleanup( + config, + fp32_model_path, + model_type=model_type, + input_shape=input_shape, + mean_values=mean_values, + scale_values=scale_values, + reverse_input_channels=reverse_input_channels, + ) + + self.logger.info(f"✓ Successfully converted {model_short_name}") + return True + + except (ValueError, RuntimeError, ImportError, FileNotFoundError) as e: + self.logger.error(f"✗ Failed to process model {model_short_name}: {e}") + import traceback + + self.logger.debug(traceback.format_exc()) + return False + + @staticmethod + def _metadata_value(value: Any) -> str: + """Convert config values to Model API rt_info string values.""" + if isinstance(value, (list, tuple)): + return " ".join(str(item) for item in value) + return str(value) + + def process_config_file( + self, + config_path: Path, + model_filter: str | None = None, + ) -> tuple[int, int]: + """ + Process models from a configuration file. + + Args: + config_path: Path to JSON configuration file + model_filter: Optional model short name to process (process only this model) + + Returns: + Tuple of (successful_count, failed_count) + """ + try: + with Path(config_path).open() as f: + config = json.load(f) + except Exception as e: + self.logger.error(f"Failed to load configuration file: {e}") + raise + + models = config.get("models", []) + + if not models: + self.logger.warning("No models found in configuration file") + return 0, 0 + + self.logger.info(f"Configuration validated: {len(models)} models found") + + # Filter models if requested + if model_filter: + models = [m for m in models if m.get("model_short_name") == model_filter] + if not models: + self.logger.error(f"Model '{model_filter}' not found in configuration") + return 0, 0 + self.logger.info(f"Processing only model: {model_filter}") + + successful = 0 + failed = 0 + + for model_config in models: + if self.process_model_config(model_config): + successful += 1 + else: + failed += 1 + + return successful, failed + + +def list_models(config_path: Path) -> None: + """List all models in a configuration file.""" + try: + with config_path.open() as f: + config = json.load(f) + except (FileNotFoundError, json.JSONDecodeError, PermissionError) as e: + print(f"Error loading configuration: {e}", file=sys.stderr) + return + + models = config.get("models", []) + + if not models: + print("No models found in configuration") + return + + print(f"\nFound {len(models)} models:\n") + print(f"{'Short Name':<30} {'Full Name':<40} {'Type':<20}") + print("-" * 90) + + for model in models: + short_name = model.get("model_short_name", "N/A") + full_name = model.get("model_full_name", "N/A") + model_type = model.get("model_type", "N/A") + print(f"{short_name:<30} {full_name:<40} {model_type:<20}") + + print() + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser( + description="Convert PyTorch models to OpenVINO format", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Convert all models in config + uv run python model_converter.py config.json -o ./models + + # Convert a specific model + uv run python model_converter.py config.json -o ./models --model resnet50 + + # List all models in config + uv run python model_converter.py config.json --list + + # Enable verbose logging + uv run python model_converter.py config.json -o ./models -v + """, + ) + + parser.add_argument( + "config", + type=Path, + help="Path to JSON configuration file", + ) + + parser.add_argument( + "-o", + "--output", + type=Path, + default=Path("./converted_models"), + help="Output directory for converted models (default: ./converted_models)", + ) + + parser.add_argument( + "-c", + "--cache", + type=Path, + default=Path.home() / ".cache" / "torch" / "hub" / "checkpoints", + help="Cache directory for downloaded weights (default: ~/.cache/torch/hub/checkpoints)", + ) + + parser.add_argument( + "-d", + "--dataset", + type=Path, + default=Path.home() / "model_api" / "validation_dataset", + help=("Path to calibration dataset for INT8 quantization (default: ~/model_api/validation_dataset)"), + ) + + parser.add_argument( + "--model", + type=str, + help="Process only the specified model (by model_short_name)", + ) + + parser.add_argument( + "--list", + action="store_true", + help="List all models in the configuration file and exit", + ) + + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + # Check if config file exists + if not args.config.exists(): + print(f"Error: Configuration file not found: {args.config}", file=sys.stderr) + return 1 + + # List models and exit + if args.list: + list_models(args.config) + return 0 + + # Setup logging + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig( + level=log_level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + logger = logging.getLogger(__name__) + + logger.info(f"Loading configuration from: {args.config}") + + try: + # Create converter + converter = ModelConverter( + output_dir=args.output, + cache_dir=args.cache, + verbose=args.verbose, + dataset_path=args.dataset, + ) + + logger.info(f"Output directory: {args.output}") + logger.info(f"Cache directory: {args.cache}") + if args.dataset: + logger.info(f"Calibration dataset: {args.dataset}") + + # Process models + successful, failed = converter.process_config_file( + config_path=args.config, + model_filter=args.model, + ) + + # Print summary + logger.info("=" * 80) + logger.info("Conversion Summary:") + logger.info(f" Successful: {successful}") + logger.info(f" Failed: {failed}") + logger.info(f" Total: {successful + failed}") + logger.info("=" * 80) + + return 0 if failed == 0 else 1 + except (ValueError, RuntimeError, ImportError, FileNotFoundError) as e: + logger.error(f"Failed to process model: {e}") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/model_converter/src/model_converter/downloaders/__init__.py b/model_converter/src/model_converter/downloaders/__init__.py new file mode 100644 index 00000000..6fbc9892 --- /dev/null +++ b/model_converter/src/model_converter/downloaders/__init__.py @@ -0,0 +1,11 @@ +# +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +"""Model weight and file downloaders.""" + +from model_converter.downloaders.huggingface import HuggingFaceDownloader +from model_converter.downloaders.url import URLDownloader + +__all__ = ["HuggingFaceDownloader", "URLDownloader"] diff --git a/model_converter/src/model_converter/downloaders/base.py b/model_converter/src/model_converter/downloaders/base.py new file mode 100644 index 00000000..588ee0c6 --- /dev/null +++ b/model_converter/src/model_converter/downloaders/base.py @@ -0,0 +1,25 @@ +# +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +"""Base downloader interface.""" + +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + + +class BaseDownloader: + """Base class for model weight downloaders.""" + + def __init__(self, cache_dir: Path): + """ + Initialize downloader. + + Args: + cache_dir: Directory to cache downloaded files + """ + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) diff --git a/model_converter/src/model_converter/downloaders/huggingface.py b/model_converter/src/model_converter/downloaders/huggingface.py new file mode 100644 index 00000000..db18e2df --- /dev/null +++ b/model_converter/src/model_converter/downloaders/huggingface.py @@ -0,0 +1,59 @@ +# +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +"""Hugging Face Hub downloader.""" + +import logging +from pathlib import Path + +from huggingface_hub import hf_hub_download, snapshot_download + +from model_converter.downloaders.base import BaseDownloader + +logger = logging.getLogger(__name__) + + +class HuggingFaceDownloader(BaseDownloader): + """Download models and files from Hugging Face Hub.""" + + def download( + self, + repo_id: str, + revision: str, + filename: str | None = None, + ) -> Path: + """ + Download model from Hugging Face Hub with caching. + + Args: + repo_id: Hugging Face repository ID (e.g., 'timm/mobilenetv2_050.lamb_in1k') + revision: Immutable revision/commit SHA to download from + filename: Optional specific file to download (if None, downloads the whole repo) + + Returns: + Path to the downloaded model file or directory + """ + logger.info(f"Downloading from Hugging Face Hub: {repo_id}") + + try: + if filename: + cached_file = hf_hub_download( # nosec B615 + repo_id=repo_id, + revision=revision, + filename=filename, + cache_dir=self.cache_dir, + ) + logger.info(f"✓ Downloaded file: {cached_file}") + return Path(cached_file) + cached_dir = snapshot_download( + repo_id=repo_id, + revision=revision, + cache_dir=self.cache_dir, + ) + logger.info(f"✓ Downloaded repository to: {cached_dir}") + return Path(cached_dir) + except Exception as e: + logger.error(f"Failed to download from Hugging Face: {e}") + raise diff --git a/model_converter/src/model_converter/downloaders/url.py b/model_converter/src/model_converter/downloaders/url.py new file mode 100644 index 00000000..90f013f3 --- /dev/null +++ b/model_converter/src/model_converter/downloaders/url.py @@ -0,0 +1,56 @@ +# +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +"""Direct URL downloader.""" + +import logging +import urllib.request +from pathlib import Path + +from model_converter.downloaders.base import BaseDownloader + +logger = logging.getLogger(__name__) + + +class URLDownloader(BaseDownloader): + """Download model weights from direct URLs.""" + + def download( + self, + url: str, + filename: str | None = None, + ) -> Path: + """ + Download model weights from URL with caching. + + Args: + url: URL to download weights from + filename: Optional filename to save as (default: extract from URL) + + Returns: + Path to the downloaded/cached weights file + """ + if filename is None: + filename = url.split("/")[-1] + + cached_file = self.cache_dir / filename + + if cached_file.exists(): + logger.info(f"Using cached weights: {cached_file}") + return cached_file + + logger.info(f"Downloading weights from: {url}") + logger.info(f"Saving to: {cached_file}") + + try: + urllib.request.urlretrieve( # noqa: S310 # nosemgrep: python.lang.security.audit.dynamic-urllib-use-detected.dynamic-urllib-use-detected + url, + cached_file, + ) + logger.info("✓ Download complete") + return cached_file + except Exception as e: + logger.error(f"Failed to download weights: {e}") + raise diff --git a/model_converter/src/model_converter/model_converter.py b/model_converter/src/model_converter/model_converter.py index 39d97c1c..0af62036 100644 --- a/model_converter/src/model_converter/model_converter.py +++ b/model_converter/src/model_converter/model_converter.py @@ -4,1196 +4,16 @@ # """ -PyTorch to OpenVINO Model Converter +Models to OpenVINO Model Converter Usage: - uv run model-converter examples/config.json -o ./output_models + uv run python model_converter.py config.json -o ./output_models """ -import argparse -import importlib -import json -import logging -import shutil import sys -import urllib.request -from pathlib import Path -from typing import Any - -import cv2 -import numpy as np -import torch -import torch.nn as nn -from huggingface_hub import hf_hub_download, snapshot_download - - -class ModelConverter: - """Handles conversion of PyTorch models to OpenVINO format.""" - - def __init__( - self, - output_dir: Path, - cache_dir: Path, - verbose: bool = False, - dataset_path: Path | None = None, - ): - """ - Initialize the ModelConverter. - - Args: - output_dir: Directory to save converted models - cache_dir: Directory to cache downloaded weights - verbose: Enable verbose logging - dataset_path: Path to calibration dataset for quantization - """ - self.output_dir = Path(output_dir) - self.cache_dir = Path(cache_dir) - self.dataset_path = Path(dataset_path) if dataset_path else None - self.output_dir.mkdir(parents=True, exist_ok=True) - self.cache_dir.mkdir(parents=True, exist_ok=True) - - # Setup logging - log_level = logging.DEBUG if verbose else logging.INFO - logging.basicConfig( - level=log_level, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - self.logger = logging.getLogger(__name__) - - def get_labels(self, label_set: str) -> str | None: - """ - Get label list for a given label set. - - Args: - label_set: Name of the label set (e.g., "IMAGENET1K_V1") - - Returns: - Space-separated string of labels, or None if not found - """ - if label_set == "IMAGENET1K_V1": - from torchvision.models._meta import _IMAGENET_CATEGORIES - - categories = _IMAGENET_CATEGORIES - categories = [label.replace(" ", "_") for label in categories] - return " ".join(categories) - - if label_set == "IMAGENET21K": - from timm.data import ImageNetInfo - - info = ImageNetInfo("imagenet21k") - categories = info.label_descriptions() - categories = [desc.split(",")[0].strip().replace(" ", "_") for desc in categories] - return " ".join(categories) - - return None - - def download_from_huggingface( - self, - repo_id: str, - revision: str, - filename: str | None = None, - ) -> Path: - """ - Download model from Hugging Face Hub with caching. - - Args: - repo_id: Hugging Face repository ID (e.g., 'timm/mobilenetv2_050.lamb_in1k') - revision: Immutable revision/commit SHA to download from - filename: Optional specific file to download (if None, downloads the whole repo) - - Returns: - Path to the downloaded model file or directory - """ - self.logger.info(f"Downloading from Hugging Face Hub: {repo_id}") - - try: - if filename: - # Download a specific file - cached_file = hf_hub_download( # nosec B615 - repo_id=repo_id, - revision=revision, - filename=filename, - cache_dir=self.cache_dir, - ) - self.logger.info(f"✓ Downloaded file: {cached_file}") - return Path(cached_file) - # Download the entire repository - cached_dir = snapshot_download( - repo_id=repo_id, - revision=revision, - cache_dir=self.cache_dir, - ) - self.logger.info(f"✓ Downloaded repository to: {cached_dir}") - return Path(cached_dir) - except Exception as e: - self.logger.error(f"Failed to download from Hugging Face: {e}") - raise - - def download_weights( - self, - url: str, - filename: str | None = None, - ) -> Path: - """ - Download model weights from URL with caching. - - Args: - url: URL to download weights from - filename: Optional filename to save as (default: extract from URL) - - Returns: - Path to the downloaded/cached weights file - """ - if filename is None: - filename = url.split("/")[-1] - - cached_file = self.cache_dir / filename - - if cached_file.exists(): - self.logger.info(f"Using cached weights: {cached_file}") - return cached_file - - self.logger.info(f"Downloading weights from: {url}") - self.logger.info(f"Saving to: {cached_file}") - - try: - urllib.request.urlretrieve( # noqa: S310 # nosemgrep: python.lang.security.audit.dynamic-urllib-use-detected.dynamic-urllib-use-detected - url, - cached_file, - ) - self.logger.info("✓ Download complete") - return cached_file - except Exception as e: - self.logger.error(f"Failed to download weights: {e}") - raise - - def load_model_class( - self, - class_path: str, - ) -> type: - """ - Dynamically load a model class from a Python path. - - Args: - class_path: Full Python path to the class (e.g., 'torchvision.models.resnet.resnet50') - - Returns: - The model class - """ - try: - module_path, class_name = class_path.rsplit(".", 1) - self.logger.debug(f"Importing module: {module_path}") - # nosemgrep: python.lang.security.audit.non-literal-import.non-literal-import - module = importlib.import_module( - module_path, - ) - model_class = getattr(module, class_name) - self.logger.debug(f"Loaded class: {class_name}") - return model_class - except Exception as e: - self.logger.error(f"Failed to import module {module_path}: {e}") - raise - - def load_checkpoint( - self, - checkpoint_path: Path, - ) -> dict[str, Any]: - """ - Load PyTorch checkpoint file. - - Args: - checkpoint_path: Path to checkpoint file - - Returns: - Checkpoint dictionary - """ - try: - checkpoint = torch.load( # nosemgrep: trailofbits.python.pickles-in-pytorch.pickles-in-pytorch - checkpoint_path, - map_location="cpu", - weights_only=True, - ) - self.logger.debug(f"Loaded checkpoint from: {checkpoint_path}") - return checkpoint - except Exception as e: - self.logger.error(f"Failed to load checkpoint: {e}") - raise - - def load_huggingface_model( - self, - repo_id: str, - revision: str, - model_library: str = "timm", - model_params: dict[str, Any] | None = None, - ) -> nn.Module: - """ - Load a model from Hugging Face Hub. - - Args: - repo_id: Hugging Face repository ID - revision: Immutable revision/commit SHA for the Hugging Face repository - model_library: Library to use ('timm', 'transformers', etc.) - model_params: Optional parameters for model loading - - Returns: - Loaded model instance - """ - try: - if model_library == "timm": - import timm - - repo_ref = f"hf-hub:{repo_id}@{revision}" - self.logger.info(f"Loading timm model: {repo_ref}") - model = timm.create_model( - repo_ref, - pretrained=True, - cache_dir=self.cache_dir, - **(model_params or {}), - ) - elif model_library == "transformers": - from transformers import AutoModel - - self.logger.info(f"Loading transformers model: {repo_id}@{revision}") - model = AutoModel.from_pretrained( - repo_id, - revision=revision, - cache_dir=self.cache_dir, - **(model_params or {}), - ) - else: - error_msg = f"Unsupported model library: {model_library}" - raise ValueError(error_msg) - - model.eval() - self.logger.info("✓ Hugging Face model loaded successfully") - return model - - except Exception as e: - self.logger.error(f"Failed to load Hugging Face model: {e}") - raise - - def create_model( - self, - model_class: type, - checkpoint: dict[str, Any], - model_params: dict[str, Any] | None = None, - ) -> nn.Module: - """ - Create and initialize model instance. - - Args: - model_class: Model class to instantiate - checkpoint: Checkpoint containing model weights - model_params: Optional parameters for model initialization - - Returns: - Initialized model instance - """ - try: - # Handle torch.nn.Module case (checkpoint contains full model) - if model_class == torch.nn.Module: - if "model" in checkpoint: - model = checkpoint["model"] - elif "state_dict" in checkpoint: - # Cannot reconstruct architecture from state_dict alone - error_msg = ( - "Checkpoint contains only state_dict. Please specify the model class instead of torch.nn.Module" - ) - raise ValueError(error_msg) - else: - # Assume checkpoint is the model itself - model = checkpoint - - if not isinstance(model, nn.Module): - error_msg = "Checkpoint does not contain a valid model" - raise ValueError(error_msg) - else: - # Instantiate model class - model = model_class(**model_params) if model_params else model_class() - - # Load weights - if "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - elif "model" in checkpoint: - if isinstance(checkpoint["model"], nn.Module): - return checkpoint["model"] - state_dict = checkpoint["model"] - else: - state_dict = checkpoint - - model.load_state_dict(state_dict, strict=False) - - model.eval() - self.logger.info("✓ Model created and loaded successfully") - return model - - except Exception as e: - self.logger.error(f"Failed to create model: {e}") - raise - - def copy_readme( - self, - model_config: dict[str, Any], - output_folder: Path, - variant: str = "fp16", - ) -> None: - """ - Copy README template to model folder and replace placeholders. - - Args: - model_config: Model configuration used to fill template placeholders - output_folder: Folder where the model is saved - variant: Model variant ('fp16' or 'int8') - """ - try: - model_short_name = str(model_config.get("model_short_name", "")).strip() - model_library = str(model_config.get("model_library", "timm")).strip() - model_license = str(model_config.get("license", "")).strip() - model_license_link = str(model_config.get("license_link", "")).strip() - docs = str(model_config.get("docs", "")).strip() - - def template_placeholder(name: str) -> str: - return f"<<{name}>>" - - if not model_short_name: - error_msg = "Model config must define a non-empty model_short_name" - raise ValueError(error_msg) - - if not model_license_link: - error_msg = f"Model '{model_short_name}' must define a non-empty license_link" - raise ValueError(error_msg) - - if not model_license: - error_msg = f"Model '{model_short_name}' must define a non-empty license" - raise ValueError(error_msg) - - if not docs: - self.logger.warning( - f"Model '{model_short_name}' does not define 'docs' field. Placeholder will be empty.", - ) - - # Determine which README template to use based on model library - template_name = f"README-{model_library}-{variant}.md" - template_path = Path(__file__).parent / "templates" / template_name - - if not template_path.exists(): - self.logger.warning(f"README template not found: {template_path}") - return - - # Read template - readme_content = template_path.read_text() - - placeholders = { - template_placeholder("license"): model_license, - template_placeholder("license_link"): model_license_link, - template_placeholder("model_name"): model_short_name, - template_placeholder("model_short_name"): model_short_name, - template_placeholder("variant"): variant, - template_placeholder("docs"): docs, - } - - for key, value in model_config.items(): - if value is None: - continue - if isinstance(value, (str, int, float, bool)): - placeholders[template_placeholder(key)] = str(value) - - for placeholder, value in placeholders.items(): - readme_content = readme_content.replace(placeholder, value) - - # Write to model folder - output_readme = output_folder / "README.md" - output_readme.write_text(readme_content) - self.logger.debug(f"Copied README to: {output_readme}") - - except (OSError, UnicodeError, ValueError) as e: - self.logger.warning(f"Failed to copy README: {e}") - - def _collect_dataset_entries(self, image_dir: Path) -> list[tuple[Path, int]]: - """Collect dataset image paths with their class labels.""" - image_entries: list[tuple[Path, int]] = [] - for class_dir in sorted(image_dir.iterdir()): - if class_dir.is_dir(): - class_label = int(class_dir.name) - for pattern in ["*.JPEG", "*.jpg", "*.png"]: - for img_path in class_dir.glob(pattern): - image_entries.append((img_path, class_label)) - return image_entries - - def _preprocess_calibration_image( - self, - img_path: Path, - width: int, - height: int, - mean: np.ndarray, - scale: np.ndarray, - reverse_input_channels: bool, - ) -> np.ndarray | None: - """Load and preprocess a single calibration image.""" - img = cv2.imread(str(img_path)) - if img is None: - return None - - img = cv2.resize(img, (width, height)) - img = img.astype(np.float32) - - if reverse_input_channels: - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - - img = (img - mean) / scale - img = img.transpose(2, 0, 1) - return np.expand_dims(img, axis=0) - - def create_calibration_dataset( - self, - input_shape: list[int], - mean_values: str | None = None, - scale_values: str | None = None, - reverse_input_channels: bool = True, - subset_size: int = 5000, - return_labels: bool = False, - ) -> tuple[list[np.ndarray], list[int]] | list[np.ndarray]: - """ - Create calibration dataset from sample validation images. - - Args: - input_shape: Target input shape [batch, channels, height, width] - mean_values: Space-separated mean values for normalization - scale_values: Space-separated scale values for normalization - reverse_input_channels: Whether to reverse RGB to BGR - subset_size: Number of images to use for calibration - return_labels: Whether to return labels along with images - - Returns: - List of preprocessed image arrays, or tuple of (images, labels) - """ - if not self.dataset_path or not self.dataset_path.exists(): - self.logger.warning("Dataset path not provided or doesn't exist. Skipping quantization.") - return [] - - # Parse mean and scale values - mean = np.array([float(x) for x in mean_values.split()]) if mean_values else np.array([0, 0, 0]) - scale = np.array([float(x) for x in scale_values.split()]) if scale_values else np.array([1, 1, 1]) - - _, _, height, width = input_shape - calibration_data: list[np.ndarray] = [] - - # Find all images in the dataset - image_dir = self.dataset_path - if not image_dir.exists(): - self.logger.error(f"Image directory not found: {image_dir}") - return ([], []) - - image_entries = self._collect_dataset_entries(image_dir) - if not image_entries: - self.logger.error("No images found in dataset") - return ([], []) - - self.logger.info(f"Found {len(image_entries)} images in dataset") - self.logger.info(f"Using {min(subset_size, len(image_entries))} images for calibration") - - if return_labels: - labels: list[int] = [] - for i, (img_path, class_label) in enumerate(image_entries[:subset_size]): - try: - img = self._preprocess_calibration_image( - img_path=img_path, - width=width, - height=height, - mean=mean, - scale=scale, - reverse_input_channels=reverse_input_channels, - ) - if img is None: - continue - - calibration_data.append(img) - labels.append(class_label) - - if (i + 1) % 50 == 0: - self.logger.debug(f"Processed {i + 1}/{subset_size} images") - - except (cv2.error, OSError, TypeError, ValueError) as e: - self.logger.warning(f"Failed to process {img_path}: {e}") - continue - - self.logger.info(f"✓ Created calibration dataset with {len(calibration_data)} images") - return calibration_data, labels - - for i, (img_path, _) in enumerate(image_entries[:subset_size]): - try: - img = self._preprocess_calibration_image( - img_path=img_path, - width=width, - height=height, - mean=mean, - scale=scale, - reverse_input_channels=reverse_input_channels, - ) - if img is None: - continue - - calibration_data.append(img) - - if (i + 1) % 50 == 0: - self.logger.debug(f"Processed {i + 1}/{subset_size} images") - - except (cv2.error, OSError, TypeError, ValueError) as e: - self.logger.warning(f"Failed to process {img_path}: {e}") - continue - - self.logger.info(f"✓ Created calibration dataset with {len(calibration_data)} images") - return calibration_data, [] - - def validate_model( - self, - model_path: Path, - validation_data: list[np.ndarray], - labels: list[int], - ) -> float: - """ - Validate OpenVINO model and compute top-1 accuracy. - - Args: - model_path: Path to the OpenVINO model (.xml) - validation_data: List of validation images - labels: List of ground truth labels - - Returns: - Top-1 accuracy (0.0 to 1.0) - """ - try: - import openvino as ov - - core = ov.Core() - model = core.read_model(model_path) - compiled_model = core.compile_model(model, device_name="CPU") - output_layer = compiled_model.outputs[0] - - predictions: list[int] = [] - for img in validation_data: - result = compiled_model(img)[output_layer] - pred_class = np.argmax(result, axis=1)[0] - predictions.append(pred_class) - - # Compute accuracy - correct = sum(predicted == label for predicted, label in zip(predictions, labels)) - return correct / len(labels) - - except (ImportError, OSError, RuntimeError, TypeError, ValueError) as e: - self.logger.error(f"Failed to validate model: {e}") - return 0.0 - - def quantize_model( - self, - model_path: Path, - calibration_data: list[np.ndarray], - model_config: dict[str, Any], - preset: str = "accuracy", - validation_data: list[np.ndarray] | None = None, - validation_labels: list[int] | None = None, - ) -> Path: - """ - Quantize OpenVINO model to INT8 using NNCF. - - Args: - model_path: Path to the FP32 OpenVINO model (.xml) - calibration_data: List of calibration images - model_config: Model configuration used for README rendering - preset: Quantization preset ('accuracy', 'performance', 'mixed') - validation_data: Optional validation images for accuracy measurement - validation_labels: Optional validation labels for accuracy measurement - - Returns: - Path to the quantized model - """ - if not calibration_data: - self.logger.warning("No calibration data provided. Skipping quantization.") - return model_path - - try: - import nncf - import openvino as ov - - self.logger.info(f"Quantizing model with {len(calibration_data)} calibration samples") - self.logger.info(f"Using preset: {preset}") - - # Load the model - core = ov.Core() - model = core.read_model(model_path) - - # Create calibration dataset generator - def calibration_dataset(): - for data in calibration_data: - yield data - - # Map preset string to NNCF enum - preset_map = { - "performance": nncf.QuantizationPreset.PERFORMANCE, - "mixed": nncf.QuantizationPreset.MIXED, - } - nncf_preset = preset_map.get(preset.lower(), nncf.QuantizationPreset.MIXED) - - # Quantize the model - quantized_model = nncf.quantize( - model, - calibration_dataset=nncf.Dataset(calibration_dataset()), - preset=nncf_preset, - subset_size=len(calibration_data), - ) - - # Extract model name from the FP32 model path - # The FP32 path is like: output_dir/model_name-fp16-ov/model_name_fp32.xml - model_name = model_path.stem # Gets model_name_fp32 from model_name_fp32.xml - # Remove _fp32 suffix if present - if model_name.endswith("_fp32"): - model_name = model_name[:-5] - - # Create output folder with -int8-ov suffix - output_folder = model_path.parent.parent / f"{model_name}-int8-ov" - output_folder.mkdir(parents=True, exist_ok=True) - - # Save quantized model with model name inside the folder - output_path = output_folder / f"{model_name}.xml" - ov.save_model(quantized_model, output_path, compress_to_fp16=True) - self.logger.info(f"✓ Quantized model saved: {output_path}") - - # Save model_info as config.json to track downloads - with (output_folder / "config.json").open("w") as f: - json.dump(quantized_model.get_rt_info(["model_info"]).value, f, indent=4) - - # Validate accuracy if validation data provided - if validation_data and validation_labels: - self.logger.info("Validating FP32 model accuracy...") - fp32_accuracy = self.validate_model(model_path, validation_data, validation_labels) - self.logger.info(f"FP32 Top-1 Accuracy: {fp32_accuracy * 100:.2f}%") - - self.logger.info("Validating INT8 model accuracy...") - int8_accuracy = self.validate_model(output_path, validation_data, validation_labels) - self.logger.info(f"INT8 Top-1 Accuracy: {int8_accuracy * 100:.2f}%") - - accuracy_drop = (fp32_accuracy - int8_accuracy) * 100 - self.logger.info(f"Accuracy Drop: {accuracy_drop:.2f}%") - - # Copy .gitattributes file - gitattributes_template = Path(__file__).parent / "templates" / ".gitattributes" - if gitattributes_template.exists(): - shutil.copy2(gitattributes_template, output_folder / ".gitattributes") - self.logger.debug(f"Copied .gitattributes to: {output_folder}") - - # Copy README for INT8 model - self.copy_readme( - model_config, - output_folder, - variant="int8", - ) - - return output_path - - except ImportError: - self.logger.error("NNCF not installed. Install with: pip install nncf") - return model_path - except (OSError, RuntimeError, TypeError, ValueError) as e: - self.logger.error(f"Failed to quantize model: {e}") - import traceback - - self.logger.debug(traceback.format_exc()) - return model_path - - def export_to_openvino( - self, - model: nn.Module, - input_shape: list[int], - output_path: Path, - model_config: dict[str, Any], - input_names: list[str] | None = None, - output_names: list[str] | None = None, - metadata: dict[tuple[str, str], str] | None = None, - ) -> tuple[Path, Path]: - """ - Export PyTorch model to OpenVINO format. - - Args: - model: PyTorch model to export - input_shape: Input tensor shape [batch, channels, height, width] - output_path: Path to save the model (without extension) - model_config: Model configuration used for README rendering - input_names: Names for input tensors - output_names: Names for output tensors - metadata: Metadata to embed in the model - - Returns: - Tuple of (fp16_model_path, fp32_model_path) - FP16 for final use, FP32 for quantization - """ - import openvino as ov - - try: - model.eval() - dummy_input = torch.randn(*input_shape) - self.logger.info("Direct PyTorch to OpenVINO conversion") - ov_model = ov.convert_model(model, example_input=dummy_input) - self.logger.info("✓ PyTorch to OpenVINO conversion complete") - - # Reshape model to fixed input shape (remove dynamic dimensions) - first_input = ov_model.input(0) - input_name_for_reshape = next(iter(first_input.get_names())) if first_input.get_names() else 0 - - self.logger.debug(f"Setting fixed input shape: {input_shape}") - ov_model.reshape({input_name_for_reshape: input_shape}) - - # Post-process the model - ov_model = self._postprocess_openvino_model( - ov_model, - input_names=input_names, - output_names=output_names, - metadata=metadata, - ) - - # Create output folder with -fp16-ov suffix - model_name = output_path.name - output_folder = output_path.parent / f"{model_name}-fp16-ov" - output_folder.mkdir(parents=True, exist_ok=True) - - # Save FP32 model for quantization (temporary) - fp32_xml_path = output_folder / f"{model_name}_fp32.xml" - ov.save_model(ov_model, fp32_xml_path, compress_to_fp16=False) - self.logger.debug(f"Saved FP32 model for quantization: {fp32_xml_path}") - - # Save the FP16 model (final) - xml_path = output_folder / f"{model_name}.xml" - ov.save_model(ov_model, xml_path, compress_to_fp16=True) - self.logger.info(f"✓ Model saved: {xml_path}") - - # Save model_info as config.json to track downloads - with (output_folder / "config.json").open("w") as f: - json.dump(ov_model.get_rt_info(["model_info"]).value, f, indent=4) - - # Copy .gitattributes file - gitattributes_template = Path(__file__).parent / "templates" / ".gitattributes" - if gitattributes_template.exists(): - shutil.copy2(gitattributes_template, output_folder / ".gitattributes") - self.logger.debug(f"Copied .gitattributes to: {output_folder}") - - # Copy README for FP16 model - self.copy_readme( - model_config, - output_folder, - variant="fp16", - ) - - return xml_path, fp32_xml_path - - except Exception as e: - self.logger.error(f"Failed to export model: {e}") - raise - - def _postprocess_openvino_model( - self, - model: Any, - input_names: list[str] | None = None, - output_names: list[str] | None = None, - metadata: dict[tuple[str, str], str] | None = None, - ) -> Any: - """ - Post-process OpenVINO model (set names, add metadata). - - Args: - model: OpenVINO model - input_names: Names for input tensors - output_names: Names for output tensors - metadata: Metadata to embed - - Returns: - Post-processed model - """ - # Set input names - if input_names: - for i, name in enumerate(input_names): - if i < len(model.inputs): - model.input(i).set_names({name}) - self.logger.debug(f"Set input {i} name to: {name}") - - # Set output names - if output_names: - for i, name in enumerate(output_names): - if i < len(model.outputs): - model.output(i).set_names({name}) - self.logger.debug(f"Set output {i} name to: {name}") - - # Add metadata - if metadata: - for key, value in metadata.items(): - model.set_rt_info(value, list(key)) - self.logger.debug(f"Set metadata {key}: {value}") - - return model - - def process_model_config(self, config: dict[str, Any]) -> bool: - """ - Process a single model configuration. - - Args: - config: Model configuration dictionary - - Returns: - True if successful, False otherwise - """ - model_short_name = config.get("model_short_name", "unknown") - model_license = config.get("license") - model_license_link = config.get("license_link") - - # Check if both FP16 and INT8 models already exist - fp16_model_path = self.output_dir / f"{model_short_name}-fp16-ov" / f"{model_short_name}.xml" - int8_model_path = self.output_dir / f"{model_short_name}-int8-ov" / f"{model_short_name}.xml" - - if fp16_model_path.exists() and int8_model_path.exists(): - self.logger.info(f"Skipping {model_short_name}: FP16 and INT8 models already exist") - return True - - try: - if not model_license: - error_msg = f"Model '{model_short_name}' must define 'license' in configuration" - raise ValueError(error_msg) - if not model_license_link: - error_msg = f"Model '{model_short_name}' must define 'license_link' in configuration" - raise ValueError(error_msg) - - self.logger.info("=" * 80) - self.logger.info(f"Processing model: {config.get('model_full_name', model_short_name)}") - self.logger.info(f"Short name: {model_short_name}") - if "description" in config: - self.logger.info(f"Description: {config['description']}") - self.logger.info("=" * 80) - - # Check if this is a Hugging Face model - huggingface_repo = config.get("huggingface_repo") - if huggingface_repo: - huggingface_revision = config.get("huggingface_revision") - if not huggingface_revision: - error_msg = "Hugging Face models must define 'huggingface_revision' with an immutable commit SHA" - raise ValueError(error_msg) - - # Load model from Hugging Face - model_library = config.get("model_library", "timm") - model_params = config.get("model_params") - model = self.load_huggingface_model( - repo_id=huggingface_repo, - revision=huggingface_revision, - model_library=model_library, - model_params=model_params, - ) - else: - # Traditional PyTorch model workflow - # Download weights - weights_url = config["weights_url"] - weights_path = self.download_weights(weights_url) - - # Load model class - model_class_name = config.get("model_class_name", "torch.nn.Module") - model_class = self.load_model_class(model_class_name) - - # Load checkpoint - checkpoint = self.load_checkpoint(weights_path) - - # Create model - model_params = config.get("model_params") - model = self.create_model(model_class, checkpoint, model_params) - - # Prepare export parameters - input_shape = config.get("input_shape", [1, 3, 224, 224]) - input_names = config.get("input_names", ["input"]) - output_names = config.get("output_names", ["result"]) - - # Prepare metadata from config (with defaults for normalization) - reverse_input_channels = config.get("reverse_input_channels", True) - mean_values = config.get("mean_values", "123.675 116.28 103.53") - scale_values = config.get("scale_values", "58.395 57.12 57.375") - - metadata = { - ("model_info", "model_type"): config.get("model_type", ""), - ("model_info", "model_short_name"): model_short_name, - ("model_info", "reverse_input_channels"): str(reverse_input_channels), - ("model_info", "mean_values"): mean_values, - ("model_info", "scale_values"): scale_values, - } - - # Add labels if specified in config - labels_config = config.get("labels") - if labels_config: - labels = self.get_labels(labels_config) - if labels: - metadata["model_info", "labels"] = labels - self.logger.info(f"Added {labels_config} labels to metadata") - else: - self.logger.warning(f"Could not load labels for: {labels_config}") - - # Get model library (default to 'timm' for backward compatibility) - model_library = config.get("model_library", "timm") - - output_path = self.output_dir / model_short_name - fp16_model_path, fp32_model_path = self.export_to_openvino( - model=model, - input_shape=input_shape, - output_path=output_path, - model_config=config, - input_names=input_names, - output_names=output_names, - metadata=metadata, - ) - - # Quantize the model if dataset is available - if self.dataset_path: - self.logger.info("Creating calibration dataset for INT8 quantization") - has_labels = bool(config.get("labels")) - - self.logger.info("Creating validation dataset for accuracy measurement") - validation_data, validation_labels = self.create_calibration_dataset( - input_shape=input_shape, - mean_values=mean_values, - scale_values=scale_values, - reverse_input_channels=reverse_input_channels, - subset_size=300, - return_labels=has_labels, - ) - - if validation_data: - # Use FP32 model for better quantization accuracy - self.quantize_model( - model_path=fp32_model_path, - calibration_data=validation_data, - model_config=config, - preset="mixed", - validation_data=validation_data if validation_labels else None, - validation_labels=validation_labels or None, - ) - - # Clean up temporary FP32 model after quantization - try: - if fp32_model_path.exists(): - fp32_model_path.unlink() - self.logger.debug(f"Removed temporary FP32 model: {fp32_model_path}") - # Also remove the .bin file - fp32_bin_path = fp32_model_path.with_suffix(".bin") - if fp32_bin_path.exists(): - fp32_bin_path.unlink() - self.logger.debug(f"Removed temporary FP32 weights: {fp32_bin_path}") - except OSError as e: - self.logger.warning(f"Failed to remove temporary FP32 files: {e}") - - self.logger.info(f"✓ Successfully converted {model_short_name}") - return True - - except (ValueError, RuntimeError, ImportError, FileNotFoundError) as e: - self.logger.error(f"✗ Failed to process model {model_short_name}: {e}") - import traceback - - self.logger.debug(traceback.format_exc()) - return False - - def process_config_file( - self, - config_path: Path, - model_filter: str | None = None, - ) -> tuple[int, int]: - """ - Process models from a configuration file. - - Args: - config_path: Path to JSON configuration file - model_filter: Optional model short name to process (process only this model) - - Returns: - Tuple of (successful_count, failed_count) - """ - try: - with Path(config_path).open() as f: - config = json.load(f) - except Exception as e: - self.logger.error(f"Failed to load configuration file: {e}") - raise - - models = config.get("models", []) - - if not models: - self.logger.warning("No models found in configuration file") - return 0, 0 - - self.logger.info(f"Configuration validated: {len(models)} models found") - - # Filter models if requested - if model_filter: - models = [m for m in models if m.get("model_short_name") == model_filter] - if not models: - self.logger.error(f"Model '{model_filter}' not found in configuration") - return 0, 0 - self.logger.info(f"Processing only model: {model_filter}") - - successful = 0 - failed = 0 - - for model_config in models: - if self.process_model_config(model_config): - successful += 1 - else: - failed += 1 - - return successful, failed - - -def list_models(config_path: Path): - """List all models in a configuration file.""" - try: - with config_path.open() as f: - config = json.load(f) - except (FileNotFoundError, json.JSONDecodeError, PermissionError) as e: - print(f"Error loading configuration: {e}", file=sys.stderr) - return - - models = config.get("models", []) - - if not models: - print("No models found in configuration") - return - - print(f"\nFound {len(models)} models:\n") - print(f"{'Short Name':<30} {'Full Name':<40} {'Type':<20}") - print("-" * 90) - - for model in models: - short_name = model.get("model_short_name", "N/A") - full_name = model.get("model_full_name", "N/A") - model_type = model.get("model_type", "N/A") - print(f"{short_name:<30} {full_name:<40} {model_type:<20}") - - print() - - -def main(): - """Main entry point.""" - parser = argparse.ArgumentParser( - description="Convert PyTorch models to OpenVINO format", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Convert all models in config - uv run model-converter examples/config.json -o ./models - - # Convert a specific model - uv run model-converter examples/config.json -o ./models --model resnet50 - - # List all models in config - uv run model-converter examples/config.json --list - - # Enable verbose logging - uv run model-converter examples/config.json -o ./models -v - """, - ) - - parser.add_argument( - "config", - type=Path, - help="Path to JSON configuration file", - ) - - parser.add_argument( - "-o", - "--output", - type=Path, - default=Path("./converted_models"), - help="Output directory for converted models (default: ./converted_models)", - ) - - parser.add_argument( - "-c", - "--cache", - type=Path, - default=Path.home() / ".cache" / "torch" / "hub" / "checkpoints", - help="Cache directory for downloaded weights (default: ~/.cache/torch/hub/checkpoints)", - ) - - parser.add_argument( - "-d", - "--dataset", - type=Path, - default=Path.home() / "model_api" / "validation_dataset", - help=("Path to calibration dataset for INT8 quantization (default: ~/model_api/validation_dataset)"), - ) - - parser.add_argument( - "--model", - type=str, - help="Process only the specified model (by model_short_name)", - ) - - parser.add_argument( - "--list", - action="store_true", - help="List all models in the configuration file and exit", - ) - - parser.add_argument( - "-v", - "--verbose", - action="store_true", - help="Enable verbose logging", - ) - - args = parser.parse_args() - - # Check if config file exists - if not args.config.exists(): - print(f"Error: Configuration file not found: {args.config}", file=sys.stderr) - return 1 - - # List models and exit - if args.list: - list_models(args.config) - return 0 - - # Setup logging - log_level = logging.DEBUG if args.verbose else logging.INFO - logging.basicConfig( - level=log_level, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - logger = logging.getLogger(__name__) - - logger.info(f"Loading configuration from: {args.config}") - - try: - # Create converter - converter = ModelConverter( - output_dir=args.output, - cache_dir=args.cache, - verbose=args.verbose, - dataset_path=args.dataset, - ) - - logger.info(f"Output directory: {args.output}") - logger.info(f"Cache directory: {args.cache}") - if args.dataset: - logger.info(f"Calibration dataset: {args.dataset}") - - # Process models - successful, failed = converter.process_config_file( - config_path=args.config, - model_filter=args.model, - ) - - # Print summary - logger.info("=" * 80) - logger.info("Conversion Summary:") - logger.info(f" Successful: {successful}") - logger.info(f" Failed: {failed}") - logger.info(f" Total: {successful + failed}") - logger.info("=" * 80) - - return 0 if failed == 0 else 1 - except (ValueError, RuntimeError, ImportError, FileNotFoundError) as e: - logger.error(f"Failed to process model: {e}") - return 1 +from model_converter.cli import main if __name__ == "__main__": sys.exit(main()) diff --git a/model_converter/tests/__init__.py b/model_converter/tests/__init__.py new file mode 100644 index 00000000..d2223984 --- /dev/null +++ b/model_converter/tests/__init__.py @@ -0,0 +1,4 @@ +# +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# diff --git a/model_converter/tests/unit/__init__.py b/model_converter/tests/unit/__init__.py new file mode 100644 index 00000000..d2223984 --- /dev/null +++ b/model_converter/tests/unit/__init__.py @@ -0,0 +1,4 @@ +# +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# diff --git a/model_converter/tests/unit/conftest.py b/model_converter/tests/unit/conftest.py new file mode 100644 index 00000000..47f75f5b --- /dev/null +++ b/model_converter/tests/unit/conftest.py @@ -0,0 +1,128 @@ +# +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +"""Shared fixtures for model_converter unit tests.""" + +from unittest.mock import MagicMock + +import pytest + + +@pytest.fixture +def tmp_output_dir(tmp_path): + """Temporary output directory for converted models.""" + return tmp_path / "output" + + +@pytest.fixture +def tmp_cache_dir(tmp_path): + """Temporary cache directory for downloaded weights.""" + return tmp_path / "cache" + + +@pytest.fixture +def sample_model_config(): + """Sample model configuration dictionary.""" + return { + "model_short_name": "test_model", + "model_full_name": "Test Model", + "model_class_name": "torchvision.models.resnet.resnet18", + "weights_url": "https://example.com/weights.pth", + "input_shape": [1, 3, 224, 224], + "input_names": ["input"], + "output_names": ["result"], + "model_type": "Classification", + "license": "Apache-2.0", + "license_link": "https://www.apache.org/licenses/LICENSE-2.0", + "docs": "https://docs.example.com", + "labels": "IMAGENET1K_V1", + "mean_values": "123.675 116.28 103.53", + "scale_values": "58.395 57.12 57.375", + "reverse_input_channels": True, + "description": "A test model", + } + + +@pytest.fixture +def converter(tmp_output_dir, tmp_cache_dir): + """Pre-built ModelConverter instance with temporary directories.""" + from model_converter.cli import ModelConverter + + return ModelConverter( + output_dir=tmp_output_dir, + cache_dir=tmp_cache_dir, + verbose=True, + ) + + +@pytest.fixture +def mock_ov_model(): + """Mock OpenVINO model object.""" + model = MagicMock() + model.inputs = [MagicMock()] + model.outputs = [MagicMock()] + model.input.return_value = MagicMock() + model.output.return_value = MagicMock() + + # Mock input(0).get_names() + input_mock = MagicMock() + input_mock.get_names.return_value = {"input"} + model.input.return_value = input_mock + + return model + + +@pytest.fixture +def config_file(tmp_path, sample_model_config): + """Create a temporary config JSON file.""" + import json + + config = {"models": [sample_model_config]} + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + return config_path + + +@pytest.fixture +def mock_torch_model(): + """Mock PyTorch model (nn.Module).""" + model = MagicMock() + model.eval = MagicMock(return_value=model) + model.load_state_dict = MagicMock() + return model + + +@pytest.fixture +def dataset_dir(tmp_path): + """Create a temporary calibration dataset directory structure.""" + import numpy as np + + dataset_path = tmp_path / "dataset" + class_dir = dataset_path / "0" + class_dir.mkdir(parents=True) + + # Create a dummy image file + import cv2 + + img = np.zeros((224, 224, 3), dtype=np.uint8) + cv2.imwrite(str(class_dir / "image_001.jpg"), img) + + class_dir2 = dataset_path / "1" + class_dir2.mkdir(parents=True) + cv2.imwrite(str(class_dir2 / "image_002.jpg"), img) + + return dataset_path + + +@pytest.fixture +def template_dir(tmp_path): + """Create a temporary template directory with sample templates.""" + templates = tmp_path / "templates" + templates.mkdir() + (templates / "README-timm-fp16.md").write_text("# <> (<>)\nLicense: <>") + (templates / "README-timm-int8.md").write_text("# <> INT8\nLicense: <>") + (templates / "README-torchvision-fp16.md").write_text("# <> (<>)\nLicense: <>") + (templates / ".gitattributes").write_text("*.bin filter=lfs diff=lfs merge=lfs -text\n") + return templates diff --git a/model_converter/tests/unit/test_adapters.py b/model_converter/tests/unit/test_adapters.py new file mode 100644 index 00000000..35a45970 --- /dev/null +++ b/model_converter/tests/unit/test_adapters.py @@ -0,0 +1,141 @@ +# +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +"""Tests for model_converter.adapters module.""" + +from unittest.mock import MagicMock + +import torch +from model_converter.adapters import TorchvisionMaskRCNNExportAdapter, get_adapter +from model_converter.adapters.base import ExportAdapter as BaseExportAdapter + + +class TestGetAdapter: + """Tests for get_adapter registry function.""" + + def test_known_type_maskrcnn(self): + """get_adapter returns MaskRCNN adapter for 'maskrcnn' type.""" + mock_model = MagicMock() + result = get_adapter("maskrcnn", mock_model) + assert isinstance(result, TorchvisionMaskRCNNExportAdapter) + + def test_known_type_case_insensitive(self): + """get_adapter handles case-insensitive lookup.""" + mock_model = MagicMock() + result = get_adapter("MaskRCNN", mock_model) + assert isinstance(result, TorchvisionMaskRCNNExportAdapter) + + def test_unknown_type_returns_model(self): + """get_adapter returns model unchanged for unknown types.""" + mock_model = MagicMock() + result = get_adapter("unknown_model_type", mock_model) + assert result is mock_model + + def test_empty_type_returns_model(self): + """get_adapter returns model unchanged for empty string.""" + mock_model = MagicMock() + result = get_adapter("", mock_model) + assert result is mock_model + + +class TestExportAdapter: + """Tests for ExportAdapter base class.""" + + def test_init_stores_model(self): + """ExportAdapter stores the model as an attribute.""" + mock_model = MagicMock(spec=torch.nn.Module) + adapter = BaseExportAdapter(mock_model) + assert adapter.model is mock_model + + def test_is_nn_module(self): + """ExportAdapter is a subclass of nn.Module.""" + assert issubclass(BaseExportAdapter, torch.nn.Module) + + +class TestTorchvisionMaskRCNNExportAdapter: + """Tests for TorchvisionMaskRCNNExportAdapter.""" + + def test_forward(self): + """Test forward pass transforms MaskRCNN output correctly.""" + # Create a mock MaskRCNN model + mock_model = MagicMock() + + # Mock transform + mock_image_list = MagicMock() + mock_image_list.tensors = torch.randn(1, 3, 224, 224) + mock_image_list.image_sizes = [(224, 224)] + mock_model.transform.return_value = (mock_image_list, None) + + # Mock backbone + mock_features = {"0": torch.randn(1, 256, 56, 56)} + mock_model.backbone.return_value = mock_features + + # Mock RPN + mock_proposals = [torch.randn(100, 4)] + mock_model.rpn.return_value = (mock_proposals, None) + + # Mock ROI heads + mock_predictions = [ + { + "boxes": torch.randn(10, 4), + "scores": torch.rand(10), + "labels": torch.randint(1, 80, (10,)), + "masks": torch.rand(10, 1, 28, 28), + }, + ] + mock_model.roi_heads.return_value = (mock_predictions, None) + + # Create adapter and run forward + adapter = TorchvisionMaskRCNNExportAdapter(mock_model) + images = torch.randn(1, 3, 224, 224) + boxes, labels, masks = adapter.forward(images) + + # Verify outputs + assert boxes.shape == (10, 5) # boxes (4) + scores (1) + assert labels.shape == (10,) + assert masks.shape == (10, 28, 28) # squeezed from (10, 1, 28, 28) + + # Labels should be shifted by -1 + expected_labels = mock_predictions[0]["labels"] - 1 + assert torch.equal(labels, expected_labels) + + def test_forward_with_tensor_features(self): + """Test forward when backbone returns a tensor instead of OrderedDict.""" + mock_model = MagicMock() + + mock_image_list = MagicMock() + mock_image_list.tensors = torch.randn(1, 3, 224, 224) + mock_image_list.image_sizes = [(224, 224)] + mock_model.transform.return_value = (mock_image_list, None) + + # Return a tensor instead of dict + mock_model.backbone.return_value = torch.randn(1, 256, 56, 56) + + mock_proposals = [torch.randn(50, 4)] + mock_model.rpn.return_value = (mock_proposals, None) + + mock_predictions = [ + { + "boxes": torch.randn(5, 4), + "scores": torch.rand(5), + "labels": torch.randint(1, 80, (5,)), + "masks": torch.rand(5, 1, 28, 28), + }, + ] + mock_model.roi_heads.return_value = (mock_predictions, None) + + adapter = TorchvisionMaskRCNNExportAdapter(mock_model) + images = torch.randn(1, 3, 224, 224) + boxes, labels, masks = adapter.forward(images) + + # When backbone returns tensor, it should be wrapped in OrderedDict + # Verify rpn was called with correct features format + rpn_call_features = mock_model.rpn.call_args[0][1] + assert isinstance(rpn_call_features, dict) + assert "0" in rpn_call_features + + assert boxes.shape == (5, 5) + assert labels.shape == (5,) + assert masks.shape == (5, 28, 28) diff --git a/model_converter/tests/unit/test_cli.py b/model_converter/tests/unit/test_cli.py new file mode 100644 index 00000000..7cc5b8ae --- /dev/null +++ b/model_converter/tests/unit/test_cli.py @@ -0,0 +1,1762 @@ +# +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +"""Tests for model_converter.cli module.""" + +import json +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch + +import cv2 +import numpy as np +import pytest +import torch +import torch.nn as nn +from model_converter.cli import ModelConverter, list_models, main + + +class TestModelConverterInit: + """Tests for ModelConverter.__init__.""" + + def test_creates_directories(self, tmp_path): + """ModelConverter creates output and cache directories.""" + output_dir = tmp_path / "output" + cache_dir = tmp_path / "cache" + + ModelConverter(output_dir=output_dir, cache_dir=cache_dir) + + assert output_dir.exists() + assert cache_dir.exists() + + def test_verbose_logging(self, tmp_path): + """ModelConverter sets debug logging in verbose mode.""" + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + verbose=True, + ) + assert converter.logger is not None + + def test_dataset_path(self, tmp_path): + """ModelConverter stores dataset path.""" + dataset = tmp_path / "dataset" + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=dataset, + ) + assert converter.dataset_path == dataset + + def test_dataset_path_none(self, tmp_path): + """ModelConverter handles None dataset path.""" + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=None, + ) + assert converter.dataset_path is None + + +class TestGetLabels: + """Tests for ModelConverter.get_labels.""" + + def test_imagenet1k_v1(self, converter): + """get_labels returns ImageNet1K labels.""" + mock_categories = ["tabby cat", "golden retriever", "great white shark"] + with ( + patch("model_converter.cli.importlib") as _, + patch.dict( + "sys.modules", + {"torchvision.models._meta": MagicMock(_IMAGENET_CATEGORIES=mock_categories)}, + ), + patch("model_converter.cli.ModelConverter.get_labels", wraps=converter.get_labels), + ): + # Directly test the code path + pass + + # Test with actual mocking of the import + with patch("torchvision.models._meta._IMAGENET_CATEGORIES", ["tabby cat", "golden retriever"], create=True): + result = converter.get_labels("IMAGENET1K_V1") + assert result is not None + assert " " not in result.split()[0] or "_" in result.split()[0] # spaces replaced with underscores + + def test_imagenet21k(self, converter): + """get_labels returns ImageNet21K labels.""" + mock_info = MagicMock() + mock_info.label_descriptions.return_value = ["tabby, tabby cat", "golden retriever, dog"] + + mock_imagenet_info_cls = MagicMock(return_value=mock_info) + with patch("timm.data.ImageNetInfo", mock_imagenet_info_cls): + result = converter.get_labels("IMAGENET21K") + + assert result == "tabby golden_retriever" + + def test_coco_v1(self, converter): + """get_labels returns COCO labels.""" + mock_weights = MagicMock() + mock_weights.COCO_V1.meta = {"categories": ["person", "bicycle", "car"]} + + with patch( + "torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights", + mock_weights, + ): + result = converter.get_labels("COCO_V1") + + assert result == "person bicycle car" + + def test_unknown_label_set(self, converter): + """get_labels returns None for unknown label sets.""" + result = converter.get_labels("NONEXISTENT_LABELS") + assert result is None + + +class TestLoadModelClass: + """Tests for ModelConverter.load_model_class.""" + + def test_successful_import(self, converter): + """load_model_class dynamically imports a class.""" + result = converter.load_model_class("torch.nn.Linear") + assert result is torch.nn.Linear + + def test_import_failure(self, converter): + """load_model_class raises on invalid path.""" + with pytest.raises(ModuleNotFoundError): + converter.load_model_class("nonexistent.module.Class") + + +class TestLoadCheckpoint: + """Tests for ModelConverter.load_checkpoint.""" + + def test_successful_load(self, converter, tmp_path): + """load_checkpoint loads a PyTorch checkpoint.""" + checkpoint_path = tmp_path / "checkpoint.pth" + checkpoint_path.touch() + state_dict = {"layer.weight": torch.randn(10, 10)} + + with patch("torch.load", return_value=state_dict): + result = converter.load_checkpoint(checkpoint_path) + + assert "layer.weight" in result + + def test_load_failure(self, converter, tmp_path): + """load_checkpoint raises on invalid file.""" + bad_path = tmp_path / "nonexistent.pth" + with pytest.raises(FileNotFoundError): + converter.load_checkpoint(bad_path) + + +class TestLoadHuggingfaceModel: + """Tests for ModelConverter.load_huggingface_model.""" + + def test_timm_model(self, converter): + """load_huggingface_model loads timm model.""" + mock_model = MagicMock() + mock_model.eval.return_value = mock_model + + with patch("timm.create_model", return_value=mock_model) as mock_create: + result = converter.load_huggingface_model( + repo_id="timm/resnet50", + revision="abc123", + model_library="timm", + ) + + assert result is mock_model + mock_create.assert_called_once() + mock_model.eval.assert_called_once() + + def test_timm_model_with_params(self, converter): + """load_huggingface_model passes model_params to timm.""" + mock_model = MagicMock() + mock_model.eval.return_value = mock_model + + with patch("timm.create_model", return_value=mock_model) as mock_create: + converter.load_huggingface_model( + repo_id="timm/resnet50", + revision="abc123", + model_library="timm", + model_params={"num_classes": 10}, + ) + + call_kwargs = mock_create.call_args[1] + assert call_kwargs["num_classes"] == 10 + + def test_transformers_model(self, converter): + """load_huggingface_model loads transformers model.""" + mock_model = MagicMock() + mock_model.eval.return_value = mock_model + + with patch("transformers.AutoModel.from_pretrained", return_value=mock_model) as mock_from: + result = converter.load_huggingface_model( + repo_id="bert-base-uncased", + revision="abc123", + model_library="transformers", + ) + + assert result is mock_model + mock_from.assert_called_once() + + def test_unsupported_library(self, converter): + """load_huggingface_model raises for unsupported library.""" + with pytest.raises(ValueError, match="Unsupported model library"): + converter.load_huggingface_model( + repo_id="some/model", + revision="abc123", + model_library="unsupported_lib", + ) + + def test_load_failure(self, converter): + """load_huggingface_model raises on failure.""" + with ( + patch("timm.create_model", side_effect=RuntimeError("Connection error")), + pytest.raises(RuntimeError, match="Connection error"), + ): + converter.load_huggingface_model( + repo_id="timm/resnet50", + revision="abc123", + model_library="timm", + ) + + +class TestCreateModel: + """Tests for ModelConverter.create_model.""" + + def test_nn_module_with_model_key(self, converter): + """create_model extracts model from checkpoint 'model' key.""" + mock_model = MagicMock(spec=nn.Module) + mock_model.eval.return_value = mock_model + checkpoint = {"model": mock_model} + + result = converter.create_model(torch.nn.Module, checkpoint) + assert result is mock_model + + def test_nn_module_with_state_dict_only_raises(self, converter): + """create_model raises when nn.Module with only state_dict.""" + checkpoint = {"state_dict": {"layer.weight": torch.randn(10)}} + + with pytest.raises(ValueError, match="state_dict"): + converter.create_model(torch.nn.Module, checkpoint) + + def test_nn_module_direct_model_as_checkpoint(self, converter): + """create_model handles nn.Module passed directly (not in dict) gracefully.""" + # When an nn.Module is passed as checkpoint, "model" in checkpoint raises TypeError + # which is caught by the except block and re-raised + model = nn.Linear(10, 10) + with pytest.raises(TypeError): + converter.create_model(torch.nn.Module, model) + + def test_nn_module_invalid_checkpoint(self, converter): + """create_model raises when checkpoint is not a valid model.""" + with pytest.raises(ValueError, match="does not contain a valid model"): + converter.create_model(torch.nn.Module, {"some_key": "some_value"}) + + def test_model_class_with_state_dict(self, converter): + """create_model instantiates class and loads state_dict.""" + mock_class = MagicMock() + mock_instance = MagicMock() + mock_instance.eval.return_value = mock_instance + mock_class.return_value = mock_instance + + state_dict = {"layer.weight": torch.randn(10)} + checkpoint = {"state_dict": state_dict} + + result = converter.create_model(mock_class, checkpoint) + assert result is mock_instance + mock_instance.load_state_dict.assert_called_once() + + def test_model_class_with_model_params(self, converter): + """create_model passes model_params to class constructor.""" + mock_class = MagicMock() + mock_instance = MagicMock() + mock_instance.eval.return_value = mock_instance + mock_class.return_value = mock_instance + + checkpoint = {"state_dict": {"w": torch.randn(10)}} + + converter.create_model(mock_class, checkpoint, model_params={"num_classes": 5}) + mock_class.assert_called_once_with(num_classes=5) + + def test_model_class_with_model_key_as_module(self, converter): + """create_model returns checkpoint['model'] if it's an nn.Module.""" + mock_model = MagicMock(spec=nn.Module) + mock_class = MagicMock() + checkpoint = {"model": mock_model} + + result = converter.create_model(mock_class, checkpoint) + assert result is mock_model + + def test_model_class_with_model_key_as_dict(self, converter): + """create_model uses checkpoint['model'] as state_dict.""" + mock_class = MagicMock() + mock_instance = MagicMock() + mock_instance.eval.return_value = mock_instance + mock_class.return_value = mock_instance + + checkpoint = {"model": {"layer.weight": torch.randn(10)}} + + result = converter.create_model(mock_class, checkpoint) + assert result is mock_instance + mock_instance.load_state_dict.assert_called_once() + + def test_model_class_bare_state_dict(self, converter): + """create_model uses checkpoint directly as state_dict.""" + mock_class = MagicMock() + mock_instance = MagicMock() + mock_instance.eval.return_value = mock_instance + mock_class.return_value = mock_instance + + checkpoint = {"layer.weight": torch.randn(10)} + + result = converter.create_model(mock_class, checkpoint) + assert result is mock_instance + mock_instance.load_state_dict.assert_called_once() + + def test_create_model_failure(self, converter): + """create_model raises on instantiation failure.""" + mock_class = MagicMock(side_effect=RuntimeError("init failed")) + + with pytest.raises(RuntimeError, match="init failed"): + converter.create_model(mock_class, {}) + + +class TestCopyReadme: + """Tests for ModelConverter.copy_readme.""" + + def test_successful_copy(self, converter, tmp_path): + """copy_readme copies and fills README template.""" + output_folder = tmp_path / "model-fp16-ov" + output_folder.mkdir() + + # We mock the template file reading + template_content = "# <>\nLicense: <>\nLink: <>\nDocs: <>" + + config = { + "model_short_name": "test_model", + "license": "Apache-2.0", + "license_link": "https://apache.org/licenses/LICENSE-2.0", + "docs": "https://docs.example.com", + "model_library": "timm", + } + + with ( + patch.object(Path, "exists", return_value=True), + patch.object(Path, "read_text", return_value=template_content), + ): + converter.copy_readme(config, output_folder, variant="fp16") + + readme = output_folder / "README.md" + assert readme.exists() + content = readme.read_text() + assert "test_model" in content + assert "Apache-2.0" in content + + def test_template_not_found(self, converter, tmp_path): + """copy_readme handles missing template gracefully.""" + output_folder = tmp_path / "model-fp16-ov" + output_folder.mkdir() + + config = { + "model_short_name": "test_model", + "license": "MIT", + "license_link": "https://mit.edu", + "docs": "", + "model_library": "timm", + } + + # Template path doesn't exist + with patch.object(Path, "exists", return_value=False): + converter.copy_readme(config, output_folder, variant="fp16") + + # No README should be created + assert not (output_folder / "README.md").exists() + + def test_missing_model_short_name(self, converter, tmp_path): + """copy_readme warns when model_short_name is empty.""" + output_folder = tmp_path / "model-fp16-ov" + output_folder.mkdir() + + config = { + "model_short_name": "", + "license": "MIT", + "license_link": "https://mit.edu", + } + + # Should not raise but log warning + converter.copy_readme(config, output_folder) + + def test_missing_license_link(self, converter, tmp_path): + """copy_readme warns when license_link is empty.""" + output_folder = tmp_path / "model-fp16-ov" + output_folder.mkdir() + + config = { + "model_short_name": "test", + "license": "MIT", + "license_link": "", + } + + converter.copy_readme(config, output_folder) + + def test_missing_license(self, converter, tmp_path): + """copy_readme warns when license is empty.""" + output_folder = tmp_path / "model-fp16-ov" + output_folder.mkdir() + + config = { + "model_short_name": "test", + "license": "", + "license_link": "https://mit.edu", + } + + converter.copy_readme(config, output_folder) + + def test_missing_docs_field(self, converter, tmp_path): + """copy_readme handles missing docs field.""" + output_folder = tmp_path / "model-fp16-ov" + output_folder.mkdir() + + template_content = "# <>" + config = { + "model_short_name": "test_model", + "license": "Apache-2.0", + "license_link": "https://apache.org", + "model_library": "timm", + } + + with ( + patch.object(Path, "exists", return_value=True), + patch.object(Path, "read_text", return_value=template_content), + ): + converter.copy_readme(config, output_folder, variant="fp16") + + def test_none_value_in_config(self, converter, tmp_path): + """copy_readme skips None values in config placeholders.""" + output_folder = tmp_path / "model-fp16-ov" + output_folder.mkdir() + + template_content = "# <>" + config = { + "model_short_name": "test_model", + "license": "Apache-2.0", + "license_link": "https://apache.org", + "model_library": "timm", + "optional_field": None, + } + + with ( + patch.object(Path, "exists", return_value=True), + patch.object(Path, "read_text", return_value=template_content), + ): + converter.copy_readme(config, output_folder, variant="fp16") + + +class TestCollectDatasetEntries: + """Tests for ModelConverter._collect_dataset_entries.""" + + def test_collects_entries(self, converter, dataset_dir): + """_collect_dataset_entries finds images with class labels.""" + entries = converter._collect_dataset_entries(dataset_dir) + assert len(entries) == 2 + # Entries are (path, class_label) tuples + assert entries[0][1] == 0 + assert entries[1][1] == 1 + + def test_empty_directory(self, converter, tmp_path): + """_collect_dataset_entries returns empty list for empty dir.""" + empty_dir = tmp_path / "empty" + empty_dir.mkdir() + entries = converter._collect_dataset_entries(empty_dir) + assert entries == [] + + +class TestPreprocessCalibrationImage: + """Tests for ModelConverter._preprocess_calibration_image.""" + + def test_valid_image(self, converter, tmp_path): + """_preprocess_calibration_image processes image correctly.""" + # Create a dummy image + img = np.zeros((100, 100, 3), dtype=np.uint8) + img_path = tmp_path / "test.jpg" + cv2.imwrite(str(img_path), img) + + result = converter._preprocess_calibration_image( + img_path=img_path, + width=224, + height=224, + mean=np.array([123.675, 116.28, 103.53]), + scale=np.array([58.395, 57.12, 57.375]), + reverse_input_channels=True, + ) + + assert result is not None + assert result.shape == (1, 3, 224, 224) + + def test_no_channel_reversal(self, converter, tmp_path): + """_preprocess_calibration_image without channel reversal.""" + img = np.zeros((100, 100, 3), dtype=np.uint8) + img_path = tmp_path / "test.jpg" + cv2.imwrite(str(img_path), img) + + result = converter._preprocess_calibration_image( + img_path=img_path, + width=224, + height=224, + mean=np.array([0, 0, 0]), + scale=np.array([1, 1, 1]), + reverse_input_channels=False, + ) + + assert result is not None + assert result.shape == (1, 3, 224, 224) + + def test_invalid_image(self, converter, tmp_path): + """_preprocess_calibration_image returns None for invalid image.""" + bad_path = tmp_path / "notanimage.txt" + bad_path.write_text("not an image") + + result = converter._preprocess_calibration_image( + img_path=bad_path, + width=224, + height=224, + mean=np.array([0, 0, 0]), + scale=np.array([1, 1, 1]), + reverse_input_channels=True, + ) + + assert result is None + + +class TestCreateCalibrationDataset: + """Tests for ModelConverter.create_calibration_dataset.""" + + def test_no_dataset_path(self, tmp_path): + """create_calibration_dataset returns empty when no dataset path.""" + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=None, + ) + result = converter.create_calibration_dataset(input_shape=[1, 3, 224, 224]) + assert result == [] + + def test_nonexistent_dataset_path(self, tmp_path): + """create_calibration_dataset returns empty for missing path.""" + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=tmp_path / "nonexistent", + ) + result = converter.create_calibration_dataset(input_shape=[1, 3, 224, 224]) + assert result == [] + + def test_with_return_labels(self, tmp_path, dataset_dir): + """create_calibration_dataset returns images and labels.""" + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=dataset_dir, + ) + result = converter.create_calibration_dataset( + input_shape=[1, 3, 224, 224], + return_labels=True, + ) + images, labels = result + assert len(images) == 2 + assert len(labels) == 2 + assert labels[0] == 0 + assert labels[1] == 1 + + def test_without_return_labels(self, tmp_path, dataset_dir): + """create_calibration_dataset returns images without labels flag.""" + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=dataset_dir, + ) + result = converter.create_calibration_dataset( + input_shape=[1, 3, 224, 224], + return_labels=False, + ) + images, labels = result + assert len(images) == 2 + assert labels == [] + + def test_with_mean_scale(self, tmp_path, dataset_dir): + """create_calibration_dataset uses mean and scale values.""" + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=dataset_dir, + ) + result = converter.create_calibration_dataset( + input_shape=[1, 3, 224, 224], + mean_values="123.675 116.28 103.53", + scale_values="58.395 57.12 57.375", + return_labels=True, + ) + images, _labels = result + assert len(images) == 2 + + def test_empty_dataset(self, tmp_path): + """create_calibration_dataset handles empty dataset directory.""" + empty_dataset = tmp_path / "empty_dataset" + empty_dataset.mkdir() + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=empty_dataset, + ) + result = converter.create_calibration_dataset( + input_shape=[1, 3, 224, 224], + return_labels=True, + ) + assert result == ([], []) + + def test_subset_size(self, tmp_path, dataset_dir): + """create_calibration_dataset respects subset_size.""" + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=dataset_dir, + ) + result = converter.create_calibration_dataset( + input_shape=[1, 3, 224, 224], + subset_size=1, + return_labels=True, + ) + images, _labels = result + assert len(images) == 1 + + def test_image_processing_error(self, tmp_path): + """create_calibration_dataset skips images that raise exceptions (with labels).""" + dataset_path = tmp_path / "dataset" + class_dir = dataset_path / "0" + class_dir.mkdir(parents=True) + # Create a valid image file that will trigger an exception in preprocessing + img = np.zeros((10, 10, 3), dtype=np.uint8) + cv2.imwrite(str(class_dir / "image_001.jpg"), img) + + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=dataset_path, + ) + # Mock _preprocess_calibration_image to raise an exception + with patch.object(converter, "_preprocess_calibration_image", side_effect=ValueError("bad image")): + result = converter.create_calibration_dataset( + input_shape=[1, 3, 224, 224], + return_labels=True, + ) + images, _labels = result + assert len(images) == 0 + + def test_image_processing_error_no_labels(self, tmp_path): + """create_calibration_dataset skips images that raise exceptions (without labels).""" + dataset_path = tmp_path / "dataset" + class_dir = dataset_path / "0" + class_dir.mkdir(parents=True) + img = np.zeros((10, 10, 3), dtype=np.uint8) + cv2.imwrite(str(class_dir / "image_001.jpg"), img) + + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=dataset_path, + ) + with patch.object(converter, "_preprocess_calibration_image", side_effect=OSError("read error")): + result = converter.create_calibration_dataset( + input_shape=[1, 3, 224, 224], + return_labels=False, + ) + images, _labels = result + assert len(images) == 0 + + def test_image_returns_none_with_labels(self, tmp_path): + """create_calibration_dataset skips None images (with labels).""" + dataset_path = tmp_path / "dataset" + class_dir = dataset_path / "0" + class_dir.mkdir(parents=True) + img = np.zeros((10, 10, 3), dtype=np.uint8) + cv2.imwrite(str(class_dir / "image_001.jpg"), img) + + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=dataset_path, + ) + with patch.object(converter, "_preprocess_calibration_image", return_value=None): + result = converter.create_calibration_dataset( + input_shape=[1, 3, 224, 224], + return_labels=True, + ) + images, _labels = result + assert len(images) == 0 + + def test_image_returns_none_without_labels(self, tmp_path): + """create_calibration_dataset skips None images (without labels).""" + dataset_path = tmp_path / "dataset" + class_dir = dataset_path / "0" + class_dir.mkdir(parents=True) + img = np.zeros((10, 10, 3), dtype=np.uint8) + cv2.imwrite(str(class_dir / "image_001.jpg"), img) + + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=dataset_path, + ) + with patch.object(converter, "_preprocess_calibration_image", return_value=None): + result = converter.create_calibration_dataset( + input_shape=[1, 3, 224, 224], + return_labels=False, + ) + images, _labels = result + assert len(images) == 0 + + def test_progress_logging_with_labels(self, tmp_path): + """create_calibration_dataset logs progress every 50 images (with labels).""" + dataset_path = tmp_path / "dataset" + class_dir = dataset_path / "0" + class_dir.mkdir(parents=True) + + # Create 51 images to trigger the progress logging at i=49 (i+1=50) + img = np.zeros((10, 10, 3), dtype=np.uint8) + for i in range(51): + cv2.imwrite(str(class_dir / f"image_{i:03d}.jpg"), img) + + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=dataset_path, + verbose=True, + ) + result = converter.create_calibration_dataset( + input_shape=[1, 3, 10, 10], + return_labels=True, + ) + images, _labels = result + assert len(images) == 51 + + def test_progress_logging_without_labels(self, tmp_path): + """create_calibration_dataset logs progress every 50 images (without labels).""" + dataset_path = tmp_path / "dataset" + class_dir = dataset_path / "0" + class_dir.mkdir(parents=True) + + img = np.zeros((10, 10, 3), dtype=np.uint8) + for i in range(51): + cv2.imwrite(str(class_dir / f"image_{i:03d}.jpg"), img) + + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=dataset_path, + verbose=True, + ) + result = converter.create_calibration_dataset( + input_shape=[1, 3, 10, 10], + return_labels=False, + ) + images, _labels = result + assert len(images) == 51 + + def test_dataset_dir_removed_after_init_check(self, tmp_path): + """create_calibration_dataset handles dir removed between checks.""" + + dataset_path = tmp_path / "dataset" + dataset_path.mkdir() + + converter = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=dataset_path, + ) + + # Remove the directory after converter init but before calibration runs + # Patch exists() to return True on first call (line 415) then False on second (line 428) + original_exists = Path.exists + call_count = [0] + + def mock_exists(self_path): + if self_path == dataset_path: + call_count[0] += 1 + return call_count[0] == 1 + return original_exists(self_path) + + with patch.object(Path, "exists", mock_exists): + result = converter.create_calibration_dataset( + input_shape=[1, 3, 224, 224], + return_labels=True, + ) + assert result == ([], []) + + +class TestValidateModel: + """Tests for ModelConverter.validate_model.""" + + def test_correct_predictions(self, converter, tmp_path): + """validate_model computes accuracy correctly.""" + mock_output_layer = MagicMock() + mock_compiled = MagicMock() + mock_compiled.outputs = [mock_output_layer] + + # Setup callable to return predictions matching labels + mock_compiled.return_value = {mock_output_layer: np.array([[0.1, 0.9, 0.0]])} + call_results = [ + {mock_output_layer: np.array([[0.1, 0.9, 0.0]])}, # pred = 1 + {mock_output_layer: np.array([[0.0, 0.0, 0.9]])}, # pred = 2 + ] + mock_compiled.side_effect = call_results + + mock_core = MagicMock() + mock_core.read_model.return_value = MagicMock() + mock_core.compile_model.return_value = mock_compiled + + with patch("openvino.Core", return_value=mock_core): + accuracy = converter.validate_model( + model_path=tmp_path / "model.xml", + validation_data=[np.zeros((1, 3, 224, 224)), np.zeros((1, 3, 224, 224))], + labels=[1, 2], + ) + + assert accuracy == pytest.approx(1.0) + + def test_partial_accuracy(self, converter, tmp_path): + """validate_model returns partial accuracy.""" + mock_output_layer = MagicMock() + mock_compiled = MagicMock() + mock_compiled.outputs = [mock_output_layer] + + call_results = [ + {mock_output_layer: np.array([[0.9, 0.1]])}, # pred = 0 (correct) + {mock_output_layer: np.array([[0.9, 0.1]])}, # pred = 0 (wrong, label=1) + ] + mock_compiled.side_effect = call_results + + mock_core = MagicMock() + mock_core.read_model.return_value = MagicMock() + mock_core.compile_model.return_value = mock_compiled + + with patch("openvino.Core", return_value=mock_core): + accuracy = converter.validate_model( + model_path=tmp_path / "model.xml", + validation_data=[np.zeros((1, 3, 224, 224)), np.zeros((1, 3, 224, 224))], + labels=[0, 1], + ) + + assert accuracy == pytest.approx(0.5) + + def test_validation_failure(self, converter, tmp_path): + """validate_model returns 0.0 on error.""" + with patch("openvino.Core", side_effect=RuntimeError("OV error")): + accuracy = converter.validate_model( + model_path=tmp_path / "model.xml", + validation_data=[np.zeros((1, 3, 224, 224))], + labels=[0], + ) + + assert accuracy == pytest.approx(0.0) + + +class TestQuantizeModel: + """Tests for ModelConverter.quantize_model.""" + + def test_no_calibration_data(self, converter, sample_model_config, tmp_path): + """quantize_model returns model_path when no calibration data.""" + model_path = tmp_path / "model.xml" + result = converter.quantize_model( + model_path=model_path, + calibration_data=[], + model_config=sample_model_config, + ) + assert result == model_path + + def test_successful_quantization(self, converter, sample_model_config, tmp_path): + """quantize_model performs INT8 quantization.""" + # Setup model path + fp16_dir = tmp_path / "test_model-fp16-ov" + fp16_dir.mkdir(parents=True) + model_path = fp16_dir / "test_model_fp32.xml" + model_path.write_text("") + + mock_ov_model = MagicMock() + mock_quantized = MagicMock() + mock_quantized.get_rt_info.return_value = MagicMock(value={"model_type": "Classification"}) + + mock_core = MagicMock() + mock_core.read_model.return_value = mock_ov_model + + calibration_data = [np.zeros((1, 3, 224, 224))] + + def consume_dataset(gen): + """Mock nncf.Dataset that consumes the generator.""" + list(gen) # Consume the generator to cover lines 572-573 + return MagicMock() + + with ( + patch("openvino.Core", return_value=mock_core), + patch("nncf.quantize", return_value=mock_quantized), + patch("nncf.Dataset", side_effect=consume_dataset), + patch("nncf.QuantizationPreset") as mock_preset, + patch("openvino.save_model"), + patch.object(Path, "exists", return_value=True), + patch("shutil.copy2"), + patch.object(converter, "copy_readme"), + ): + mock_preset.MIXED = "mixed" + mock_preset.PERFORMANCE = "performance" + result = converter.quantize_model( + model_path=model_path, + calibration_data=calibration_data, + model_config=sample_model_config, + preset="mixed", + ) + + assert result != model_path # Should return new quantized path + + def test_nncf_not_installed(self, converter, sample_model_config, tmp_path): + """quantize_model handles missing NNCF.""" + model_path = tmp_path / "model.xml" + calibration_data = [np.zeros((1, 3, 224, 224))] + + with ( + patch.dict("sys.modules", {"nncf": None}), + patch("builtins.__import__", side_effect=ImportError("No module named 'nncf'")), + ): + result = converter.quantize_model( + model_path=model_path, + calibration_data=calibration_data, + model_config=sample_model_config, + ) + + assert result == model_path + + def test_quantization_with_validation(self, converter, sample_model_config, tmp_path): + """quantize_model validates accuracy when validation data provided.""" + fp16_dir = tmp_path / "test_model-fp16-ov" + fp16_dir.mkdir(parents=True) + model_path = fp16_dir / "test_model_fp32.xml" + model_path.write_text("") + + mock_quantized = MagicMock() + mock_quantized.get_rt_info.return_value = MagicMock(value={"model_type": "Classification"}) + + mock_core = MagicMock() + mock_core.read_model.return_value = MagicMock() + + calibration_data = [np.zeros((1, 3, 224, 224))] + validation_data = [np.zeros((1, 3, 224, 224))] + validation_labels = [0] + + def consume_dataset(gen): + list(gen) + return MagicMock() + + with ( + patch("openvino.Core", return_value=mock_core), + patch("nncf.quantize", return_value=mock_quantized), + patch("nncf.Dataset", side_effect=consume_dataset), + patch("nncf.QuantizationPreset") as mock_preset, + patch("openvino.save_model"), + patch.object(Path, "exists", return_value=True), + patch("shutil.copy2"), + patch.object(converter, "copy_readme"), + patch.object(converter, "validate_model", return_value=0.95), + ): + mock_preset.MIXED = "mixed" + converter.quantize_model( + model_path=model_path, + calibration_data=calibration_data, + model_config=sample_model_config, + validation_data=validation_data, + validation_labels=validation_labels, + ) + + def test_quantization_runtime_error(self, converter, sample_model_config, tmp_path): + """quantize_model handles runtime errors gracefully.""" + model_path = tmp_path / "model.xml" + calibration_data = [np.zeros((1, 3, 224, 224))] + + with patch("openvino.Core", side_effect=RuntimeError("OV error")): + result = converter.quantize_model( + model_path=model_path, + calibration_data=calibration_data, + model_config=sample_model_config, + ) + + assert result == model_path + + +class TestExportToOpenvino: + """Tests for ModelConverter.export_to_openvino.""" + + def test_successful_export(self, converter, sample_model_config, tmp_path): + """export_to_openvino exports model to OV format.""" + mock_model = MagicMock(spec=nn.Module) + mock_model.eval.return_value = mock_model + + mock_ov_model = MagicMock() + mock_input = MagicMock() + mock_input.get_names.return_value = {"input"} + mock_ov_model.input.return_value = mock_input + mock_ov_model.inputs = [mock_input] + mock_ov_model.outputs = [MagicMock()] + mock_ov_model.get_rt_info.return_value = MagicMock(value={"model_type": "Classification"}) + + output_path = converter.output_dir / "test_model" + + with ( + patch("openvino.convert_model", return_value=mock_ov_model), + patch("openvino.save_model"), + patch.object(Path, "exists", return_value=True), + patch("shutil.copy2"), + patch.object(converter, "copy_readme"), + ): + fp16_path, _fp32_path = converter.export_to_openvino( + model=mock_model, + input_shape=[1, 3, 224, 224], + output_path=output_path, + model_config=sample_model_config, + input_names=["input"], + output_names=["result"], + metadata={("model_info", "model_type"): "Classification"}, + ) + + assert "fp16" in str(fp16_path.parent) or fp16_path.name == "test_model.xml" + + def test_export_failure(self, converter, sample_model_config, tmp_path): + """export_to_openvino raises on conversion failure.""" + mock_model = MagicMock(spec=nn.Module) + mock_model.eval.return_value = mock_model + + output_path = converter.output_dir / "test_model" + + with ( + patch("openvino.convert_model", side_effect=RuntimeError("Conversion failed")), + pytest.raises(RuntimeError, match="Conversion failed"), + ): + converter.export_to_openvino( + model=mock_model, + input_shape=[1, 3, 224, 224], + output_path=output_path, + model_config=sample_model_config, + ) + + +class TestPrepareModelForExport: + """Tests for ModelConverter._prepare_model_for_export.""" + + def test_with_adapter(self, converter): + """_prepare_model_for_export applies adapter for known model type.""" + mock_model = MagicMock(spec=nn.Module) + config = {"model_type": "MaskRCNN"} + + with patch("model_converter.cli.get_adapter") as mock_get_adapter: + mock_adapted = MagicMock() + mock_get_adapter.return_value = mock_adapted + result = converter._prepare_model_for_export(mock_model, config) + + assert result is mock_adapted + + def test_without_adapter(self, converter): + """_prepare_model_for_export returns model unchanged for no adapter.""" + mock_model = MagicMock(spec=nn.Module) + config = {"model_type": "Classification"} + + with patch("model_converter.cli.get_adapter", return_value=mock_model): + result = converter._prepare_model_for_export(mock_model, config) + + assert result is mock_model + + +class TestCreateExampleInput: + """Tests for ModelConverter._create_example_input.""" + + def test_maskrcnn_input(self, converter): + """_create_example_input uses rand for maskrcnn.""" + config = {"model_type": "MaskRCNN"} + result = converter._create_example_input([1, 3, 224, 224], config) + assert result.shape == (1, 3, 224, 224) + assert result.min() >= 0 # rand produces [0, 1) + + def test_default_input(self, converter): + """_create_example_input uses randn for non-maskrcnn.""" + config = {"model_type": "Classification"} + result = converter._create_example_input([1, 3, 224, 224], config) + assert result.shape == (1, 3, 224, 224) + + +class TestPostprocessOpenvinoModel: + """Tests for ModelConverter._postprocess_openvino_model.""" + + def test_set_input_names(self, converter, mock_ov_model): + """_postprocess_openvino_model sets input tensor names.""" + converter._postprocess_openvino_model( + mock_ov_model, + input_names=["images"], + ) + mock_ov_model.input(0).set_names.assert_called_with({"images"}) + + def test_set_output_names(self, converter, mock_ov_model): + """_postprocess_openvino_model sets output tensor names.""" + converter._postprocess_openvino_model( + mock_ov_model, + output_names=["predictions"], + ) + mock_ov_model.output(0).set_names.assert_called_with({"predictions"}) + + def test_set_metadata(self, converter, mock_ov_model): + """_postprocess_openvino_model adds metadata.""" + metadata = { + ("model_info", "model_type"): "Classification", + ("model_info", "labels"): "cat dog", + } + converter._postprocess_openvino_model(mock_ov_model, metadata=metadata) + assert mock_ov_model.set_rt_info.call_count == 2 + + def test_no_operations(self, converter, mock_ov_model): + """_postprocess_openvino_model handles None params.""" + result = converter._postprocess_openvino_model(mock_ov_model) + assert result is mock_ov_model + + +class TestLoadModelFromConfig: + """Tests for ModelConverter._load_model_from_config.""" + + def test_huggingface_path(self, converter): + """_load_model_from_config loads from HuggingFace.""" + config = { + "huggingface_repo": "timm/resnet50", + "huggingface_revision": "abc123", + "model_library": "timm", + } + mock_model = MagicMock() + with patch.object(converter, "load_huggingface_model", return_value=mock_model): + result = converter._load_model_from_config(config) + assert result is mock_model + + def test_huggingface_missing_revision(self, converter): + """_load_model_from_config raises when HF revision is missing.""" + config = { + "huggingface_repo": "timm/resnet50", + } + with pytest.raises(ValueError, match="huggingface_revision"): + converter._load_model_from_config(config) + + def test_url_path(self, converter, tmp_path): + """_load_model_from_config loads from URL.""" + config = { + "weights_url": "https://example.com/weights.pth", + "model_class_name": "torch.nn.Module", + } + + mock_model = MagicMock(spec=nn.Module) + mock_model.eval.return_value = mock_model + + with ( + patch.object(converter._url_downloader, "download", return_value=tmp_path / "weights.pth"), + patch.object(converter, "load_model_class", return_value=torch.nn.Module), + patch.object(converter, "load_checkpoint", return_value={"model": mock_model}), + patch.object(converter, "create_model", return_value=mock_model), + ): + result = converter._load_model_from_config(config) + + assert result is mock_model + + def test_url_path_default_class(self, converter, tmp_path): + """_load_model_from_config uses torch.nn.Module as default class.""" + config = { + "weights_url": "https://example.com/weights.pth", + } + + mock_model = MagicMock(spec=nn.Module) + + with ( + patch.object(converter._url_downloader, "download", return_value=tmp_path / "weights.pth"), + patch.object(converter, "load_model_class", return_value=torch.nn.Module) as mock_load_class, + patch.object(converter, "load_checkpoint", return_value={}), + patch.object(converter, "create_model", return_value=mock_model), + ): + converter._load_model_from_config(config) + + mock_load_class.assert_called_once_with("torch.nn.Module") + + +class TestQuantizeAndCleanup: + """Tests for ModelConverter._quantize_and_cleanup.""" + + def test_with_classification_labels(self, converter, tmp_path, sample_model_config): + """_quantize_and_cleanup runs validation for classification with labels.""" + fp32_path = tmp_path / "model_fp32.xml" + fp32_path.write_text("") + fp32_bin = tmp_path / "model_fp32.bin" + fp32_bin.write_text("weights") + + validation_data = [np.zeros((1, 3, 224, 224))] + validation_labels = [0] + + with ( + patch.object( + converter, + "create_calibration_dataset", + return_value=(validation_data, validation_labels), + ), + patch.object(converter, "quantize_model"), + ): + converter._quantize_and_cleanup( + sample_model_config, + fp32_path, + model_type="Classification", + input_shape=[1, 3, 224, 224], + mean_values="123.675 116.28 103.53", + scale_values="58.395 57.12 57.375", + reverse_input_channels=True, + ) + + # FP32 files should be cleaned up + assert not fp32_path.exists() + assert not fp32_bin.exists() + + def test_without_classification_labels(self, converter, tmp_path, sample_model_config): + """_quantize_and_cleanup skips validation for non-classification.""" + fp32_path = tmp_path / "model_fp32.xml" + fp32_path.write_text("") + + config = {**sample_model_config, "labels": None} + validation_data = [np.zeros((1, 3, 224, 224))] + + with ( + patch.object(converter, "create_calibration_dataset", return_value=(validation_data, [])), + patch.object(converter, "quantize_model") as mock_quantize, + ): + converter._quantize_and_cleanup( + config, + fp32_path, + model_type="Detection", + input_shape=[1, 3, 224, 224], + mean_values="123.675 116.28 103.53", + scale_values="58.395 57.12 57.375", + reverse_input_channels=True, + ) + + # Quantize should be called with no validation data/labels + mock_quantize.assert_called_once() + call_kwargs = mock_quantize.call_args[1] + assert call_kwargs["validation_data"] is None + assert call_kwargs["validation_labels"] is None + + def test_empty_calibration_data(self, converter, tmp_path, sample_model_config): + """_quantize_and_cleanup skips quantization when no data.""" + fp32_path = tmp_path / "model_fp32.xml" + fp32_path.write_text("") + + with ( + patch.object(converter, "create_calibration_dataset", return_value=([], [])), + patch.object(converter, "quantize_model") as mock_quantize, + ): + converter._quantize_and_cleanup( + sample_model_config, + fp32_path, + model_type="Classification", + input_shape=[1, 3, 224, 224], + mean_values="123.675 116.28 103.53", + scale_values="58.395 57.12 57.375", + reverse_input_channels=True, + ) + + mock_quantize.assert_not_called() + + def test_cleanup_failure(self, converter, tmp_path, sample_model_config): + """_quantize_and_cleanup handles cleanup failure gracefully.""" + fp32_path = tmp_path / "model_fp32.xml" + fp32_path.write_text("") + + with ( + patch.object(converter, "create_calibration_dataset", return_value=([], [])), + patch.object(Path, "exists", return_value=True), + patch.object(Path, "unlink", side_effect=OSError("Permission denied")), + ): + # Should not raise + converter._quantize_and_cleanup( + sample_model_config, + fp32_path, + model_type="Classification", + input_shape=[1, 3, 224, 224], + mean_values="123.675 116.28 103.53", + scale_values="58.395 57.12 57.375", + reverse_input_channels=True, + ) + + +class TestProcessModelConfig: + """Tests for ModelConverter.process_model_config.""" + + def test_already_exists(self, converter, sample_model_config): + """process_model_config skips when both models already exist.""" + # Create existing model files + fp16_dir = converter.output_dir / "test_model-fp16-ov" + fp16_dir.mkdir(parents=True) + (fp16_dir / "test_model.xml").write_text("") + + int8_dir = converter.output_dir / "test_model-int8-ov" + int8_dir.mkdir(parents=True) + (int8_dir / "test_model.xml").write_text("") + + result = converter.process_model_config(sample_model_config) + assert result is True + + def test_missing_license(self, converter): + """process_model_config fails when license is missing.""" + config = {"model_short_name": "test", "license_link": "https://example.com"} + result = converter.process_model_config(config) + assert result is False + + def test_missing_license_link(self, converter): + """process_model_config fails when license_link is missing.""" + config = {"model_short_name": "test", "license": "MIT"} + result = converter.process_model_config(config) + assert result is False + + def test_successful_conversion(self, converter, sample_model_config): + """process_model_config successfully converts a model.""" + mock_model = MagicMock(spec=nn.Module) + mock_model.eval.return_value = mock_model + + fp16_path = converter.output_dir / "test_model-fp16-ov" / "test_model.xml" + fp32_path = converter.output_dir / "test_model-fp16-ov" / "test_model_fp32.xml" + + with ( + patch.object(converter, "_load_model_from_config", return_value=mock_model), + patch.object(converter, "get_labels", return_value="cat dog"), + patch.object(converter, "export_to_openvino", return_value=(fp16_path, fp32_path)), + ): + result = converter.process_model_config(sample_model_config) + + assert result is True + + def test_successful_conversion_with_dataset(self, tmp_path, sample_model_config, dataset_dir): + """process_model_config quantizes when dataset is available.""" + conv = ModelConverter( + output_dir=tmp_path / "out", + cache_dir=tmp_path / "cache", + dataset_path=dataset_dir, + ) + mock_model = MagicMock(spec=nn.Module) + mock_model.eval.return_value = mock_model + + fp16_path = conv.output_dir / "test_model-fp16-ov" / "test_model.xml" + fp32_path = conv.output_dir / "test_model-fp16-ov" / "test_model_fp32.xml" + + with ( + patch.object(conv, "_load_model_from_config", return_value=mock_model), + patch.object(conv, "get_labels", return_value="cat dog"), + patch.object(conv, "export_to_openvino", return_value=(fp16_path, fp32_path)), + patch.object(conv, "_quantize_and_cleanup"), + ): + result = conv.process_model_config(sample_model_config) + + assert result is True + + def test_conversion_failure(self, converter, sample_model_config): + """process_model_config returns False on failure.""" + with patch.object(converter, "_load_model_from_config", side_effect=RuntimeError("load failed")): + result = converter.process_model_config(sample_model_config) + + assert result is False + + def test_no_labels_configured(self, converter): + """process_model_config works without labels in config.""" + config = { + "model_short_name": "test_model", + "license": "Apache-2.0", + "license_link": "https://apache.org", + "weights_url": "https://example.com/weights.pth", + "model_class_name": "torch.nn.Module", + "input_shape": [1, 3, 224, 224], + "model_type": "Classification", + } + + mock_model = MagicMock(spec=nn.Module) + fp16_path = converter.output_dir / "test_model-fp16-ov" / "test_model.xml" + fp32_path = converter.output_dir / "test_model-fp16-ov" / "test_model_fp32.xml" + + with ( + patch.object(converter, "_load_model_from_config", return_value=mock_model), + patch.object(converter, "export_to_openvino", return_value=(fp16_path, fp32_path)), + ): + result = converter.process_model_config(config) + + assert result is True + + def test_labels_not_found(self, converter, sample_model_config): + """process_model_config handles unknown label set.""" + mock_model = MagicMock(spec=nn.Module) + fp16_path = converter.output_dir / "test_model-fp16-ov" / "test_model.xml" + fp32_path = converter.output_dir / "test_model-fp16-ov" / "test_model_fp32.xml" + + with ( + patch.object(converter, "_load_model_from_config", return_value=mock_model), + patch.object(converter, "get_labels", return_value=None), + patch.object(converter, "export_to_openvino", return_value=(fp16_path, fp32_path)), + ): + result = converter.process_model_config(sample_model_config) + + assert result is True + + def test_metadata_fields(self, converter): + """process_model_config includes optional metadata fields.""" + config = { + "model_short_name": "test_model", + "license": "Apache-2.0", + "license_link": "https://apache.org", + "weights_url": "https://example.com/weights.pth", + "input_shape": [1, 3, 224, 224], + "model_type": "Detection", + "confidence_threshold": "0.5", + "iou_threshold": "0.45", + "resize_type": "standard", + } + + mock_model = MagicMock(spec=nn.Module) + fp16_path = converter.output_dir / "test_model-fp16-ov" / "test_model.xml" + fp32_path = converter.output_dir / "test_model-fp16-ov" / "test_model_fp32.xml" + + with ( + patch.object(converter, "_load_model_from_config", return_value=mock_model), + patch.object(converter, "export_to_openvino", return_value=(fp16_path, fp32_path)) as mock_export, + ): + converter.process_model_config(config) + + # Check metadata was passed + call_kwargs = mock_export.call_args[1] + metadata = call_kwargs["metadata"] + assert ("model_info", "confidence_threshold") in metadata + assert ("model_info", "iou_threshold") in metadata + + +class TestMetadataValue: + """Tests for ModelConverter._metadata_value.""" + + def test_string(self): + """_metadata_value converts string to string.""" + assert ModelConverter._metadata_value("hello") == "hello" + + def test_integer(self): + """_metadata_value converts int to string.""" + assert ModelConverter._metadata_value(42) == "42" + + def test_float(self): + """_metadata_value converts float to string.""" + assert ModelConverter._metadata_value(0.5) == "0.5" + + def test_boolean(self): + """_metadata_value converts bool to string.""" + assert ModelConverter._metadata_value(True) == "True" + + def test_list(self): + """_metadata_value joins list with spaces.""" + assert ModelConverter._metadata_value([1, 2, 3]) == "1 2 3" + + def test_tuple(self): + """_metadata_value joins tuple with spaces.""" + assert ModelConverter._metadata_value(("a", "b")) == "a b" + + +class TestProcessConfigFile: + """Tests for ModelConverter.process_config_file.""" + + def test_multiple_models(self, converter, tmp_path): + """process_config_file processes multiple models.""" + config = { + "models": [ + { + "model_short_name": "model1", + "license": "MIT", + "license_link": "https://mit.edu", + "weights_url": "https://example.com/1.pth", + }, + { + "model_short_name": "model2", + "license": "MIT", + "license_link": "https://mit.edu", + "weights_url": "https://example.com/2.pth", + }, + ], + } + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + + with patch.object(converter, "process_model_config", return_value=True) as mock_process: + successful, failed = converter.process_config_file(config_path) + + assert successful == 2 + assert failed == 0 + assert mock_process.call_count == 2 + + def test_filter_match(self, converter, tmp_path): + """process_config_file filters to specific model.""" + config = { + "models": [ + {"model_short_name": "model1", "license": "MIT", "license_link": "x"}, + {"model_short_name": "model2", "license": "MIT", "license_link": "x"}, + ], + } + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + + with patch.object(converter, "process_model_config", return_value=True) as mock_process: + successful, _failed = converter.process_config_file(config_path, model_filter="model2") + + assert successful == 1 + mock_process.assert_called_once() + + def test_filter_no_match(self, converter, tmp_path): + """process_config_file returns 0,0 when filter doesn't match.""" + config = { + "models": [ + {"model_short_name": "model1"}, + ], + } + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + + successful, failed = converter.process_config_file(config_path, model_filter="nonexistent") + assert successful == 0 + assert failed == 0 + + def test_empty_models(self, converter, tmp_path): + """process_config_file handles empty models list.""" + config = {"models": []} + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + + successful, failed = converter.process_config_file(config_path) + assert successful == 0 + assert failed == 0 + + def test_invalid_json(self, converter, tmp_path): + """process_config_file raises on invalid JSON.""" + config_path = tmp_path / "bad.json" + config_path.write_text("not valid json {{{") + + with pytest.raises(json.JSONDecodeError): + converter.process_config_file(config_path) + + def test_model_failure(self, converter, tmp_path): + """process_config_file counts failed models.""" + config = { + "models": [ + {"model_short_name": "model1", "license": "MIT", "license_link": "x"}, + ], + } + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + + with patch.object(converter, "process_model_config", return_value=False): + successful, failed = converter.process_config_file(config_path) + + assert successful == 0 + assert failed == 1 + + +class TestListModels: + """Tests for list_models function.""" + + def test_normal_output(self, tmp_path, capsys): + """list_models prints model information.""" + config = { + "models": [ + { + "model_short_name": "resnet50", + "model_full_name": "ResNet-50", + "model_type": "Classification", + }, + ], + } + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + + list_models(config_path) + + captured = capsys.readouterr() + assert "resnet50" in captured.out + assert "ResNet-50" in captured.out + assert "Classification" in captured.out + + def test_file_not_found(self, tmp_path, capsys): + """list_models handles missing config file.""" + list_models(tmp_path / "nonexistent.json") + + captured = capsys.readouterr() + assert "Error" in captured.err + + def test_empty_models(self, tmp_path, capsys): + """list_models handles empty models list.""" + config = {"models": []} + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + + list_models(config_path) + + captured = capsys.readouterr() + assert "No models found" in captured.out + + def test_invalid_json(self, tmp_path, capsys): + """list_models handles invalid JSON.""" + config_path = tmp_path / "bad.json" + config_path.write_text("not json") + + list_models(config_path) + + captured = capsys.readouterr() + assert "Error" in captured.err + + +class TestMain: + """Tests for main() CLI entry point.""" + + def test_missing_config_file(self, tmp_path, monkeypatch): + """main returns 1 when config file doesn't exist.""" + monkeypatch.setattr(sys, "argv", ["model_converter", str(tmp_path / "nonexistent.json")]) + result = main() + assert result == 1 + + def test_list_flag(self, tmp_path, monkeypatch, capsys): + """main --list flag lists models and exits.""" + config = {"models": [{"model_short_name": "test", "model_type": "cls"}]} + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + + monkeypatch.setattr(sys, "argv", ["model_converter", str(config_path), "--list"]) + result = main() + assert result == 0 + + captured = capsys.readouterr() + assert "test" in captured.out + + def test_successful_run(self, tmp_path, monkeypatch): + """main runs conversion successfully.""" + config = {"models": [{"model_short_name": "m1", "license": "MIT", "license_link": "x"}]} + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + + monkeypatch.setattr( + sys, + "argv", + [ + "model_converter", + str(config_path), + "-o", + str(tmp_path / "output"), + "-c", + str(tmp_path / "cache"), + "-d", + str(tmp_path / "dataset"), + ], + ) + + with patch.object(ModelConverter, "process_config_file", return_value=(1, 0)): + result = main() + + assert result == 0 + + def test_failed_run(self, tmp_path, monkeypatch): + """main returns 1 when models fail.""" + config = {"models": [{"model_short_name": "m1"}]} + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + + monkeypatch.setattr( + sys, + "argv", + [ + "model_converter", + str(config_path), + "-o", + str(tmp_path / "output"), + "-c", + str(tmp_path / "cache"), + ], + ) + + with patch.object(ModelConverter, "process_config_file", return_value=(0, 1)): + result = main() + + assert result == 1 + + def test_verbose_flag(self, tmp_path, monkeypatch): + """main enables verbose logging with -v flag.""" + config = {"models": []} + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + + monkeypatch.setattr( + sys, + "argv", + [ + "model_converter", + str(config_path), + "-o", + str(tmp_path / "output"), + "-c", + str(tmp_path / "cache"), + "-v", + ], + ) + + with patch.object(ModelConverter, "process_config_file", return_value=(0, 0)): + result = main() + + assert result == 0 + + def test_model_filter(self, tmp_path, monkeypatch): + """main passes --model filter to process_config_file.""" + config = {"models": [{"model_short_name": "target"}]} + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + + monkeypatch.setattr( + sys, + "argv", + [ + "model_converter", + str(config_path), + "-o", + str(tmp_path / "output"), + "-c", + str(tmp_path / "cache"), + "--model", + "target", + ], + ) + + with patch.object(ModelConverter, "process_config_file", return_value=(1, 0)) as mock_process: + main() + + mock_process.assert_called_once_with(config_path=config_path, model_filter="target") + + def test_exception_during_processing(self, tmp_path, monkeypatch): + """main returns 1 on unhandled exception.""" + config = {"models": [{"model_short_name": "m1"}]} + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + + monkeypatch.setattr( + sys, + "argv", + [ + "model_converter", + str(config_path), + "-o", + str(tmp_path / "output"), + "-c", + str(tmp_path / "cache"), + ], + ) + + with patch.object(ModelConverter, "process_config_file", side_effect=ValueError("bad config")): + result = main() + + assert result == 1 diff --git a/model_converter/tests/unit/test_downloaders.py b/model_converter/tests/unit/test_downloaders.py new file mode 100644 index 00000000..7f329b18 --- /dev/null +++ b/model_converter/tests/unit/test_downloaders.py @@ -0,0 +1,171 @@ +# +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +"""Tests for model_converter.downloaders module.""" + +from pathlib import Path +from unittest.mock import patch + +import pytest +from model_converter.downloaders import HuggingFaceDownloader, URLDownloader +from model_converter.downloaders.base import BaseDownloader + + +class TestBaseDownloader: + """Tests for BaseDownloader.""" + + def test_creates_cache_dir(self, tmp_path): + """BaseDownloader creates cache directory on init.""" + cache_dir = tmp_path / "new_cache" / "nested" + downloader = BaseDownloader(cache_dir=cache_dir) + assert downloader.cache_dir.exists() + assert downloader.cache_dir == cache_dir + + def test_existing_cache_dir(self, tmp_path): + """BaseDownloader handles existing cache directory.""" + cache_dir = tmp_path / "existing_cache" + cache_dir.mkdir() + downloader = BaseDownloader(cache_dir=cache_dir) + assert downloader.cache_dir.exists() + + def test_cache_dir_is_path(self, tmp_path): + """BaseDownloader converts cache_dir to Path.""" + cache_dir = tmp_path / "cache" + downloader = BaseDownloader(cache_dir=cache_dir) + assert isinstance(downloader.cache_dir, Path) + + +class TestURLDownloader: + """Tests for URLDownloader.""" + + def test_cache_hit(self, tmp_path): + """URLDownloader returns cached file if it exists.""" + cache_dir = tmp_path / "cache" + cache_dir.mkdir() + cached_file = cache_dir / "weights.pth" + cached_file.write_text("dummy weights") + + downloader = URLDownloader(cache_dir=cache_dir) + result = downloader.download("https://example.com/weights.pth") + + assert result == cached_file + + def test_download_success(self, tmp_path): + """URLDownloader downloads file and returns path.""" + cache_dir = tmp_path / "cache" + downloader = URLDownloader(cache_dir=cache_dir) + + with patch("urllib.request.urlretrieve") as mock_urlretrieve: + mock_urlretrieve.side_effect = lambda url, path: Path(path).write_text("data") + result = downloader.download("https://example.com/model_weights.pth") + + assert result == cache_dir / "model_weights.pth" + mock_urlretrieve.assert_called_once_with( + "https://example.com/model_weights.pth", + cache_dir / "model_weights.pth", + ) + + def test_download_with_custom_filename(self, tmp_path): + """URLDownloader uses custom filename when provided.""" + cache_dir = tmp_path / "cache" + downloader = URLDownloader(cache_dir=cache_dir) + + with patch("urllib.request.urlretrieve") as mock_urlretrieve: + mock_urlretrieve.side_effect = lambda url, path: Path(path).write_text("data") + result = downloader.download("https://example.com/v1/download", filename="custom.pth") + + assert result == cache_dir / "custom.pth" + + def test_download_failure(self, tmp_path): + """URLDownloader raises on download failure.""" + cache_dir = tmp_path / "cache" + downloader = URLDownloader(cache_dir=cache_dir) + + with patch("urllib.request.urlretrieve") as mock_urlretrieve: + mock_urlretrieve.side_effect = OSError("Connection refused") + with pytest.raises(OSError, match="Connection refused"): + downloader.download("https://example.com/weights.pth") + + def test_filename_extracted_from_url(self, tmp_path): + """URLDownloader extracts filename from URL when not provided.""" + cache_dir = tmp_path / "cache" + downloader = URLDownloader(cache_dir=cache_dir) + + with patch("urllib.request.urlretrieve") as mock_urlretrieve: + mock_urlretrieve.side_effect = lambda url, path: Path(path).write_text("data") + result = downloader.download("https://example.com/path/to/resnet50.pth") + + assert result.name == "resnet50.pth" + + +class TestHuggingFaceDownloader: + """Tests for HuggingFaceDownloader.""" + + def test_download_single_file(self, tmp_path): + """HuggingFaceDownloader downloads a single file.""" + cache_dir = tmp_path / "cache" + downloader = HuggingFaceDownloader(cache_dir=cache_dir) + + with patch("model_converter.downloaders.huggingface.hf_hub_download") as mock_hf: + mock_hf.return_value = str(cache_dir / "model.safetensors") + result = downloader.download( + repo_id="timm/resnet50", + revision="abc123", + filename="model.safetensors", + ) + + assert result == cache_dir / "model.safetensors" + mock_hf.assert_called_once_with( + repo_id="timm/resnet50", + revision="abc123", + filename="model.safetensors", + cache_dir=cache_dir, + ) + + def test_download_snapshot(self, tmp_path): + """HuggingFaceDownloader downloads full repository.""" + cache_dir = tmp_path / "cache" + downloader = HuggingFaceDownloader(cache_dir=cache_dir) + + with patch("model_converter.downloaders.huggingface.snapshot_download") as mock_snap: + mock_snap.return_value = str(cache_dir / "repo_snapshot") + result = downloader.download( + repo_id="timm/resnet50", + revision="abc123", + ) + + assert result == cache_dir / "repo_snapshot" + mock_snap.assert_called_once_with( + repo_id="timm/resnet50", + revision="abc123", + cache_dir=cache_dir, + ) + + def test_download_failure(self, tmp_path): + """HuggingFaceDownloader raises on download failure.""" + cache_dir = tmp_path / "cache" + downloader = HuggingFaceDownloader(cache_dir=cache_dir) + + with patch("model_converter.downloaders.huggingface.hf_hub_download") as mock_hf: + mock_hf.side_effect = Exception("Repository not found") + with pytest.raises(Exception, match="Repository not found"): + downloader.download( + repo_id="nonexistent/model", + revision="abc123", + filename="weights.bin", + ) + + def test_download_snapshot_failure(self, tmp_path): + """HuggingFaceDownloader raises on snapshot download failure.""" + cache_dir = tmp_path / "cache" + downloader = HuggingFaceDownloader(cache_dir=cache_dir) + + with patch("model_converter.downloaders.huggingface.snapshot_download") as mock_snap: + mock_snap.side_effect = Exception("Network error") + with pytest.raises(Exception, match="Network error"): + downloader.download( + repo_id="timm/resnet50", + revision="abc123", + ) diff --git a/model_converter/tests/unit/test_entrypoints.py b/model_converter/tests/unit/test_entrypoints.py new file mode 100644 index 00000000..d154d208 --- /dev/null +++ b/model_converter/tests/unit/test_entrypoints.py @@ -0,0 +1,70 @@ +# +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +"""Tests for model_converter entry points.""" + +import runpy +import sys +from unittest.mock import patch + + +class TestInitExports: + """Tests for model_converter.__init__ exports.""" + + def test_exports_model_converter(self): + """__init__ exports ModelConverter class.""" + from model_converter import ModelConverter + + assert ModelConverter is not None + + def test_exports_list_models(self): + """__init__ exports list_models function.""" + from model_converter import list_models + + assert callable(list_models) + + def test_exports_main(self): + """__init__ exports main function.""" + from model_converter import main + + assert callable(main) + + def test_all_attribute(self): + """__init__ defines __all__ correctly.""" + import model_converter + + assert "ModelConverter" in model_converter.__all__ + assert "list_models" in model_converter.__all__ + assert "main" in model_converter.__all__ + + +class TestMainModule: + """Tests for model_converter.__main__.""" + + def test_calls_main(self): + """__main__ calls cli.main() and sys.exit.""" + with ( + patch("model_converter.cli.main", return_value=0) as mock_main, + patch.object(sys, "exit") as mock_exit, + ): + runpy.run_module("model_converter", run_name="__main__", alter_sys=True) + + mock_main.assert_called_once() + mock_exit.assert_called_once_with(0) + + +class TestModelConverterScript: + """Tests for model_converter.model_converter legacy entry point.""" + + def test_calls_main_when_run(self): + """model_converter.py calls main() when run as __main__.""" + with ( + patch("model_converter.cli.main", return_value=0) as mock_main, + patch.object(sys, "exit") as mock_exit, + ): + runpy.run_module("model_converter.model_converter", run_name="__main__", alter_sys=True) + + mock_main.assert_called_once() + mock_exit.assert_called_once_with(0) diff --git a/model_converter/tests/unit/test_yolo.py b/model_converter/tests/unit/test_yolo.py new file mode 100644 index 00000000..b6d48ccc --- /dev/null +++ b/model_converter/tests/unit/test_yolo.py @@ -0,0 +1,172 @@ +# +# Copyright (C) 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +"""Tests for model_converter.yolo module.""" + +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + + +class TestCopyReadmeTemplate: + """Tests for copy_readme_template function.""" + + def test_copies_and_replaces_placeholder(self, tmp_path): + """copy_readme_template replaces <> in template.""" + from model_converter.yolo.yolo import copy_readme_template + + # Create a template in the actual templates dir (we need to mock it) + template_content = "# YOLO11<> Model\nSize: <>" + dest_dir = tmp_path / "output" + dest_dir.mkdir() + + with patch.object(Path, "read_text", return_value=template_content): + copy_readme_template("README-yolo-fp16.md", dest_dir, "n") + + readme = dest_dir / "README.md" + assert readme.exists() + content = readme.read_text() + assert "YOLO11n Model" in content + assert "Size: n" in content + assert "<>" not in content + + +class TestUpdateModelTypeInXml: + """Tests for update_model_type_in_xml function.""" + + def test_updates_model_type(self, tmp_path): + """update_model_type_in_xml updates model_type value in XML.""" + from model_converter.yolo.yolo import update_model_type_in_xml + + # Create a valid OpenVINO XML structure + xml_content = """ + + + + + + + +""" + xml_path = tmp_path / "model.xml" + xml_path.write_text(xml_content) + + update_model_type_in_xml(xml_path, "YOLO11") + + # Verify XML was updated + from defusedxml.ElementTree import parse + + tree = parse(xml_path) + root = tree.getroot() + model_type_elem = root.find(".//rt_info/model_info/model_type") + assert model_type_elem is not None + assert model_type_elem.get("value") == "YOLO11" + + # Verify config.json was created + config_json = tmp_path / "config.json" + assert config_json.exists() + config = json.loads(config_json.read_text()) + assert config["model_type"] == "YOLO11" + assert config["labels"] == "person car" + + def test_handles_parse_error(self, tmp_path, capsys): + """update_model_type_in_xml handles invalid XML gracefully.""" + from model_converter.yolo.yolo import update_model_type_in_xml + + xml_path = tmp_path / "bad.xml" + xml_path.write_text("not valid xml <<<>>>") + + # Should not raise + update_model_type_in_xml(xml_path, "YOLO11") + + captured = capsys.readouterr() + assert "Failed to update" in captured.out + + def test_handles_file_not_found(self, tmp_path, capsys): + """update_model_type_in_xml handles missing file gracefully.""" + from model_converter.yolo.yolo import update_model_type_in_xml + + xml_path = tmp_path / "nonexistent.xml" + + update_model_type_in_xml(xml_path, "YOLO11") + + captured = capsys.readouterr() + assert "Failed to update" in captured.out + + +class TestConvertYoloModels: + """Tests for convert_yolo_models function.""" + + @patch("model_converter.yolo.yolo.copy_readme_template") + @patch("model_converter.yolo.yolo.update_model_type_in_xml") + @patch("model_converter.yolo.yolo.YOLO") + @patch("shutil.rmtree") + def test_convert_single_model(self, mock_rmtree, mock_yolo_class, mock_update_xml, mock_copy_readme, tmp_path): + """convert_yolo_models converts a single YOLO variant.""" + from model_converter.yolo.yolo import convert_yolo_models + + mock_model = MagicMock() + mock_yolo_class.return_value = mock_model + + # Mock that the output folders exist after export + with ( + patch("model_converter.yolo.yolo.Path"), + patch.object(Path, "exists", return_value=False), + ): + convert_yolo_models(["yolo11n"]) + + mock_yolo_class.assert_called_once_with("yolo11n.pt") + assert mock_model.export.call_count == 2 + + @patch("model_converter.yolo.yolo.copy_readme_template") + @patch("model_converter.yolo.yolo.update_model_type_in_xml") + @patch("model_converter.yolo.yolo.YOLO") + def test_convert_with_existing_output( + self, + mock_yolo_class, + mock_update_xml, + mock_copy_readme, + tmp_path, + monkeypatch, + ): + """convert_yolo_models handles existing output directories.""" + from model_converter.yolo.yolo import convert_yolo_models + + monkeypatch.chdir(tmp_path) + + mock_model = MagicMock() + mock_yolo_class.return_value = mock_model + + # Create the output dirs that YOLO export would create + fp16_dir = tmp_path / "yolo11n_openvino_model" + fp16_dir.mkdir() + (fp16_dir / "yolo11n.xml").write_text("") + + int8_dir = tmp_path / "yolo11n_int8_openvino_model" + int8_dir.mkdir() + (int8_dir / "yolo11n.xml").write_text("") + + # Create target dirs that already exist (to test rmtree path) + (tmp_path / "YOLO11n-fp16-ov").mkdir() + (tmp_path / "YOLO11n-int8-ov").mkdir() + + with ( + patch("model_converter.yolo.yolo.update_model_type_in_xml"), + patch("model_converter.yolo.yolo.copy_readme_template"), + ): + convert_yolo_models(["yolo11n"]) + + +class TestYoloMain: + """Tests for yolo main function.""" + + @patch("model_converter.yolo.yolo.convert_yolo_models") + def test_main_returns_zero(self, mock_convert): + """main() calls convert_yolo_models and returns 0.""" + from model_converter.yolo.yolo import main + + result = main() + assert result == 0 + mock_convert.assert_called_once_with()