Skip to content

Commit 3e5f3e8

Browse files
committed
Added RitS to settings
1 parent f1c57b0 commit 3e5f3e8

1 file changed

Lines changed: 68 additions & 3 deletions

File tree

arc/settings/settings.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010
import string
1111
import sys
12+
from typing import Optional
1213

1314
# Users should update the following server dictionary.
1415
# Instructions for RSA key generation can be found here:
@@ -72,6 +73,7 @@
7273
'cfour': 'local',
7374
'gaussian': ['local', 'server2'],
7475
'gcn': 'local',
76+
'rits': 'local',
7577
'mockter': 'local',
7678
'molpro': ['local', 'server2'],
7779
'onedmin': 'server1',
@@ -89,7 +91,7 @@
8991
supported_ess = ['cfour', 'gaussian', 'mockter', 'molpro', 'orca', 'qchem', 'terachem', 'onedmin', 'xtb', 'torchani', 'openbabel']
9092

9193
# TS methods to try when appropriate for a reaction (other than user guesses which are always allowed):
92-
ts_adapters = ['heuristics', 'AutoTST', 'GCN', 'xtb_gsm', 'orca_neb']
94+
ts_adapters = ['heuristics', 'AutoTST', 'GCN', 'RitS', 'xtb_gsm', 'orca_neb']
9395

9496
# List here job types to execute by default
9597
default_job_types = {'conf_opt': True, # defaults to True if not specified
@@ -172,6 +174,7 @@
172174
output_filenames = {'cfour': 'output.out',
173175
'gaussian': 'input.log',
174176
'gcn': 'output.yml',
177+
'rits': 'output.yml',
175178
'mockter': 'output.yml',
176179
'molpro': 'input.out',
177180
'onedmin': 'output.out',
@@ -321,8 +324,9 @@
321324
ARC_FAMILIES_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'data', 'families')
322325

323326
# Default environment names for sister repos
324-
TS_GCN_PYTHON, TANI_PYTHON, AUTOTST_PYTHON, ARC_PYTHON, XTB, OB_PYTHON, RMG_PYTHON, RMG_PATH, RMG_DB_PATH = \
325-
None, None, None, None, None, None, None, None, None
327+
TS_GCN_PYTHON, TANI_PYTHON, AUTOTST_PYTHON, RITS_PYTHON, RITS_REPO_PATH, RITS_CKPT_PATH, \
328+
ARC_PYTHON, XTB, OB_PYTHON, RMG_PYTHON, RMG_PATH, RMG_DB_PATH = \
329+
None, None, None, None, None, None, None, None, None, None, None, None
326330

327331
home = os.getenv("HOME") or os.path.expanduser("~")
328332

@@ -362,11 +366,72 @@ def find_executable(env_name, executable_name='python'):
362366
OB_PYTHON = find_executable('ob_env')
363367
TS_GCN_PYTHON = find_executable('ts_gcn')
364368
AUTOTST_PYTHON = find_executable('tst_env')
369+
RITS_PYTHON = find_executable('rits_env')
365370
ARC_PYTHON = find_executable('arc_env')
366371
RMG_ENV_NAME = 'rmg_env'
367372
RMG_PYTHON = find_executable('rmg_env')
368373
XTB = find_executable('xtb_env', 'xtb')
369374

375+
376+
def find_rits_repo() -> Optional[str]:
377+
"""
378+
Locate a RitS source checkout. Used by the RitS TS adapter to find
379+
'scripts/sample_transition_state.py' and 'scripts/conf/rits.yaml',
380+
which are not part of the importable 'megalodon' package.
381+
382+
Search order:
383+
1. ``ARC_RITS_REPO`` environment variable (explicit override).
384+
2. ``~/Code/RitS`` (default for ARC dev machines).
385+
3. Sibling-of-ARC location ``<parent-of-arc-repo>/RitS`` —
386+
matches what ``devtools/install_rits.sh`` produces.
387+
388+
Returns:
389+
Optional[str]: Absolute path to the repo root, or ``None`` if
390+
nothing was found. The repo is considered "found" only if it
391+
contains ``scripts/sample_transition_state.py``.
392+
"""
393+
candidates = list()
394+
env_override = os.getenv('ARC_RITS_REPO')
395+
if env_override:
396+
candidates.append(env_override)
397+
candidates.append(os.path.join(home, 'Code', 'RitS'))
398+
arc_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
399+
candidates.append(os.path.join(os.path.dirname(arc_root), 'RitS'))
400+
for path in candidates:
401+
if path and os.path.isfile(os.path.join(path, 'scripts', 'sample_transition_state.py')):
402+
return os.path.abspath(path)
403+
return None
404+
405+
406+
def find_rits_ckpt(repo_path: Optional[str] = None) -> Optional[str]:
407+
"""
408+
Locate the pretrained RitS checkpoint file ('rits.ckpt').
409+
410+
Search order:
411+
1. ``ARC_RITS_CKPT`` environment variable (explicit override).
412+
2. ``<repo_path>/data/rits.ckpt`` — what ``install_rits.sh`` writes.
413+
414+
Args:
415+
repo_path (Optional[str]): The RitS repo path returned by
416+
``find_rits_repo()``. If ``None``, only the env-var override
417+
is consulted.
418+
419+
Returns:
420+
Optional[str]: Absolute path to the checkpoint, or ``None``.
421+
"""
422+
env_override = os.getenv('ARC_RITS_CKPT')
423+
if env_override and os.path.isfile(env_override):
424+
return os.path.abspath(env_override)
425+
if repo_path:
426+
candidate = os.path.join(repo_path, 'data', 'rits.ckpt')
427+
if os.path.isfile(candidate):
428+
return os.path.abspath(candidate)
429+
return None
430+
431+
432+
RITS_REPO_PATH = find_rits_repo()
433+
RITS_CKPT_PATH = find_rits_ckpt(RITS_REPO_PATH)
434+
370435
# Set RMG_DB_PATH with fallback methods
371436
rmg_db_candidates, rmg_candidates = list(), list()
372437

0 commit comments

Comments
 (0)