Skip to content

Commit a932851

Browse files
committed
Added the RitS adapter
1 parent 3e5f3e8 commit a932851

6 files changed

Lines changed: 704 additions & 4 deletions

File tree

arc/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def check_ess_settings(ess_settings: Optional[dict] = None) -> dict:
141141
f'strings. Got: {server_list} which is a {type(server_list)}')
142142
# run checks:
143143
for ess, server_list in settings_dict.items():
144-
if ess.lower() not in supported_ess + ['gcn', 'heuristics', 'autotst', 'kinbot', 'xtb_gsm', 'orca_neb']:
144+
if ess.lower() not in supported_ess + ['gcn', 'heuristics', 'autotst', 'kinbot', 'rits', 'xtb_gsm', 'orca_neb']:
145145
raise SettingsError(f'Recognized ESS software are {supported_ess}. Got: {ess}')
146146
for server in server_list:
147147
if not isinstance(server, bool) and server.lower() not in [s.lower() for s in servers.keys()]:

arc/job/adapter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class JobEnum(str, Enum):
9898
heuristics = 'heuristics' # ARC's heuristics
9999
kinbot = 'kinbot' # KinBot, 10.1016/j.cpc.2019.106947
100100
gcn = 'gcn' # Graph neural network for isomerization, https://doi.org/10.1021/acs.jpclett.0c00500
101+
rits = 'rits' # Right into the Saddle, flow-matching TS generator, https://github.com/isayevlab/RitS, 10.26434/chemrxiv.15001681/v1
101102
user = 'user' # user guesses
102103
xtb_gsm = 'xtb_gsm' # Double ended growing string method (DE-GSM), [10.1021/ct400319w, 10.1063/1.4804162] via xTB
103104
orca_neb = 'orca_neb'

arc/job/adapters/common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@
7373
'Singlet_Carbene_Intra_Disproportionation': ['gcn', 'xtb_gsm', 'orca_neb'],
7474
}
7575

76-
all_families_ts_adapters = []
76+
all_families_ts_adapters = ['rits']
7777
adapters_that_do_not_require_a_level_arg = ['xtb', 'torchani']
7878

7979
# Default is "queue", "pipe" will be called whenever needed. So just list 'incore'.
80-
default_incore_adapters = ['autotst', 'crest', 'gcn', 'heuristics', 'kinbot', 'psi4', 'xtb', 'xtb_gsm', 'torchani',
81-
'openbabel']
80+
default_incore_adapters = ['autotst', 'crest', 'gcn', 'heuristics', 'kinbot', 'psi4', 'rits',
81+
'xtb', 'xtb_gsm', 'torchani', 'openbabel']
8282

8383

