|
18 | 18 | import logging |
19 | 19 | from pathlib import Path |
20 | 20 | from typing import Any, List, Tuple, Dict |
21 | | -from itertools import chain |
22 | 21 |
|
23 | 22 | import yaml |
24 | 23 | import numpy as np |
@@ -175,21 +174,29 @@ def compare_values(val1: Any, val2: Any, path: str, atol: float = 1e-12, |
175 | 174 |
|
176 | 175 | ### SPECIAL CASES |
177 | 176 | # Special handling for 'elements' path - normalize to title case and sort |
178 | | - if path.split('.')[-1] == 'elements': |
| 177 | + if path.endswith('elements') and val1 and val2: |
179 | 178 | if isinstance(val1[0], str) and isinstance(val2[0], str): |
| 179 | + # the elements list in a phase, eg. elements: [H, C, O, N, Ne, X] |
180 | 180 | val1_normalized = sorted([v.title() for v in val1]) |
181 | 181 | val2_normalized = sorted([v.title() for v in val2]) |
182 | 182 | if val1_normalized != val2_normalized: |
183 | 183 | differences.append(f"Elements list mismatch at {path}: {val1_normalized} vs {val2_normalized}") |
184 | 184 | return differences |
185 | 185 | if isinstance(val1[0], dict) and isinstance(val2[0], dict): |
186 | | - for d in chain(val1, val2): |
187 | | - d['symbol'] = d['symbol'].title() |
| 186 | + # The top level elements list like elements:[{symbol: D, atomic-weight: 2.014102}] |
| 187 | + val1 = [dict(d, symbol=d['symbol'].title()) for d in val1] |
| 188 | + val2 = [dict(d, symbol=d['symbol'].title()) for d in val2] |
188 | 189 | val1 = sorted([d for d in val1], key=lambda x: x['symbol']) |
189 | 190 | val2 = sorted([d for d in val2], key=lambda x: x['symbol']) |
190 | 191 |
|
191 | 192 | if path.endswith('Troe.T1') or path.endswith('Troe.T2') or path.endswith('Troe.T3'): |
192 | | - rtol = 5e-3 # Relax tolerance due to rounding. |
| 193 | + rtol = 0.005 # Relax tolerance due to rounding. |
| 194 | + |
| 195 | + if path.endswith('rate-constant.b'): |
| 196 | + atol = 0.0005 # ck rounds to 0.001 |
| 197 | + |
| 198 | + if path.endswith('rate-constant.Ea'): |
| 199 | + atol = 4185/2 # ck rounds to 0.001 kcal/mol which is 4184 J/kmol |
193 | 200 |
|
194 | 201 | ### END OF SPECIAL CASES |
195 | 202 |
|
@@ -226,13 +233,13 @@ def compare_values(val1: Any, val2: Any, path: str, atol: float = 1e-12, |
226 | 233 |
|
227 | 234 | # Handle numeric values with tolerance |
228 | 235 | elif is_numeric(val1) and is_numeric(val2): |
229 | | - # Use numpy.allclose for comparison |
| 236 | + # Use numpy.isclose for comparison |
230 | 237 | if not np.isclose(val1, val2, atol=atol, rtol=rtol): |
231 | 238 | # Compute the difference for reporting |
232 | | - abs_diff = abs(val1 - val2) |
233 | | - rel_diff = abs(abs_diff / val2) if val2 != 0 else float('inf') |
| 239 | + diff = (val2 - val1) |
| 240 | + ratio = (val2 / val1) if val1 != 0 else float('inf') |
234 | 241 | differences.append(f"Numerical difference at {path}: {val1} vs {val2} " |
235 | | - f"(abs_diff={abs_diff:.2e}, rel_diff={rel_diff:.2e})") |
| 242 | + f"(difference={diff:.2g}, ratio={ratio:.2g})") |
236 | 243 |
|
237 | 244 | # Handle strings and other comparable types |
238 | 245 | elif val1 != val2: |
@@ -500,9 +507,9 @@ def main(): |
500 | 507 | parser.add_argument("file1", help="First Cantera YAML file") |
501 | 508 | parser.add_argument("file2", help="Second Cantera YAML file") |
502 | 509 | parser.add_argument("--abs-tol", type=float, default=1e-11, |
503 | | - help="Absolute tolerance for numerical comparisons (default: 1e-9)") |
| 510 | + help="Absolute tolerance for numerical comparisons (default: 1e-11)") |
504 | 511 | parser.add_argument("--rel-tol", type=float, default=1e-3, |
505 | | - help="Relative tolerance for numerical comparisons (default: 1e-9)") |
| 512 | + help="Relative tolerance for numerical comparisons (default: 1e-3)") |
506 | 513 |
|
507 | 514 | args = parser.parse_args() |
508 | 515 |
|
@@ -535,14 +542,11 @@ def main(): |
535 | 542 | logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") |
536 | 543 | if len(sys.argv) == 1: |
537 | 544 | logging.info("No arguments provided. Using default test files for demonstration.") |
538 | | - # sys.argv.extend([ |
539 | | - # "test/rmgpy/test_data/yaml_writer_data/chemkin/from_main_test.yaml", |
540 | | - # "test/rmgpy/test_data/yaml_writer_data/cantera/from_main_test.yaml" |
541 | | - # ]) |
| 545 | + rmg_root = Path(__file__).resolve().parents[2] |
542 | 546 |
|
543 | 547 | sys.argv.extend([ |
544 | | - "/Users/rwest/Code/RMG-Py/testing/eg0/cantera_from_ck/chem.yaml", |
545 | | - "/Users/rwest/Code/RMG-Py/testing/eg0/cantera2/chem.yaml" |
| 548 | + str(rmg_root / "testing/eg0/cantera_from_ck/chem.yaml"), |
| 549 | + str(rmg_root / "testing/eg0/cantera1/chem.yaml") |
546 | 550 | ]) |
547 | 551 |
|
548 | 552 | main() |
0 commit comments