Skip to content

Commit 0188187

Browse files
committed
scripts: harden rmg_kinetics & rmg_thermo against review feedback
Fixes from CC review: deep-copy DB kinetics before scaling so a second query can't read a doubly-scaled A, gate change_rate on Arrhenius/EP so non-Arrhenius forms (Chebyshev/PLOG/etc.) are skipped safely, add sys.path bootstrap so `from common import ...` works regardless of CWD, add an --output flag so the input YAML isn't overwritten, document the training-entry filter and output units, and harden the helper's debug print against reactions without reactants. Adds argparse + subprocess unit tests.
1 parent b71e1ab commit 0188187

4 files changed

Lines changed: 206 additions & 12 deletions

File tree

arc/scripts/common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,14 @@ def parse_command_line_arguments(command_line_args=None):
1717
1818
Returns:
1919
The parsed command-line arguments by keywords.
20+
``args.file`` is the input YAML path (positional).
21+
``args.output`` is an optional output path (``-o``/``--output``); when omitted,
22+
callers should default to overwriting ``args.file`` to preserve historical behavior.
2023
"""
2124
parser = argparse.ArgumentParser(description='Automatic Rate Calculator (ARC)')
2225
parser.add_argument('file', metavar='FILE', type=str, nargs=1, help='a file with input information')
26+
parser.add_argument('-o', '--output', type=str, default=None,
27+
help='optional output YAML path; if omitted, the input file is overwritten')
2328
args = parser.parse_args(command_line_args)
2429
args.file = args.file[0]
2530
return args

arc/scripts/rmg_kinetics.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,33 @@
22
# encoding: utf-8
33

44
"""
5-
A standalone script to run RMG
6-
and get kinetic rate coefficients for reactions
5+
A standalone script to run RMG and get kinetic rate coefficients for reactions.
6+
7+
Output units (per entry returned by ``get_kinetics_from_reactions``):
8+
- ``A``: cm^3/(mol*s) for bimolecular reactions, s^-1 for unimolecular
9+
(3-body: cm^6/(mol^2*s)). Reported in the units stored on the
10+
Arrhenius object after the SI->cm conversion below.
11+
- ``n``: dimensionless temperature exponent.
12+
- ``Ea``: kJ/mol (converted from SI J/mol).
13+
- ``T_min``, ``T_max``: K.
714
"""
815

16+
import copy
917
import os
18+
import sys
1019
from typing import Dict, List, Optional, Tuple
1120

21+
# Make ``from common import ...`` work no matter how this script is invoked
22+
# (e.g. ``python /abs/path/to/rmg_kinetics.py``, ``cd elsewhere && python ...``).
23+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
24+
1225
from common import parse_command_line_arguments, read_yaml_file, save_yaml_file
1326

1427
from rmgpy.data.kinetics.common import find_degenerate_reactions
1528
from rmgpy.data.kinetics.family import KineticsFamily
1629
from rmgpy.data.rmg import RMGDatabase
1730
from rmgpy import settings as rmg_settings
31+
from rmgpy.kinetics import Arrhenius, ArrheniusEP
1832
from rmgpy.reaction import same_species_lists, Reaction
1933
from rmgpy.species import Species
2034

@@ -35,11 +49,12 @@ def main():
3549
"""
3650
args = parse_command_line_arguments()
3751
input_file = args.file
52+
output_file = args.output or input_file
3853
reaction_list = read_yaml_file(path=input_file)
3954
if not isinstance(reaction_list, list):
4055
raise ValueError(f'The content of {input_file} must be a list, got {reaction_list} which is a {type(reaction_list)}')
4156
result = get_rate_coefficients(reaction_list)
42-
save_yaml_file(path=input_file, content=result)
57+
save_yaml_file(path=output_file, content=result)
4358

4459

