Skip to content

Commit 1eb2637

Browse files
committed
feat(infer): 对齐 YOLO 分类推理链路并提升 C++ 置信度一致性
- 调整 C++ 分类前处理为更接近 Ultralytics YOLO 默认流程: short-edge resize、中心裁剪取整、下采样插值策略(AREA/LINEAR)对齐。 - 保持 ONNX 输出直读(不额外 softmax),在现有模型上提升了 C++ 侧置信度表现。 - 新增 ONNX 参数检查脚本 `py/inspect_onnx_for_cpp.py`: 可导出 `task/imgsz/names/args` 等信息到 `*.infer_params.json`,用于对齐与排查。 - 新增 C++ 侧模型类别名能力: 增加 `modelClassNames`,支持从 ONNX metadata `names` 解析类别名并作为回退来源。 - 明确运行时行为: C++ 不读取 `best.infer_params.json`,该文件仅用于离线核对。 - 更新 README,补充参数对齐说明与脚本使用示例。
1 parent 74a3881 commit 1eb2637

6 files changed

Lines changed: 509 additions & 12 deletions

File tree

README.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,54 @@ cmake --build build
117117
点击 **加载文件夹** → 程序会递归收集子文件夹中的图片 → 点击 **批量推理全部**
118118
5. 可点击下方结果列表中的任意项,快速切回对应图片查看结果
119119

120+
## 参数对齐记录(2026-04-25)
121+
122+
本次主要做了“Python YOLO 推理”和“C++ ORT 推理”的参数对齐,重点如下:
123+
124+
- 新增脚本 [`py/inspect_onnx_for_cpp.py`](./py/inspect_onnx_for_cpp.py),可从 ONNX 读取 metadata(`task/imgsz/names/args`)并导出参数文件(如 `best.infer_params.json`)。
125+
- C++ 侧优化了分类前处理细节(短边缩放取整、中心裁剪取整、下采样插值策略),以更接近 Ultralytics 分类默认流程。
126+
- C++ 侧增加了 ONNX metadata 的读取与类别名解析(`names`)。
127+
128+
### `best.infer_params.json` 当前作用
129+
130+
`best.infer_params.json` 目前是**对齐/排查用报告文件**,用于查看模型导出参数并和 C++ 实现核对。
131+
当前 C++ 推理流程**不会自动读取**这个 JSON 文件。
132+
133+
`inspect_onnx_for_cpp.py` 使用示例:
134+
135+
```bash
136+
# 在仓库根目录执行(以猫狗模型为例)
137+
python py/inspect_onnx_for_cpp.py ^
138+
--model models/cat_vs_dog/best.onnx ^
139+
--image assets/val/dog/dog_21.jpg ^
140+
--out models/cat_vs_dog/best.infer_params.json
141+
```
142+
143+
执行后会在终端打印解析结果,并写出 `models/cat_vs_dog/best.infer_params.json`
144+
145+
### C++ 当前类别名查找顺序
146+
147+
在选择 `.onnx` 模型后,C++ 按下面顺序找类别名:
148+
149+
1. 模型同级目录:`labels.txt`,找不到再找 `class_names.txt`
150+
2. ONNX metadata:读取 `names`
151+
3. 兼容回退:若模型目录名是 `cat_vs_dog`,使用 `cat/dog`
152+
4. 以上都没有时,显示 `class_N`
153+
154+
对应代码位置:
155+
156+
- [`MainWindow.cpp`](./src/MainWindow.cpp)`selectModel()`
157+
- [`OnnxClassifier.cpp`](./src/OnnxClassifier.cpp)`loadModel()``modelClassNames()`
158+
159+
### 对猫狗模型和试剂模型的建议
160+
161+
两个模型都建议各自放在独立目录,并在模型目录至少提供以下文件之一:
162+
163+
- `labels.txt`(推荐)
164+
- 或依赖 ONNX metadata 的 `names`
165+
166+
这样 C++ 在切换猫狗模型/试剂模型时会自动加载对应类别名,避免串类。
167+
120168
## 模型来源
121169

