Skip to content

Commit a160ef7

Browse files
committed
Fix switch_ts to reset state & clean up IRC when switching TS guesses
When a TS guess fails validation (e.g., NMD check), switch_ts picks the next guess but previously left stale state behind: 1. IRC species from the invalidated guess were never cleaned up. delete_all_species_jobs('TS0') only deletes jobs under the TS0 label, but IRC species like IRC_TS0_1 are separate entries in running_jobs/species_dict/etc. These orphaned species continued running in parallel with the new guess, potentially interfering with job processing. 2. job_types flags (freq, sp, opt) were never reset. After guess N's freq completed, job_types['freq'] = True carried over to guess N+1, causing the scheduler to skip re-running freq for the new geometry. 3. convergence was never reset to None. 4. The old line self.output[label]['geo'] = ... wrote to the wrong dict level (top-level keys instead of self.output[label]['paths']), making it dead code. 5. Pending pipe batches from the old guess were never discarded.
1 parent 5c50f88 commit a160ef7

2 files changed

Lines changed: 145 additions & 2 deletions

File tree

arc/scheduler.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2763,7 +2763,36 @@ def switch_ts(self, label: str):
27632763
logger.info(f'Switching a TS guess for {label}...')
27642764
self.determine_most_likely_ts_conformer(label=label) # Look for a different TS guess.
27652765
self.delete_all_species_jobs(label=label) # Delete other currently running jobs for this TS.
2766-
self.output[label]['geo'] = self.output[label]['freq'] = self.output[label]['sp'] = self.output[label]['composite'] = ''
2766+
# Clean up IRC species spawned from the invalidated TS guess.
2767+
irc_labels_str = self.species_dict[label].irc_label
2768+
if irc_labels_str:
2769+
for irc_label in irc_labels_str.split():
2770+
if irc_label in self.job_dict and irc_label in self.output:
2771+
self.delete_all_species_jobs(irc_label)
2772+
if irc_label in self.running_jobs:
2773+
del self.running_jobs[irc_label]
2774+
if irc_label in self.job_dict:
2775+
del self.job_dict[irc_label]
2776+
if irc_label in self.output:
2777+
del self.output[irc_label]
2778+
if irc_label in self.species_dict:
2779+
self.species_list = [spc for spc in self.species_list if spc.label != irc_label]
2780+
del self.species_dict[irc_label]
2781+
if irc_label in self.unique_species_labels:
2782+
self.unique_species_labels.remove(irc_label)
2783+
logger.info(f'Deleted IRC species {irc_label} from invalidated TS guess.')
2784+
self.species_dict[label].irc_label = None
2785+
# Reset job_types so the new guess's pipeline runs from scratch.
2786+
for job_type in self.output[label]['job_types']:
2787+
if job_type in ['rotors', 'bde']:
2788+
continue
2789+
self.output[label]['job_types'][job_type] = False
2790+
self.output[label]['convergence'] = None
2791+
# Discard any pending pipe jobs queued for the OLD guess geometry.
2792+
self._pending_pipe_sp.discard(label)
2793+
self._pending_pipe_freq.discard(label)
2794+
self._pending_pipe_irc.discard((label, 'forward'))
2795+
self._pending_pipe_irc.discard((label, 'reverse'))
27672796
freq_path = os.path.join(self.project_directory, 'output', 'rxns', label, 'geometry', 'freq.out')
27682797
if os.path.isfile(freq_path):
27692798
os.remove(freq_path)

arc/scheduler_test.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
import unittest
9+
from unittest.mock import patch
910
import os
1011
import shutil
1112

@@ -19,7 +20,7 @@
1920
from arc.imports import settings
2021
from arc.reaction import ARCReaction
2122
from arc.species.converter import str_to_xyz
22-
from arc.species.species import ARCSpecies
23+
from arc.species.species import ARCSpecies, TSGuess
2324

2425

2526
default_levels_of_theory = settings['default_levels_of_theory']
@@ -757,6 +758,119 @@ def test_add_label_to_unique_species_labels(self):
757758
self.assertEqual(unique_label, 'new_species_15_1')
758759
self.assertEqual(self.sched2.unique_species_labels, ['methylamine', 'C2H6', 'CtripCO', 'new_species_15', 'new_species_15_0', 'new_species_15_1'])
759760

