|
2 | 2 | import os |
3 | 3 | import time |
4 | 4 | import subprocess |
5 | | -from argparse import ArgumentParser, BooleanOptionalAction |
| 5 | +from argparse import ArgumentParser, BooleanOptionalAction, RawTextHelpFormatter |
6 | 6 | from typing import Any, Dict, List, Tuple |
7 | 7 | import onnx |
8 | 8 |
|
@@ -50,10 +50,13 @@ def get_torch_dtype_from_command_line_args(dtype: str) -> "torch.dtype": # noqa |
50 | 50 | return torch_dtype[dtype] |
51 | 51 |
|
52 | 52 |
|
53 | | -def get_parser(name: str) -> ArgumentParser: |
| 53 | +def get_parser(name: str, epilog: str = "") -> ArgumentParser: |
54 | 54 | """Creates a default parser for many models.""" |
55 | 55 | parser = ArgumentParser( |
56 | | - prog=name, description=f"""Export command line for model {name!r}.""" |
| 56 | + prog=name, |
| 57 | + description=f"""Export command line for model {name!r}.""", |
| 58 | + epilog=epilog, |
| 59 | + formatter_class=RawTextHelpFormatter, |
57 | 60 | ) |
58 | 61 | parser.add_argument( |
59 | 62 | "-m", |
@@ -110,7 +113,7 @@ def get_parser(name: str) -> ArgumentParser: |
110 | 113 | "-a", |
111 | 114 | "--atol", |
112 | 115 | type=float, |
113 | | - default=1.0, |
| 116 | + default=2.0, |
114 | 117 | help="fails if the maximum discrepancy is above that threshold", |
115 | 118 | ) |
116 | 119 | parser.add_argument( |
@@ -311,7 +314,8 @@ def fprint(s): |
311 | 314 | diff = max_diff(flat_export_expected, small, hist=[0.1, 0.01]) |
312 | 315 | fprint(f"-- discrepancies={diff}") |
313 | 316 | assert diff["abs"] <= atol and diff["rep"][">0.1"] / diff["n"] <= mismatch01, ( |
314 | | - f"absolution tolerance is above {atol} or number of mismatches is above " |
| 317 | + f"absolute error {diff['abs']} is above {atol} or number of " |
| 318 | + f"mismatches ({diff['rep']['>0.1'] / diff['n']}) is above " |
315 | 319 | f"{mismatch01}, dicrepancies={string_diff(diff)}" |
316 | 320 | ) |
317 | 321 |
|
@@ -362,8 +366,9 @@ def fprint(s): |
362 | 366 | assert ( |
363 | 367 | diff["abs"] <= atol and diff["rep"][">0.1"] / diff["n"] <= mismatch01 |
364 | 368 | ), ( |
365 | | - f"absolution tolerance is above {atol} or number of mismatches is " |
366 | | - f"above {mismatch01}, dicrepancies={string_diff(diff)}" |
| 369 | + f"absolute error {diff['abs']} is above {atol} or number " |
| 370 | + f" of mismatches ({diff['rep']['>0.1'] / diff['n']}) " |
| 371 | + f"is above {mismatch01}, dicrepancies={string_diff(diff)}" |
367 | 372 | ) |
368 | 373 | js = string_diff(diff, js=True, ratio=True, inputs=se, **info) |
369 | 374 | fs.write(js) |
|
0 commit comments