8484
def _initialize_adapter(obj: 'JobAdapter',
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
#!/usr/bin/env python3
2+
# encoding: utf-8
3+
4+
"""
5+
A standalone script to run RitS (Right into the Saddle) and emit TS guesses
6+
as a YAML file consumable by ARC's RitSAdapter.
7+
8+
This script must be invoked from inside the ``rits_env`` conda environment
9+
(it does NOT import ``megalodon`` directly — RitS's own
10+
``scripts/sample_transition_state.py`` does that). The parent ARC process
11+
shells out to this script via ``subprocess.run`` so that ARC's main env
12+
does not have to carry the heavy ML dependency stack.
13+
14+
Input file (``input.yml``)
15+
--------------------------
16+
Required keys:
17+
reactant_xyz_path : str absolute path to a plain XYZ file (atom-mapped)
18+
product_xyz_path : str absolute path to the matching product XYZ
19+
rits_repo_path : str absolute path to the RitS source checkout
20+
ckpt_path : str absolute path to the pretrained ``rits.ckpt``
21+
output_xyz_path : str absolute path RitS should write its raw output to
22+
yml_out_path : str absolute path this script writes the parsed TSGuess list to
23+
24+
Optional keys (with defaults):
25+
config_path : str defaults to ``<rits_repo_path>/scripts/conf/rits.yaml``
26+
n_samples : int default 10
27+
batch_size : int default 32
28+
charge : int default 0
29+
device : str default 'auto' (RitS picks GPU if visible, else CPU)
30+
add_stereo : bool default False
31+
num_steps : int default None (use config value)
32+
33+
Output (``yml_out_path``)
34+
-------------------------
35+
A YAML *list* of TSGuess dictionaries. Each entry has:
36+
method : 'RitS'
37+
method_direction : 'F'
38+
method_index : int (0-based sample index)
39+
initial_xyz : str (XYZ-format coordinate block, no header lines)
40+
success : bool
41+
execution_time : str (str(datetime.timedelta))
42+
43+
If RitS fails to produce any usable output, the script writes a list with a
44+
single failed-guess entry instead of raising — the parent adapter then logs
45+
the failure but continues running other TS methods.
46+
"""
47+
48+
import argparse
49+
import datetime
50+
import os
51+
import subprocess
52+
import sys
53+
import traceback
54+
from typing import List, Optional
55+
56+
import yaml
57+
58+
59+
def read_yaml_file(path: str) -> dict:
60+
"""Read a YAML file and return its contents as a dict."""
61+
with open(path, 'r') as f:
62+
return yaml.load(stream=f, Loader=yaml.FullLoader)
63+
64+
65+
def string_representer(dumper, data):
66+
"""YAML representer that uses block literals for multi-line strings."""
67+
if len(data.splitlines()) > 1:
68+
return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data, style='|')
69+
return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data)
70+
71+
72+
def save_yaml_file(path: str, content) -> None:
73+
"""Save ``content`` to a YAML file at ``path``."""
74+
yaml.add_representer(str, string_representer)
75+
with open(path, 'w') as f:
76+
f.write(yaml.dump(data=content))
77+
78+
79+
def parse_multi_frame_xyz(xyz_path: str) -> List[str]:
80+
"""
81+
Parse a (possibly multi-frame) XYZ file into a list of coordinate-block strings.
82+
83+
RitS writes a single XYZ file when ``--n_samples == 1`` and a multi-frame
84+
XYZ when ``n_samples > 1`` (frames concatenated, each prefixed by an atom
85+
count line and a blank/comment line). This parser handles both.
86+
87+
Args:
88+
xyz_path (str): Path to the XYZ file emitted by RitS.
89+
90+
Returns:
91+
List[str]: One coordinate block per frame, suitable for passing to
92+
``arc.species.converter.str_to_xyz`` (atom symbols + xyz only — no
93+
header / comment lines).
94+
"""
95+
if not os.path.isfile(xyz_path):
96+
return list()
97+
with open(xyz_path, 'r') as f:
98+
raw_lines = [line.rstrip('\n') for line in f]
99+
frames = list()
100+
i, n = 0, len(raw_lines)
101+
while i < n:
102+
# Skip blank lines between frames
103+
while i < n and not raw_lines[i].strip():
104+
i += 1
105+
if i >= n:
106+
break
107+
# First non-blank line of a frame should be the atom count
108+
try:
109+
n_atoms = int(raw_lines[i].strip())
110+
except ValueError:
111+
# Not a frame header — bail on this row to avoid an infinite loop
112+
i += 1
113+
continue
114+
i += 1
115+
# Comment / energy line (may be blank)
116+
if i < n:
117+
i += 1
118+
# The next n_atoms lines are coordinates
119+
coord_lines = list()
120+
for _ in range(n_atoms):
121+
if i >= n:
122+
break
123+
coord_lines.append(raw_lines[i])
124+
i += 1
125+
if len(coord_lines) == n_atoms:
126+
frames.append('\n'.join(coord_lines))
127+
return frames
128+
129+
130+
def run_rits(input_dict: dict) -> List[dict]:
131+
"""
132+
Invoke ``scripts/sample_transition_state.py`` from the RitS source tree
133+
and parse the resulting XYZ frames into a list of TSGuess dictionaries.
134+
135+
Args:
136+
input_dict (dict): The parsed contents of ``input.yml``.
137+
138+
Returns:
139+
List[dict]: One TSGuess-shaped dict per generated sample. Always at
140+
least one entry — a failed sentinel if RitS produced nothing.
141+
"""
142+
repo = input_dict['rits_repo_path']
143+
sample_script = os.path.join(repo, 'scripts', 'sample_transition_state.py')
144+
config_path = input_dict.get('config_path') or os.path.join(repo, 'scripts', 'conf', 'rits.yaml')
145+
output_xyz = input_dict['output_xyz_path']
146+
n_samples = int(input_dict.get('n_samples', 10))
147+
batch_size = int(input_dict.get('batch_size', 32))
148+
charge = int(input_dict.get('charge', 0))
149+
device = str(input_dict.get('device', 'auto'))
150+
add_stereo = bool(input_dict.get('add_stereo', False))
151+
num_steps = input_dict.get('num_steps')
152+
153+
cmd = [
154+
sys.executable, sample_script,
155+
'--reactant_xyz', input_dict['reactant_xyz_path'],
156+
'--product_xyz', input_dict['product_xyz_path'],
157+
'--config', config_path,
158+
'--ckpt', input_dict['ckpt_path'],
159+
'--output', output_xyz,
160+
'--n_samples', str(n_samples),
161+
'--batch_size', str(batch_size),
162+
'--charge', str(charge),
163+
'--device', device,
164+
]
165+
if add_stereo:
166+
cmd.append('--add_stereo')
167+
if num_steps is not None:
168+
cmd.extend(['--num_steps', str(num_steps)])
169+
170+
t0 = datetime.datetime.now()
171+
print(f'[rits_script] running: {" ".join(cmd)}', flush=True)
172+
completed = subprocess.run(cmd, cwd=repo)
173+
elapsed = datetime.datetime.now() - t0
174+
175+
if completed.returncode != 0:
176+
print(f'[rits_script] sample_transition_state.py exited with code {completed.returncode}', flush=True)
177+
return [_failed_guess(elapsed, index=0)]
178+
179+
frames = parse_multi_frame_xyz(output_xyz)
180+
if not frames:
181+
print(f'[rits_script] no frames parsed from {output_xyz}', flush=True)
182+
return [_failed_guess(elapsed, index=0)]
183+
184+
tsgs = list()
185+
for i, coord_block in enumerate(frames):
186+
tsgs.append({
187+
'method': 'RitS',
188+
'method_direction': 'F',
189+
'method_index': i,
190+
'initial_xyz': coord_block,
191+
'success': True,
192+
'execution_time': str(elapsed),
193+
})
194+
return tsgs
195+
196+
197+
def _failed_guess(elapsed: datetime.timedelta, index: int = 0) -> dict:
198+
"""Return a failed-TSGuess sentinel dict."""
199+
return {
200+
'method': 'RitS',
201+
'method_direction': 'F',
202+
'method_index': index,
203+
'initial_xyz': None,
204+
'success': False,
205+
'execution_time': str(elapsed),
206+
}
207+
208+
209+
def parse_command_line_arguments(command_line_args: Optional[list] = None) -> argparse.Namespace:
210+
"""Parse the script's command-line arguments."""
211+
parser = argparse.ArgumentParser(description='Run RitS to generate TS guesses for an ARC reaction.')
212+
parser.add_argument('--yml_in_path', metavar='input', type=str, default='input.yml',
213+
help='Path to the input YAML file (default: ./input.yml).')
214+
return parser.parse_args(command_line_args)
215+
216+
217+
def main():
218+
"""Entry point: read input.yml, run RitS, write output YAML."""
219+
args = parse_command_line_arguments()
220+
yml_in_path = str(args.yml_in_path)
221+
if not os.path.isfile(yml_in_path):
222+
print(f'[rits_script] input file not found: {yml_in_path}', file=sys.stderr)
223+
sys.exit(1)
224+
input_dict = read_yaml_file(yml_in_path)
225+
226+
try:
227+
tsgs = run_rits(input_dict)
228+
except Exception:
229+
traceback.print_exc()
230+
tsgs = [_failed_guess(datetime.timedelta(0), index=0)]
231+
232+
save_yaml_file(path=input_dict['yml_out_path'], content=tsgs)
233+
n_ok = sum(1 for tsg in tsgs if tsg.get('success'))
234+
print(f'[rits_script] wrote {len(tsgs)} TSGuess entries ({n_ok} successful) to {input_dict["yml_out_path"]}',
235+
flush=True)
236+
237+
238+
if __name__ == '__main__':
239+
main()

arc/job/adapters/ts/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
import arc.job.adapters.ts.gcn_ts
33
import arc.job.adapters.ts.heuristics
44
import arc.job.adapters.ts.kinbot_ts
5+
import arc.job.adapters.ts.rits_ts
56
import arc.job.adapters.ts.xtb_gsm
67
import arc.job.adapters.ts.orca_neb

0 commit comments

Comments
 (0)