|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import argparse |
| 4 | +import ast |
| 5 | +import json |
| 6 | +from pathlib import Path |
| 7 | +from typing import Any |
| 8 | + |
| 9 | +import cv2 |
| 10 | +import numpy as np |
| 11 | +import onnxruntime as ort |
| 12 | +from PIL import Image |
| 13 | + |
| 14 | + |
| 15 | +def _safe_literal(value: Any) -> Any: |
| 16 | + if not isinstance(value, str): |
| 17 | + return value |
| 18 | + try: |
| 19 | + return ast.literal_eval(value) |
| 20 | + except Exception: |
| 21 | + return value |
| 22 | + |
| 23 | + |
| 24 | +def _normalize_names(names_value: Any) -> list[str]: |
| 25 | + if isinstance(names_value, dict): |
| 26 | + pairs: list[tuple[int, str]] = [] |
| 27 | + for key, value in names_value.items(): |
| 28 | + try: |
| 29 | + idx = int(key) |
| 30 | + except Exception: |
| 31 | + continue |
| 32 | + pairs.append((idx, str(value))) |
| 33 | + if not pairs: |
| 34 | + return [] |
| 35 | + pairs.sort(key=lambda item: item[0]) |
| 36 | + max_index = pairs[-1][0] |
| 37 | + out = [f"class_{i}" for i in range(max_index + 1)] |
| 38 | + for idx, name in pairs: |
| 39 | + out[idx] = name |
| 40 | + return out |
| 41 | + if isinstance(names_value, (list, tuple)): |
| 42 | + return [str(x) for x in names_value] |
| 43 | + return [] |
| 44 | + |
| 45 | + |
| 46 | +def _parse_imgsz(imgsz_value: Any, input_shape: list[Any]) -> list[int]: |
| 47 | + if isinstance(imgsz_value, (list, tuple)) and len(imgsz_value) == 2: |
| 48 | + try: |
| 49 | + return [int(imgsz_value[0]), int(imgsz_value[1])] |
| 50 | + except Exception: |
| 51 | + pass |
| 52 | + |
| 53 | + if len(input_shape) == 4: |
| 54 | + h = input_shape[2] |
| 55 | + w = input_shape[3] |
| 56 | + if isinstance(h, int) and isinstance(w, int): |
| 57 | + return [h, w] |
| 58 | + return [224, 224] |
| 59 | + |
| 60 | + |
| 61 | +def _bgr_to_chw_float01(arr_bgr: np.ndarray) -> np.ndarray: |
| 62 | + arr_rgb = cv2.cvtColor(arr_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 |
| 63 | + return np.transpose(arr_rgb, (2, 0, 1))[None] |
| 64 | + |
| 65 | + |
| 66 | +def preprocess_cpp_style(arr_bgr: np.ndarray, target_h: int, target_w: int) -> np.ndarray: |
| 67 | + h, w = arr_bgr.shape[:2] |
| 68 | + scale = max(target_w / float(w), target_h / float(h)) |
| 69 | + resized_w = max(target_w, int(np.floor(w * scale))) |
| 70 | + resized_h = max(target_h, int(np.floor(h * scale))) |
| 71 | + interpolation = cv2.INTER_AREA if resized_w < w or resized_h < h else cv2.INTER_LINEAR |
| 72 | + resized = cv2.resize(arr_bgr, (resized_w, resized_h), interpolation=interpolation) |
| 73 | + |
| 74 | + crop_x = max(0, int(np.rint((resized_w - target_w) / 2.0))) |
| 75 | + crop_y = max(0, int(np.rint((resized_h - target_h) / 2.0))) |
| 76 | + cropped = resized[crop_y : crop_y + target_h, crop_x : crop_x + target_w] |
| 77 | + return _bgr_to_chw_float01(cropped) |
| 78 | + |
| 79 | + |
| 80 | +def preprocess_ultralytics_reference(arr_bgr: np.ndarray, target_h: int, target_w: int) -> np.ndarray: |
| 81 | + arr_rgb = cv2.cvtColor(arr_bgr, cv2.COLOR_BGR2RGB) |
| 82 | + img = Image.fromarray(arr_rgb) |
| 83 | + src_w, src_h = img.size |
| 84 | + |
| 85 | + if src_w <= src_h: |
| 86 | + resized_w = target_w |
| 87 | + resized_h = max(1, int(target_w * src_h / src_w)) |
| 88 | + else: |
| 89 | + resized_h = target_h |
| 90 | + resized_w = max(1, int(target_h * src_w / src_h)) |
| 91 | + |
| 92 | + img = img.resize((resized_w, resized_h), resample=Image.BILINEAR) |
| 93 | + crop_x = int(round((resized_w - target_w) / 2.0)) |
| 94 | + crop_y = int(round((resized_h - target_h) / 2.0)) |
| 95 | + img = img.crop((crop_x, crop_y, crop_x + target_w, crop_y + target_h)) |
| 96 | + |
| 97 | + arr = np.asarray(img, dtype=np.float32) / 255.0 |
| 98 | + return np.transpose(arr, (2, 0, 1))[None] |
| 99 | + |
| 100 | + |
| 101 | +def run_probe(session: ort.InferenceSession, input_name: str, tensor: np.ndarray) -> dict[str, Any]: |
| 102 | + output = session.run(None, {input_name: tensor})[0] |
| 103 | + logits = np.asarray(output).reshape(-1).astype(np.float64) |
| 104 | + top1_idx = int(np.argmax(logits)) |
| 105 | + return { |
| 106 | + "shape": list(np.asarray(output).shape), |
| 107 | + "sum": float(logits.sum()), |
| 108 | + "min": float(logits.min()), |
| 109 | + "max": float(logits.max()), |
| 110 | + "top1_index": top1_idx, |
| 111 | + "top1_score": float(logits[top1_idx]), |
| 112 | + } |
| 113 | + |
| 114 | + |
| 115 | +def inspect_model(model_path: Path, image_path: Path | None = None) -> dict[str, Any]: |
| 116 | + session = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) |
| 117 | + input0 = session.get_inputs()[0] |
| 118 | + output0 = session.get_outputs()[0] |
| 119 | + metadata_raw = dict(session.get_modelmeta().custom_metadata_map or {}) |
| 120 | + metadata = {k: _safe_literal(v) for k, v in metadata_raw.items()} |
| 121 | + names = _normalize_names(metadata.get("names")) |
| 122 | + imgsz = _parse_imgsz(metadata.get("imgsz"), list(input0.shape)) |
| 123 | + |
| 124 | + result: dict[str, Any] = { |
| 125 | + "model_path": str(model_path.resolve()), |
| 126 | + "io": { |
| 127 | + "input_name": input0.name, |
| 128 | + "input_shape": list(input0.shape), |
| 129 | + "input_type": input0.type, |
| 130 | + "output_name": output0.name, |
| 131 | + "output_shape": list(output0.shape), |
| 132 | + "output_type": output0.type, |
| 133 | + }, |
| 134 | + "metadata_raw": metadata_raw, |
| 135 | + "metadata_parsed": metadata, |
| 136 | + "cpp_recommended_config": { |
| 137 | + "task": metadata.get("task", "classify"), |
| 138 | + "imgsz_hw": imgsz, |
| 139 | + "layout": "NCHW", |
| 140 | + "color_order": "RGB", |
| 141 | + "pixel_range": "[0, 1]", |
| 142 | + "normalize_mean": [0.0, 0.0, 0.0], |
| 143 | + "normalize_std": [1.0, 1.0, 1.0], |
| 144 | + "resize_rule": "short_edge_to_target_then_center_crop", |
| 145 | + "resize_long_edge_rounding": "floor(int(target * long / short))", |
| 146 | + "center_crop_rounding": "round((resized - target) / 2)", |
| 147 | + "interpolation_hint": "PIL.BILINEAR in YOLO; C++ approximate: INTER_AREA when downsample else INTER_LINEAR", |
| 148 | + "softmax_hint": "Do not add softmax again if model output already sums close to 1.", |
| 149 | + "class_names": names, |
| 150 | + }, |
| 151 | + } |
| 152 | + |
| 153 | + if image_path is not None: |
| 154 | + buf = np.fromfile(str(image_path), dtype=np.uint8) |
| 155 | + arr_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR) |
| 156 | + if arr_bgr is None: |
| 157 | + raise RuntimeError(f"Failed to read image: {image_path}") |
| 158 | + |
| 159 | + target_h, target_w = imgsz |
| 160 | + cpp_tensor = preprocess_cpp_style(arr_bgr, target_h, target_w) |
| 161 | + yolo_tensor = preprocess_ultralytics_reference(arr_bgr, target_h, target_w) |
| 162 | + probe_cpp = run_probe(session, input0.name, cpp_tensor) |
| 163 | + probe_yolo = run_probe(session, input0.name, yolo_tensor) |
| 164 | + |
| 165 | + result["probe"] = { |
| 166 | + "image_path": str(image_path.resolve()), |
| 167 | + "cpp_style": probe_cpp, |
| 168 | + "ultralytics_reference_style": probe_yolo, |
| 169 | + "top1_same": probe_cpp["top1_index"] == probe_yolo["top1_index"], |
| 170 | + "top1_score_abs_diff": abs(probe_cpp["top1_score"] - probe_yolo["top1_score"]), |
| 171 | + } |
| 172 | + |
| 173 | + return result |
| 174 | + |
| 175 | + |
| 176 | +def main() -> None: |
| 177 | + parser = argparse.ArgumentParser( |
| 178 | + description="Inspect Ultralytics ONNX metadata and export C++ inference parameters." |
| 179 | + ) |
| 180 | + parser.add_argument( |
| 181 | + "--model", |
| 182 | + type=Path, |
| 183 | + default=Path("models/cat_vs_dog/best.onnx"), |
| 184 | + help="Path to ONNX model.", |
| 185 | + ) |
| 186 | + parser.add_argument( |
| 187 | + "--image", |
| 188 | + type=Path, |
| 189 | + default=None, |
| 190 | + help="Optional image path for probe inference comparison.", |
| 191 | + ) |
| 192 | + parser.add_argument( |
| 193 | + "--out", |
| 194 | + type=Path, |
| 195 | + default=None, |
| 196 | + help="Output JSON path (default: <model>.infer_params.json).", |
| 197 | + ) |
| 198 | + parser.add_argument( |
| 199 | + "--print-only", |
| 200 | + action="store_true", |
| 201 | + help="Only print JSON, do not write file.", |
| 202 | + ) |
| 203 | + args = parser.parse_args() |
| 204 | + |
| 205 | + payload = inspect_model(args.model, args.image) |
| 206 | + output_text = json.dumps(payload, ensure_ascii=False, indent=2) |
| 207 | + print(output_text) |
| 208 | + |
| 209 | + if args.print_only: |
| 210 | + return |
| 211 | + |
| 212 | + out_path = args.out or args.model.with_suffix(".infer_params.json") |
| 213 | + out_path.parent.mkdir(parents=True, exist_ok=True) |
| 214 | + out_path.write_text(output_text + "\n", encoding="utf-8") |
| 215 | + print(f"\nSaved: {out_path.resolve()}") |
| 216 | + |
| 217 | + |
| 218 | +if __name__ == "__main__": |
| 219 | + main() |
0 commit comments