122170
模型由 [`py`](./py/) 目录中的 Python 脚本基于 YOLO 训练的分类模型导出,训练脚本示例:
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
{
2+
"model_path": "E:\\YOLO\\python-train-cpp-infer-demo\\models\\cat_vs_dog\\best.onnx",
3+
"io": {
4+
"input_name": "images",
5+
"input_shape": [
6+
1,
7+
3,
8+
224,
9+
224
10+
],
11+
"input_type": "tensor(float)",
12+
"output_name": "output0",
13+
"output_shape": [
14+
1,
15+
2
16+
],
17+
"output_type": "tensor(float)"
18+
},
19+
"metadata_raw": {
20+
"date": "2026-04-06T10:48:26.236324",
21+
"description": "Ultralytics YOLO11n-cls model trained on assets",
22+
"author": "Ultralytics",
23+
"version": "8.4.33",
24+
"task": "classify",
25+
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
26+
"docs": "https://docs.ultralytics.com",
27+
"stride": "1",
28+
"batch": "1",
29+
"imgsz": "[224, 224]",
30+
"names": "{0: 'cat', 1: 'dog'}",
31+
"args": "{'batch': 1, 'half': False, 'dynamic': False, 'simplify': True, 'opset': None, 'nms': False}",
32+
"channels": "3",
33+
"end2end": "False"
34+
},
35+
"metadata_parsed": {
36+
"date": "2026-04-06T10:48:26.236324",
37+
"description": "Ultralytics YOLO11n-cls model trained on assets",
38+
"author": "Ultralytics",
39+
"version": "8.4.33",
40+
"task": "classify",
41+
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
42+
"docs": "https://docs.ultralytics.com",
43+
"stride": 1,
44+
"batch": 1,
45+
"imgsz": [
46+
224,
47+
224
48+
],
49+
"names": {
50+
"0": "cat",
51+
"1": "dog"
52+
},
53+
"args": {
54+
"batch": 1,
55+
"half": false,
56+
"dynamic": false,
57+
"simplify": true,
58+
"opset": null,
59+
"nms": false
60+
},
61+
"channels": 3,
62+
"end2end": false
63+
},
64+
"cpp_recommended_config": {
65+
"task": "classify",
66+
"imgsz_hw": [
67+
224,
68+
224
69+
],
70+
"layout": "NCHW",
71+
"color_order": "RGB",
72+
"pixel_range": "[0, 1]",
73+
"normalize_mean": [
74+
0.0,
75+
0.0,
76+
0.0
77+
],
78+
"normalize_std": [
79+
1.0,
80+
1.0,
81+
1.0
82+
],
83+
"resize_rule": "short_edge_to_target_then_center_crop",
84+
"resize_long_edge_rounding": "floor(int(target * long / short))",
85+
"center_crop_rounding": "round((resized - target) / 2)",
86+
"interpolation_hint": "PIL.BILINEAR in YOLO; C++ approximate: INTER_AREA when downsample else INTER_LINEAR",
87+
"softmax_hint": "Do not add softmax again if model output already sums close to 1.",
88+
"class_names": [
89+
"cat",
90+
"dog"
91+
]
92+
},
93+
"probe": {
94+
"image_path": "E:\\YOLO\\python-train-cpp-infer-demo\\assets\\val\\dog\\dog_21.jpg",
95+
"cpp_style": {
96+
"shape": [
97+
1,
98+
2
99+
],
100+
"sum": 0.9999999701976776,
101+
"min": 0.29744085669517517,
102+
"max": 0.7025591135025024,
103+
"top1_index": 1,
104+
"top1_score": 0.7025591135025024
105+
},
106+
"ultralytics_reference_style": {
107+
"shape": [
108+
1,
109+
2
110+
],
111+
"sum": 1.0000000298023224,
112+
"min": 0.41840896010398865,
113+
"max": 0.5815910696983337,
114+
"top1_index": 1,
115+
"top1_score": 0.5815910696983337
116+
},
117+
"top1_same": true,
118+
"top1_score_abs_diff": 0.1209680438041687
119+
}
120+
}