761+
@patch('arc.scheduler.Scheduler.run_opt_job')
762+
def test_switch_ts_cleanup(self, mock_run_opt):
763+
"""Test that switch_ts resets job_types, convergence, cleans up IRC species, and clears pending pipes."""
764+
ts_xyz = str_to_xyz("""N 0.91779059 0.51946178 0.00000000
765+
H 1.81402049 1.03819414 0.00000000
766+
H 0.00000000 0.00000000 0.00000000
767+
H 0.91779059 1.22790192 0.72426890""")
768+
769+
ts_spc = ARCSpecies(label='TS_test', is_ts=True, xyz=ts_xyz, multiplicity=1, charge=0,
770+
compute_thermo=False)
771+
# Create two TSGuess objects so determine_most_likely_ts_conformer can pick the 2nd after the 1st fails.
772+
ts_spc.ts_guesses = [
773+
TSGuess(index=0, method='heuristics', success=True, energy=100.0, xyz=ts_xyz,
774+
execution_time='0:00:01'),
775+
TSGuess(index=1, method='heuristics', success=True, energy=110.0, xyz=ts_xyz,
776+
execution_time='0:00:01'),
777+
]
778+
ts_spc.ts_guesses[0].opt_xyz = ts_xyz
779+
ts_spc.ts_guesses[0].imaginary_freqs = [-500.0]
780+
ts_spc.ts_guesses[1].opt_xyz = ts_xyz
781+
ts_spc.ts_guesses[1].imaginary_freqs = [-400.0]
782+
# Simulate guess 0 already tried.
783+
ts_spc.chosen_ts = 0
784+
ts_spc.chosen_ts_list = [0]
785+
ts_spc.ts_guesses_exhausted = False
786+
787+
sched = Scheduler(project='test_switch_ts', ess_settings=self.ess_settings,
788+
species_list=[ts_spc],
789+
opt_level=Level(repr=default_levels_of_theory['opt']),
790+
freq_level=Level(repr=default_levels_of_theory['freq']),
791+
sp_level=Level(repr=default_levels_of_theory['sp']),
792+
ts_guess_level=Level(repr=default_levels_of_theory['ts_guesses']),
793+
project_directory=os.path.join(ARC_PATH, 'Projects',
794+
'arc_project_for_testing_delete_after_usage4'),
795+
testing=True,
796+
job_types=self.job_types1,
797+
)
798+
799+
ts_label = 'TS_test'
800+
# Simulate state after guess 0 completed: freq/sp/opt marked done.
801+
sched.output[ts_label]['job_types']['opt'] = True
802+
sched.output[ts_label]['job_types']['freq'] = True
803+
sched.output[ts_label]['job_types']['sp'] = True
804+
sched.output[ts_label]['convergence'] = True
805+
sched.job_dict[ts_label] = {'opt': {}, 'freq': {}, 'sp': {}}
806+
sched.running_jobs[ts_label] = []
807+
808+
# Simulate IRC species spawned from guess 0.
809+
irc_label_1 = 'IRC_TS_test_1'
810+
irc_label_2 = 'IRC_TS_test_2'
811+
irc_spc_1 = ARCSpecies(label=irc_label_1, xyz=ts_xyz, compute_thermo=False,
812+
irc_label=ts_label)
813+
irc_spc_2 = ARCSpecies(label=irc_label_2, xyz=ts_xyz, compute_thermo=False,
814+
irc_label=ts_label)
815+
ts_spc.irc_label = f'{irc_label_1} {irc_label_2}'
816+
sched.species_dict[irc_label_1] = irc_spc_1
817+
sched.species_dict[irc_label_2] = irc_spc_2
818+
sched.species_list.extend([irc_spc_1, irc_spc_2])
819+
sched.unique_species_labels.extend([irc_label_1, irc_label_2])
820+
sched.running_jobs[irc_label_1] = ['opt_a100']
821+
sched.running_jobs[irc_label_2] = ['opt_a101']
822+
sched.job_dict[irc_label_1] = {'opt': {}}
823+
sched.job_dict[irc_label_2] = {'opt': {}}
824+
sched.initialize_output_dict(label=irc_label_1)
825+
sched.initialize_output_dict(label=irc_label_2)
826+
827+
# Simulate pending pipe entries from the old guess.
828+
sched._pending_pipe_sp.add(ts_label)
829+
sched._pending_pipe_freq.add(ts_label)
830+
sched._pending_pipe_irc.add((ts_label, 'forward'))
831+
sched._pending_pipe_irc.add((ts_label, 'reverse'))
832+
833+
# Call switch_ts — should pick guess 1 and clean up all state from guess 0.
834+
sched.switch_ts(ts_label)
835+
836+
# Verify guess 1 was selected.
837+
self.assertEqual(sched.species_dict[ts_label].chosen_ts, 1)
838+
self.assertIn(1, sched.species_dict[ts_label].chosen_ts_list)
839+
840+
# Verify IRC species from guess 0 fully removed.
841+
self.assertNotIn(irc_label_1, sched.species_dict)
842+
self.assertNotIn(irc_label_2, sched.species_dict)
843+
self.assertNotIn(irc_label_1, sched.running_jobs)
844+
self.assertNotIn(irc_label_2, sched.running_jobs)
845+
self.assertNotIn(irc_label_1, sched.job_dict)
846+
self.assertNotIn(irc_label_2, sched.job_dict)
847+
self.assertNotIn(irc_label_1, sched.output)
848+
self.assertNotIn(irc_label_2, sched.output)
849+
self.assertNotIn(irc_label_1, sched.unique_species_labels)
850+
self.assertNotIn(irc_label_2, sched.unique_species_labels)
851+
self.assertIsNone(sched.species_dict[ts_label].irc_label)
852+
853+
# Verify job_types reset and convergence cleared.
854+
self.assertFalse(sched.output[ts_label]['job_types']['opt'])
855+
self.assertFalse(sched.output[ts_label]['job_types']['freq'])
856+
self.assertFalse(sched.output[ts_label]['job_types']['sp'])
857+
self.assertIsNone(sched.output[ts_label]['convergence'])
858+
859+
# Verify pending pipe entries cleared.
860+
self.assertNotIn(ts_label, sched._pending_pipe_sp)
861+
self.assertNotIn(ts_label, sched._pending_pipe_freq)
862+
self.assertNotIn((ts_label, 'forward'), sched._pending_pipe_irc)
863+
self.assertNotIn((ts_label, 'reverse'), sched._pending_pipe_irc)
864+
865+
# Verify ts_checks were reset.
866+
self.assertIsNone(sched.species_dict[ts_label].ts_checks['freq'])
867+
self.assertIsNone(sched.species_dict[ts_label].ts_checks['NMD'])
868+
self.assertIsNone(sched.species_dict[ts_label].ts_checks['E0'])
869+
870+
# Clean up.
871+
shutil.rmtree(os.path.join(ARC_PATH, 'Projects', 'arc_project_for_testing_delete_after_usage4'),
872+
ignore_errors=True)
873+
760874
@classmethod
761875
def tearDownClass(cls):
762876
"""

0 commit comments

Comments
 (0)