|
9 | 9 | import os |
10 | 10 | import string |
11 | 11 | import sys |
| 12 | +from typing import Optional |
12 | 13 |
|
13 | 14 | # Users should update the following server dictionary. |
14 | 15 | # Instructions for RSA key generation can be found here: |
|
72 | 73 | 'cfour': 'local', |
73 | 74 | 'gaussian': ['local', 'server2'], |
74 | 75 | 'gcn': 'local', |
| 76 | + 'rits': 'local', |
75 | 77 | 'mockter': 'local', |
76 | 78 | 'molpro': ['local', 'server2'], |
77 | 79 | 'onedmin': 'server1', |
|
89 | 91 | supported_ess = ['cfour', 'gaussian', 'mockter', 'molpro', 'orca', 'qchem', 'terachem', 'onedmin', 'xtb', 'torchani', 'openbabel'] |
90 | 92 |
|
91 | 93 | # TS methods to try when appropriate for a reaction (other than user guesses which are always allowed): |
92 | | -ts_adapters = ['heuristics', 'AutoTST', 'GCN', 'xtb_gsm', 'orca_neb'] |
| 94 | +ts_adapters = ['heuristics', 'AutoTST', 'GCN', 'RitS', 'xtb_gsm', 'orca_neb'] |
93 | 95 |
|
94 | 96 | # List here job types to execute by default |
95 | 97 | default_job_types = {'conf_opt': True, # defaults to True if not specified |
|
172 | 174 | output_filenames = {'cfour': 'output.out', |
173 | 175 | 'gaussian': 'input.log', |
174 | 176 | 'gcn': 'output.yml', |
| 177 | + 'rits': 'output.yml', |
175 | 178 | 'mockter': 'output.yml', |
176 | 179 | 'molpro': 'input.out', |
177 | 180 | 'onedmin': 'output.out', |
|
321 | 324 | ARC_FAMILIES_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'data', 'families') |
322 | 325 |
|
323 | 326 | # Default environment names for sister repos |
324 | | -TS_GCN_PYTHON, TANI_PYTHON, AUTOTST_PYTHON, ARC_PYTHON, XTB, OB_PYTHON, RMG_PYTHON, RMG_PATH, RMG_DB_PATH = \ |
325 | | - None, None, None, None, None, None, None, None, None |
| 327 | +TS_GCN_PYTHON, TANI_PYTHON, AUTOTST_PYTHON, RITS_PYTHON, RITS_REPO_PATH, RITS_CKPT_PATH, \ |
| 328 | + ARC_PYTHON, XTB, OB_PYTHON, RMG_PYTHON, RMG_PATH, RMG_DB_PATH = \ |
| 329 | + None, None, None, None, None, None, None, None, None, None, None, None |
326 | 330 |
|
327 | 331 | home = os.getenv("HOME") or os.path.expanduser("~") |
328 | 332 |
|
@@ -362,11 +366,72 @@ def find_executable(env_name, executable_name='python'): |
362 | 366 | OB_PYTHON = find_executable('ob_env') |
363 | 367 | TS_GCN_PYTHON = find_executable('ts_gcn') |
364 | 368 | AUTOTST_PYTHON = find_executable('tst_env') |
| 369 | +RITS_PYTHON = find_executable('rits_env') |
365 | 370 | ARC_PYTHON = find_executable('arc_env') |
366 | 371 | RMG_ENV_NAME = 'rmg_env' |
367 | 372 | RMG_PYTHON = find_executable('rmg_env') |
368 | 373 | XTB = find_executable('xtb_env', 'xtb') |
369 | 374 |
|
| 375 | + |
| 376 | +def find_rits_repo() -> Optional[str]: |
| 377 | + """ |
| 378 | + Locate a RitS source checkout. Used by the RitS TS adapter to find |
| 379 | + 'scripts/sample_transition_state.py' and 'scripts/conf/rits.yaml', |
| 380 | + which are not part of the importable 'megalodon' package. |
| 381 | +
|
| 382 | + Search order: |
| 383 | + 1. ``ARC_RITS_REPO`` environment variable (explicit override). |
| 384 | + 2. ``~/Code/RitS`` (default for ARC dev machines). |
| 385 | + 3. Sibling-of-ARC location ``<parent-of-arc-repo>/RitS`` — |
| 386 | + matches what ``devtools/install_rits.sh`` produces. |
| 387 | +
|
| 388 | + Returns: |
| 389 | + Optional[str]: Absolute path to the repo root, or ``None`` if |
| 390 | + nothing was found. The repo is considered "found" only if it |
| 391 | + contains ``scripts/sample_transition_state.py``. |
| 392 | + """ |
| 393 | + candidates = list() |
| 394 | + env_override = os.getenv('ARC_RITS_REPO') |
| 395 | + if env_override: |
| 396 | + candidates.append(env_override) |
| 397 | + candidates.append(os.path.join(home, 'Code', 'RitS')) |
| 398 | + arc_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| 399 | + candidates.append(os.path.join(os.path.dirname(arc_root), 'RitS')) |
| 400 | + for path in candidates: |
| 401 | + if path and os.path.isfile(os.path.join(path, 'scripts', 'sample_transition_state.py')): |
| 402 | + return os.path.abspath(path) |
| 403 | + return None |
| 404 | + |
| 405 | + |
| 406 | +def find_rits_ckpt(repo_path: Optional[str] = None) -> Optional[str]: |
| 407 | + """ |
| 408 | + Locate the pretrained RitS checkpoint file ('rits.ckpt'). |
| 409 | +
|
| 410 | + Search order: |
| 411 | + 1. ``ARC_RITS_CKPT`` environment variable (explicit override). |
| 412 | + 2. ``<repo_path>/data/rits.ckpt`` — what ``install_rits.sh`` writes. |
| 413 | +
|
| 414 | + Args: |
| 415 | + repo_path (Optional[str]): The RitS repo path returned by |
| 416 | + ``find_rits_repo()``. If ``None``, only the env-var override |
| 417 | + is consulted. |
| 418 | +
|
| 419 | + Returns: |
| 420 | + Optional[str]: Absolute path to the checkpoint, or ``None``. |
| 421 | + """ |
| 422 | + env_override = os.getenv('ARC_RITS_CKPT') |
| 423 | + if env_override and os.path.isfile(env_override): |
| 424 | + return os.path.abspath(env_override) |
| 425 | + if repo_path: |
| 426 | + candidate = os.path.join(repo_path, 'data', 'rits.ckpt') |
| 427 | + if os.path.isfile(candidate): |
| 428 | + return os.path.abspath(candidate) |
| 429 | + return None |
| 430 | + |
| 431 | + |
| 432 | +RITS_REPO_PATH = find_rits_repo() |
| 433 | +RITS_CKPT_PATH = find_rits_ckpt(RITS_REPO_PATH) |
| 434 | + |
370 | 435 | # Set RMG_DB_PATH with fallback methods |
371 | 436 | rmg_db_candidates, rmg_candidates = list(), list() |
372 | 437 |
|
|
0 commit comments