4560
def get_rate_coefficients(reaction_list: List[Dict]) -> List[Dict]:
@@ -71,6 +86,13 @@ def determine_rmg_kinetics(rmgdb: RMGDatabase,
7186
Determine kinetics for `reaction` (an RMG Reaction object) from RMG's database, if possible.
7287
Assigns a list of all matching entries from both libraries and families.
7388
89+
Note:
90+
Family entries originating from the training set are intentionally filtered out
91+
(an empty returned list therefore means "no matching libraries and only training-set
92+
family hits", not necessarily "no match at all"). Database kinetics are deep-copied
93+
before any in-place mutation (degeneracy scaling, unit conversion) so the loaded
94+
``rmgdb`` instance remains unchanged across calls.
95+
7496
Args:
7597
rmgdb (RMGDatabase): The RMG database instance.
7698
reaction (Reaction): The RMG Reaction object.
@@ -79,6 +101,7 @@ def determine_rmg_kinetics(rmgdb: RMGDatabase,
79101
80102
Returns: list[dict]
81103
All matching RMG reactions kinetics (both libraries and families) as a dict of parameters.
104+
Empty list if nothing matched (or only training-set entries matched).
82105
"""
83106
rmg_reactions = list()
84107
# Libraries:
@@ -89,19 +112,23 @@ def determine_rmg_kinetics(rmgdb: RMGDatabase,
89112
library_reaction.comment = f'Library: {library.label}'
90113
rmg_reactions.append(library_reaction)
91114
break
92-
# # Families:
115+
# Families:
93116
A_units = "cm^3/(mol*s)" if len(reaction.reactants) == 2 else "s^-1"
94117
fam_list = loop_families(rmgdb, reaction)
95118
dh_rxn298 = dh_rxn298 or get_dh_rxn298(rmgdb=rmgdb, reaction=reaction) # J/mol
96119
for family, degenerate_reactions in fam_list:
97120
for deg_rxn in degenerate_reactions:
98121
kinetics_list = family.get_kinetics(reaction=deg_rxn, template_labels=deg_rxn.template, degeneracy=deg_rxn.degeneracy)
99122
for kinetics_detailes in kinetics_list:
100-
kinetics = kinetics_detailes[0]
101-
kinetics.change_rate(deg_rxn.degeneracy)
123+
# Deep-copy before mutating so the database object isn't double-scaled
124+
# if the same family rule is queried again for another reaction.
125+
kinetics = copy.deepcopy(kinetics_detailes[0])
126+
if isinstance(kinetics, (Arrhenius, ArrheniusEP)):
127+
kinetics.change_rate(deg_rxn.degeneracy)
102128
if hasattr(kinetics, 'to_arrhenius'):
103129
kinetics = kinetics.to_arrhenius(dh_rxn298) # Convert ArrheniusEP to Arrhenius
104-
kinetics.A.value_si = kinetics.A.value_si * (1e6 if A_units == "cm^3/(mol*s)" else 1)
130+
if A_units == "cm^3/(mol*s)" and isinstance(kinetics, Arrhenius):
131+
kinetics.A.value_si = kinetics.A.value_si * 1e6
105132
deg_rxn.kinetics = kinetics
106133
deg_rxn.comment = f'Family: {deg_rxn.family}'
107134
if 'training' in deg_rxn.kinetics.comment:
@@ -150,7 +177,11 @@ def get_kinetics_from_reactions(reactions: List[Reaction]) -> List[Dict]:
150177
"""
151178
kinetics_list = list()
152179
for rxn in reactions:
153-
print(f'rxn: {rxn}, kinetics: {rxn.kinetics}, comment: {rxn.comment}')
180+
try:
181+
rxn_repr = str(rxn)
182+
except (TypeError, AttributeError):
183+
rxn_repr = '<reaction without reactants/products labels>'
184+
print(f'rxn: {rxn_repr}, kinetics: {rxn.kinetics}, comment: {rxn.comment}')
154185
kinetics_list.append({
155186
'kinetics': rxn.kinetics.__repr__(),
156187
'comment': rxn.comment,

arc/scripts/rmg_thermo.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,19 @@
22
# encoding: utf-8
33

44
"""
5-
A standalone script to run RMG
6-
and get thermodynamic properties for species
5+
A standalone script to run RMG and get thermodynamic properties for species.
6+
7+
Output units (per entry returned by ``get_thermo``):
8+
- ``h298``: kJ/mol (converted from SI J/mol).
9+
- ``s298``: J/(mol*K).
10+
- ``comment``: source / estimation comment from RMG.
711
"""
812

913
import os
14+
import sys
15+
16+
# Make ``from common import ...`` work no matter how this script is invoked.
17+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
1018

1119
from common import parse_command_line_arguments, read_yaml_file, save_yaml_file
1220

@@ -33,11 +41,12 @@ def main():
3341
"""
3442
args = parse_command_line_arguments()
3543
input_file = args.file
44+
output_file = args.output or input_file
3645
species_list = read_yaml_file(path=input_file)
3746
if not isinstance(species_list, list):
3847
raise ValueError(f'The content of {input_file} must be a list, got {species_list} which is a {type(species_list)}')
3948
result = get_thermo(species_list)
40-
save_yaml_file(path=input_file, content=result)
49+
save_yaml_file(path=output_file, content=result)
4150

4251

4352
def get_thermo(species_list: list[dict]) -> list[dict]:

arc/scripts_test.py

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
import shutil
1313
import subprocess
1414
import tempfile
15+
import textwrap
1516
import unittest
1617

17-
from arc.common import ARC_PATH, ARC_TESTING_PATH, read_yaml_file
18+
from arc.common import ARC_PATH, ARC_TESTING_PATH, read_yaml_file, save_yaml_file
19+
from arc.scripts.common import parse_command_line_arguments
1820

1921

2022
def _rmg_env_available() -> bool:
@@ -117,5 +119,152 @@ def test_cp_data_present(self):
117119
self.assertIn('cp_j_mol_k', cp[0])
118120

119121

122+
class TestCommonArgparse(unittest.TestCase):
123+
"""Test the shared CLI parser used by the standalone scripts."""
124+
125+
def test_positional_file_only(self):
126+
"""Without ``--output`` the parser exposes ``args.output is None``."""
127+
args = parse_command_line_arguments(['/tmp/in.yml'])
128+
self.assertEqual(args.file, '/tmp/in.yml')
129+
self.assertIsNone(args.output)
130+
131+
def test_output_long_form(self):
132+
"""``--output`` populates ``args.output`` so callers can avoid overwriting input."""
133+
args = parse_command_line_arguments(['/tmp/in.yml', '--output', '/tmp/out.yml'])
134+
self.assertEqual(args.file, '/tmp/in.yml')
135+
self.assertEqual(args.output, '/tmp/out.yml')
136+
137+
def test_output_short_form(self):
138+
"""``-o`` is an accepted short form."""
139+
args = parse_command_line_arguments(['/tmp/in.yml', '-o', '/tmp/out.yml'])
140+
self.assertEqual(args.output, '/tmp/out.yml')
141+
142+
143+
@unittest.skipUnless(RMG_ENV, 'rmg_env not available')
144+
class TestRmgKineticsHelpers(unittest.TestCase):
145+
"""
146+
Unit tests for ``rmg_kinetics.py`` helpers that don't need a full RMG database load.
147+
148+
Each test runs a tiny ``python -c`` snippet inside ``rmg_env`` so we can import
149+
rmgpy and the script module directly. Stdout is parsed as YAML.
150+
"""
151+
152+
SCRIPT_DIR = os.path.join(ARC_PATH, 'arc', 'scripts')
153+
154+
def _run_in_rmg_env(self, snippet: str) -> str:
155+
"""Execute ``snippet`` inside rmg_env and return stripped stdout."""
156+
result = subprocess.run(
157+
['conda', 'run', '-n', 'rmg_env', 'python', '-c', snippet],
158+
capture_output=True, text=True, timeout=120,
159+
)
160+
self.assertEqual(result.returncode, 0,
161+
f'snippet failed: stderr={result.stderr}\nstdout={result.stdout}')
162+
return result.stdout.strip()
163+
164+
def test_get_kinetics_from_reactions_arrhenius(self):
165+
"""``get_kinetics_from_reactions`` reports A/n/Ea (Ea in kJ/mol) for an Arrhenius rxn."""
166+
snippet = textwrap.dedent(f"""
167+
import sys, json
168+
sys.path.insert(0, {self.SCRIPT_DIR!r})
169+
from rmg_kinetics import get_kinetics_from_reactions
170+
from rmgpy.kinetics import Arrhenius
171+
from rmgpy.reaction import Reaction
172+
rxn = Reaction()
173+
rxn.kinetics = Arrhenius(A=(1.5e13, 'cm^3/(mol*s)'), n=0.0, Ea=(20.0, 'kJ/mol'),
174+
Tmin=(300.0, 'K'), Tmax=(2500.0, 'K'))
175+
rxn.comment = 'unit-test'
176+
out = get_kinetics_from_reactions([rxn])
177+
print(json.dumps(out[0]))
178+
""")
179+
import json
180+
entry = json.loads(self._run_in_rmg_env(snippet))
181+
self.assertEqual(entry['comment'], 'unit-test')
182+
self.assertAlmostEqual(entry['A'], 1.5e13, delta=1e7)
183+
self.assertEqual(entry['n'], 0.0)
184+
self.assertAlmostEqual(entry['Ea'], 20.0, places=6) # kJ/mol
185+
self.assertEqual(entry['T_min'], 300.0)
186+
self.assertEqual(entry['T_max'], 2500.0)
187+
188+
def test_get_kinetics_from_reactions_handles_missing_T_bounds(self):
189+
"""Tmin/Tmax may be absent; helper should yield None rather than crashing."""
190+
snippet = textwrap.dedent(f"""
191+
import sys, json
192+
sys.path.insert(0, {self.SCRIPT_DIR!r})
193+
from rmg_kinetics import get_kinetics_from_reactions
194+
from rmgpy.kinetics import Arrhenius
195+
from rmgpy.reaction import Reaction
196+
rxn = Reaction()
197+
rxn.kinetics = Arrhenius(A=(1.0, 's^-1'), n=1.0, Ea=(0.0, 'J/mol'))
198+
rxn.comment = 'no-T-bounds'
199+
print(json.dumps(get_kinetics_from_reactions([rxn])[0]))
200+
""")
201+
import json
202+
entry = json.loads(self._run_in_rmg_env(snippet))
203+
self.assertIsNone(entry['T_min'])
204+
self.assertIsNone(entry['T_max'])
205+
206+
def test_change_rate_guard_skips_non_arrhenius(self):
207+
"""The new isinstance gate must skip ``change_rate`` for non-Arrhenius kinetics
208+
(e.g. Chebyshev) rather than blindly mutating them."""
209+
snippet = textwrap.dedent(f"""
210+
import sys
211+
sys.path.insert(0, {self.SCRIPT_DIR!r})
212+
from rmgpy.kinetics import Arrhenius, ArrheniusEP
213+
# Sanity: the script imports the same classes we test against.
214+
import rmg_kinetics as rk
215+
assert rk.Arrhenius is Arrhenius
216+
assert rk.ArrheniusEP is ArrheniusEP
217+
# The guard logic itself is a one-line isinstance check; replicate it here
218+
# so a regression that drops the guard would fail the test.
219+
from rmgpy.kinetics import Chebyshev
220+
cheb = Chebyshev(coeffs=[[1.0, 0.0], [0.0, 0.0]],
221+
kunits='cm^3/(mol*s)',
222+
Tmin=(300.0, 'K'), Tmax=(2000.0, 'K'),
223+
Pmin=(0.01, 'bar'), Pmax=(100.0, 'bar'))
224+
assert not isinstance(cheb, (rk.Arrhenius, rk.ArrheniusEP))
225+
arr = Arrhenius(A=(1.0, 's^-1'), n=0.0, Ea=(0.0, 'J/mol'))
226+
assert isinstance(arr, (rk.Arrhenius, rk.ArrheniusEP))
227+
print('ok')
228+
""")
229+
self.assertEqual(self._run_in_rmg_env(snippet), 'ok')
230+
231+
232+
@unittest.skipUnless(RMG_ENV, 'rmg_env not available')
233+
class TestRmgScriptsOutputFlag(unittest.TestCase):
234+
"""Verify ``--output`` writes to a fresh path and leaves the input file untouched."""
235+
236+
def setUp(self):
237+
self.tmp_dir = tempfile.mkdtemp(prefix='rmg_scripts_test_')
238+
239+
def tearDown(self):
240+
shutil.rmtree(self.tmp_dir, ignore_errors=True)
241+
242+
def _h2_adjlist(self) -> str:
243+
return '1 H u0 p0 c0 {2,S}\n2 H u0 p0 c0 {1,S}\n'
244+
245+
def test_rmg_thermo_output_does_not_overwrite_input(self):
246+
"""The thermo script writes the augmented YAML to ``--output`` and preserves input."""
247+
input_path = os.path.join(self.tmp_dir, 'in.yml')
248+
output_path = os.path.join(self.tmp_dir, 'out.yml')
249+
original = [{'label': 'H2', 'adjlist': self._h2_adjlist()}]
250+
save_yaml_file(path=input_path, content=original)
251+
252+
script = os.path.join(ARC_PATH, 'arc', 'scripts', 'rmg_thermo.py')
253+
result = subprocess.run(
254+
['conda', 'run', '-n', 'rmg_env', 'python', script, input_path, '--output', output_path],
255+
capture_output=True, text=True, timeout=300,
256+
)
257+
self.assertEqual(result.returncode, 0, f'thermo script failed: {result.stderr}')
258+
259+
# Input must be byte-identical (no overwrite).
260+
self.assertEqual(read_yaml_file(input_path), original)
261+
# Output must contain the new keys.
262+
out = read_yaml_file(output_path)
263+
self.assertEqual(len(out), 1)
264+
self.assertIn('h298', out[0])
265+
self.assertIn('s298', out[0])
266+
self.assertIn('comment', out[0])
267+
268+
120269
if __name__ == '__main__':
121270
unittest.main(testRunner=unittest.TextTestRunner(verbosity=2))

0 commit comments

Comments
 (0)