py/inspect_onnx_for_cpp.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
from __future__ import annotations
2+
3+
import argparse
4+
import ast
5+
import json
6+
from pathlib import Path
7+
from typing import Any
8+
9+
import cv2
10+
import numpy as np
11+
import onnxruntime as ort
12+
from PIL import Image
13+
14+
15+
def _safe_literal(value: Any) -> Any:
16+
if not isinstance(value, str):
17+
return value
18+
try:
19+
return ast.literal_eval(value)
20+
except Exception:
21+
return value
22+
23+
24+
def _normalize_names(names_value: Any) -> list[str]:
25+
if isinstance(names_value, dict):
26+
pairs: list[tuple[int, str]] = []
27+
for key, value in names_value.items():
28+
try:
29+
idx = int(key)
30+
except Exception:
31+
continue
32+
pairs.append((idx, str(value)))
33+
if not pairs:
34+
return []
35+
pairs.sort(key=lambda item: item[0])
36+
max_index = pairs[-1][0]
37+
out = [f"class_{i}" for i in range(max_index + 1)]
38+
for idx, name in pairs:
39+
out[idx] = name
40+
return out
41+
if isinstance(names_value, (list, tuple)):
42+
return [str(x) for x in names_value]
43+
return []
44+
45+
46+
def _parse_imgsz(imgsz_value: Any, input_shape: list[Any]) -> list[int]:
47+
if isinstance(imgsz_value, (list, tuple)) and len(imgsz_value) == 2:
48+
try:
49+
return [int(imgsz_value[0]), int(imgsz_value[1])]
50+
except Exception:
51+
pass
52+
53+
if len(input_shape) == 4:
54+
h = input_shape[2]
55+
w = input_shape[3]
56+
if isinstance(h, int) and isinstance(w, int):
57+
return [h, w]
58+
return [224, 224]
59+
60+
61+
def _bgr_to_chw_float01(arr_bgr: np.ndarray) -> np.ndarray:
62+
arr_rgb = cv2.cvtColor(arr_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
63+
return np.transpose(arr_rgb, (2, 0, 1))[None]
64+
65+
66+
def preprocess_cpp_style(arr_bgr: np.ndarray, target_h: int, target_w: int) -> np.ndarray:
67+
h, w = arr_bgr.shape[:2]
68+
scale = max(target_w / float(w), target_h / float(h))
69+
resized_w = max(target_w, int(np.floor(w * scale)))
70+
resized_h = max(target_h, int(np.floor(h * scale)))
71+
interpolation = cv2.INTER_AREA if resized_w < w or resized_h < h else cv2.INTER_LINEAR
72+
resized = cv2.resize(arr_bgr, (resized_w, resized_h), interpolation=interpolation)
73+
74+
crop_x = max(0, int(np.rint((resized_w - target_w) / 2.0)))
75+
crop_y = max(0, int(np.rint((resized_h - target_h) / 2.0)))
76+
cropped = resized[crop_y : crop_y + target_h, crop_x : crop_x + target_w]
77+
return _bgr_to_chw_float01(cropped)
78+
79+
80+
def preprocess_ultralytics_reference(arr_bgr: np.ndarray, target_h: int, target_w: int) -> np.ndarray:
81+
arr_rgb = cv2.cvtColor(arr_bgr, cv2.COLOR_BGR2RGB)
82+
img = Image.fromarray(arr_rgb)
83+
src_w, src_h = img.size
84+
85+
if src_w <= src_h:
86+
resized_w = target_w
87+
resized_h = max(1, int(target_w * src_h / src_w))
88+
else:
89+
resized_h = target_h
90+
resized_w = max(1, int(target_h * src_w / src_h))
91+
92+
img = img.resize((resized_w, resized_h), resample=Image.BILINEAR)
93+
crop_x = int(round((resized_w - target_w) / 2.0))
94+
crop_y = int(round((resized_h - target_h) / 2.0))
95+
img = img.crop((crop_x, crop_y, crop_x + target_w, crop_y + target_h))
96+
97+
arr = np.asarray(img, dtype=np.float32) / 255.0
98+
return np.transpose(arr, (2, 0, 1))[None]
99+
100+
101+
def run_probe(session: ort.InferenceSession, input_name: str, tensor: np.ndarray) -> dict[str, Any]:
102+
output = session.run(None, {input_name: tensor})[0]
103+
logits = np.asarray(output).reshape(-1).astype(np.float64)
104+
top1_idx = int(np.argmax(logits))
105+
return {
106+
"shape": list(np.asarray(output).shape),
107+
"sum": float(logits.sum()),
108+
"min": float(logits.min()),
109+
"max": float(logits.max()),
110+
"top1_index": top1_idx,
111+
"top1_score": float(logits[top1_idx]),
112+
}
113+
114+
115+
def inspect_model(model_path: Path, image_path: Path | None = None) -> dict[str, Any]:
116+
session = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"])
117+
input0 = session.get_inputs()[0]
118+
output0 = session.get_outputs()[0]
119+
metadata_raw = dict(session.get_modelmeta().custom_metadata_map or {})
120+
metadata = {k: _safe_literal(v) for k, v in metadata_raw.items()}
121+
names = _normalize_names(metadata.get("names"))
122+
imgsz = _parse_imgsz(metadata.get("imgsz"), list(input0.shape))
123+
124+
result: dict[str, Any] = {
125+
"model_path": str(model_path.resolve()),
126+
"io": {
127+
"input_name": input0.name,
128+
"input_shape": list(input0.shape),
129+
"input_type": input0.type,
130+
"output_name": output0.name,
131+
"output_shape": list(output0.shape),
132+
"output_type": output0.type,
133+
},
134+
"metadata_raw": metadata_raw,
135+
"metadata_parsed": metadata,
136+
"cpp_recommended_config": {
137+
"task": metadata.get("task", "classify"),
138+
"imgsz_hw": imgsz,
139+
"layout": "NCHW",
140+
"color_order": "RGB",
141+
"pixel_range": "[0, 1]",
142+
"normalize_mean": [0.0, 0.0, 0.0],
143+
"normalize_std": [1.0, 1.0, 1.0],
144+
"resize_rule": "short_edge_to_target_then_center_crop",
145+
"resize_long_edge_rounding": "floor(int(target * long / short))",
146+
"center_crop_rounding": "round((resized - target) / 2)",
147+
"interpolation_hint": "PIL.BILINEAR in YOLO; C++ approximate: INTER_AREA when downsample else INTER_LINEAR",
148+
"softmax_hint": "Do not add softmax again if model output already sums close to 1.",
149+
"class_names": names,
150+
},
151+
}
152+
153+
if image_path is not None:
154+
buf = np.fromfile(str(image_path), dtype=np.uint8)
155+
arr_bgr = cv2.imdecode(buf, cv2.IMREAD_COLOR)
156+
if arr_bgr is None:
157+
raise RuntimeError(f"Failed to read image: {image_path}")
158+
159+
target_h, target_w = imgsz
160+
cpp_tensor = preprocess_cpp_style(arr_bgr, target_h, target_w)
161+
yolo_tensor = preprocess_ultralytics_reference(arr_bgr, target_h, target_w)
162+
probe_cpp = run_probe(session, input0.name, cpp_tensor)
163+
probe_yolo = run_probe(session, input0.name, yolo_tensor)
164+
165+
result["probe"] = {
166+
"image_path": str(image_path.resolve()),
167+
"cpp_style": probe_cpp,
168+
"ultralytics_reference_style": probe_yolo,
169+
"top1_same": probe_cpp["top1_index"] == probe_yolo["top1_index"],
170+
"top1_score_abs_diff": abs(probe_cpp["top1_score"] - probe_yolo["top1_score"]),
171+
}
172+
173+
return result
174+
175+
176+
def main() -> None:
177+
parser = argparse.ArgumentParser(
178+
description="Inspect Ultralytics ONNX metadata and export C++ inference parameters."
179+
)
180+
parser.add_argument(
181+
"--model",
182+
type=Path,
183+
default=Path("models/cat_vs_dog/best.onnx"),
184+
help="Path to ONNX model.",
185+
)
186+
parser.add_argument(
187+
"--image",
188+
type=Path,
189+
default=None,
190+
help="Optional image path for probe inference comparison.",
191+
)
192+
parser.add_argument(
193+
"--out",
194+
type=Path,
195+
default=None,
196+
help="Output JSON path (default: <model>.infer_params.json).",
197+
)
198+
parser.add_argument(
199+
"--print-only",
200+
action="store_true",
201+
help="Only print JSON, do not write file.",
202+
)
203+
args = parser.parse_args()
204+
205+
payload = inspect_model(args.model, args.image)
206+
output_text = json.dumps(payload, ensure_ascii=False, indent=2)
207+
print(output_text)
208+
209+
if args.print_only:
210+
return
211+
212+
out_path = args.out or args.model.with_suffix(".infer_params.json")
213+
out_path.parent.mkdir(parents=True, exist_ok=True)
214+
out_path.write_text(output_text + "\n", encoding="utf-8")
215+
print(f"\nSaved: {out_path.resolve()}")
216+
217+
218+
if __name__ == "__main__":
219+
main()

0 commit comments

Comments
 (0)