Skip to content

Commit 03fdcba

Browse files
committed
Fix switch_ts to reset state & clean up IRC when switching TS guesses
When a TS guess fails validation (e.g., NMD check), switch_ts picks the next guess but previously left stale state behind: 1. IRC species from the invalidated guess were never cleaned up. delete_all_species_jobs('TS0') only deletes jobs under the TS0 label, but IRC species like IRC_TS0_1 are separate entries in running_jobs/species_dict/etc. These orphaned species continued running in parallel with the new guess, potentially interfering with job processing. 2. job_types flags (freq, sp, opt) were never reset. After guess N's freq completed, job_types['freq'] = True carried over to guess N+1, causing the scheduler to skip re-running freq for the new geometry. 3. convergence was never reset to None. 4. The old line self.output[label]['geo'] = ... wrote to the wrong dict level (top-level keys instead of self.output[label]['paths']), making it dead code. 5. Pending pipe batches from the old guess were never discarded.
1 parent 32a8ea1 commit 03fdcba

2 files changed

Lines changed: 148 additions & 3 deletions

File tree

arc/scheduler.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2763,7 +2763,6 @@ def switch_ts(self, label: str):
27632763
logger.info(f'Switching a TS guess for {label}...')
27642764
self.determine_most_likely_ts_conformer(label=label) # Look for a different TS guess.
27652765
self.delete_all_species_jobs(label=label) # Delete other currently running jobs for this TS.
2766-
self.output[label]['geo'] = self.output[label]['freq'] = self.output[label]['sp'] = self.output[label]['composite'] = ''
27672766
freq_path = os.path.join(self.project_directory, 'output', 'rxns', label, 'geometry', 'freq.out')
27682767
if os.path.isfile(freq_path):
27692768
os.remove(freq_path)
@@ -3555,7 +3554,14 @@ def troubleshoot_ess(self,
35553554
f'log file:\n"{job.job_status[1]["line"]}".'
35563555
logger.warning(warning_message)
35573556
if self.species_dict[label].is_ts and conformer is not None:
3558-
xyz = self.species_dict[label].ts_guesses[conformer].get_xyz()
3557+
tsg = next((t for t in self.species_dict[label].ts_guesses
3558+
if t.conformer_index == conformer), None)
3559+
if tsg is not None:
3560+
xyz = tsg.get_xyz()
3561+
else:
3562+
logger.warning(f'Could not find TS guess with index {conformer} for {label}; '
3563+
f'skipping troubleshooting for this conformer.')
3564+
return None
35593565
elif conformer is not None:
35603566
xyz = self.species_dict[label].conformers[conformer]
35613567
else:
@@ -3705,6 +3711,33 @@ def delete_all_species_jobs(self, label: str):
37053711
job.delete()
37063712
self.running_jobs[label] = list()
37073713
self.output[label]['paths'] = {key: '' if key != 'irc' else list() for key in self.output[label]['paths'].keys()}
3714+
for job_type in self.output[label]['job_types']:
3715+
self.output[label]['job_types'][job_type] = False
3716+
self.output[label]['convergence'] = None
3717+
self._pending_pipe_sp.discard(label)
3718+
self._pending_pipe_freq.discard(label)
3719+
self._pending_pipe_irc.discard((label, 'forward'))
3720+
self._pending_pipe_irc.discard((label, 'reverse'))
3721+
# Clean up any IRC species spawned from this TS.
3722+
if label in self.species_dict and self.species_dict[label].is_ts:
3723+
irc_labels_str = self.species_dict[label].irc_label
3724+
if irc_labels_str:
3725+
for irc_label in irc_labels_str.split():
3726+
if irc_label in self.job_dict and irc_label in self.output:
3727+
self.delete_all_species_jobs(irc_label)
3728+
if irc_label in self.running_jobs:
3729+
del self.running_jobs[irc_label]
3730+
if irc_label in self.job_dict:
3731+
del self.job_dict[irc_label]
3732+
if irc_label in self.output:
3733+
del self.output[irc_label]
3734+
if irc_label in self.species_dict:
3735+
self.species_list = [spc for spc in self.species_list if spc.label != irc_label]
3736+
del self.species_dict[irc_label]
3737+
if irc_label in self.unique_species_labels:
3738+
self.unique_species_labels.remove(irc_label)
3739+
logger.info(f'Deleted IRC species {irc_label}.')
3740+
self.species_dict[label].irc_label = None
37083741

37093742
def restore_running_jobs(self):
37103743
"""

arc/scheduler_test.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
import unittest
9+
from unittest.mock import patch
910
import os
1011
import shutil
1112

@@ -19,7 +20,7 @@
1920
from arc.imports import settings
2021
from arc.reaction import ARCReaction
2122
from arc.species.converter import str_to_xyz
22-
from arc.species.species import ARCSpecies
23+
from arc.species.species import ARCSpecies, TSGuess
2324

2425

2526
default_levels_of_theory = settings['default_levels_of_theory']
@@ -757,6 +758,117 @@ def test_add_label_to_unique_species_labels(self):
757758
self.assertEqual(unique_label, 'new_species_15_1')
758759
self.assertEqual(self.sched2.unique_species_labels, ['methylamine', 'C2H6', 'CtripCO', 'new_species_15', 'new_species_15_0', 'new_species_15_1'])
759760

761+
@patch('arc.scheduler.Scheduler.run_opt_job')
762+
def test_switch_ts_cleanup(self, mock_run_opt):
763+
"""Test that switch_ts resets job_types, convergence, cleans up IRC species, and clears pending pipes."""
764+
ts_xyz = str_to_xyz("""N 0.91779059 0.51946178 0.00000000
765+
H 1.81402049 1.03819414 0.00000000
766+
H 0.00000000 0.00000000 0.00000000
767+
H 0.91779059 1.22790192 0.72426890""")
768+
769+
ts_spc = ARCSpecies(label='TS_test', is_ts=True, xyz=ts_xyz, multiplicity=1, charge=0,
770+
compute_thermo=False)
771+
# Create two TSGuess objects so determine_most_likely_ts_conformer can pick the 2nd after the 1st fails.
772+
ts_spc.ts_guesses = [
773+
TSGuess(index=0, method='heuristics', success=True, energy=100.0, xyz=ts_xyz,
774+
execution_time='0:00:01'),
775+
TSGuess(index=1, method='heuristics', success=True, energy=110.0, xyz=ts_xyz,
776+
execution_time='0:00:01'),
777+
]
778+
ts_spc.ts_guesses[0].opt_xyz = ts_xyz
779+
ts_spc.ts_guesses[0].imaginary_freqs = [-500.0]
780+
ts_spc.ts_guesses[1].opt_xyz = ts_xyz
781+
ts_spc.ts_guesses[1].imaginary_freqs = [-400.0]
782+
# Simulate guess 0 already tried.
783+
ts_spc.chosen_ts = 0
784+
ts_spc.chosen_ts_list = [0]
785+
ts_spc.ts_guesses_exhausted = False
786+
787+
project_directory = os.path.join(ARC_PATH, 'Projects',
788+
'arc_project_for_testing_delete_after_usage4')
789+
self.addCleanup(shutil.rmtree, project_directory, ignore_errors=True)
790+
sched = Scheduler(project='test_switch_ts', ess_settings=self.ess_settings,
791+
species_list=[ts_spc],
792+
opt_level=Level(repr=default_levels_of_theory['opt']),
793+
freq_level=Level(repr=default_levels_of_theory['freq']),
794+
sp_level=Level(repr=default_levels_of_theory['sp']),
795+
ts_guess_level=Level(repr=default_levels_of_theory['ts_guesses']),
796+
project_directory=project_directory,
797+
testing=True,
798+
job_types=self.job_types1,
799+
)
800+
801+
ts_label = 'TS_test'
802+
# Simulate state after guess 0 completed: freq/sp/opt marked done.
803+
sched.output[ts_label]['job_types']['opt'] = True
804+
sched.output[ts_label]['job_types']['freq'] = True
805+
sched.output[ts_label]['job_types']['sp'] = True
806+
sched.output[ts_label]['convergence'] = True
807+
sched.job_dict[ts_label] = {'opt': {}, 'freq': {}, 'sp': {}}
808+
sched.running_jobs[ts_label] = []
809+
810+
# Simulate IRC species spawned from guess 0.
811+
irc_label_1 = 'IRC_TS_test_1'
812+
irc_label_2 = 'IRC_TS_test_2'
813+
irc_spc_1 = ARCSpecies(label=irc_label_1, xyz=ts_xyz, compute_thermo=False,
814+
irc_label=ts_label)
815+
irc_spc_2 = ARCSpecies(label=irc_label_2, xyz=ts_xyz, compute_thermo=False,
816+
irc_label=ts_label)
817+
ts_spc.irc_label = f'{irc_label_1} {irc_label_2}'
818+
sched.species_dict[irc_label_1] = irc_spc_1
819+
sched.species_dict[irc_label_2] = irc_spc_2
820+
sched.species_list.extend([irc_spc_1, irc_spc_2])
821+
sched.unique_species_labels.extend([irc_label_1, irc_label_2])
822+
sched.running_jobs[irc_label_1] = ['opt_a100']
823+
sched.running_jobs[irc_label_2] = ['opt_a101']
824+
sched.job_dict[irc_label_1] = {'opt': {}}
825+
sched.job_dict[irc_label_2] = {'opt': {}}
826+
sched.initialize_output_dict(label=irc_label_1)
827+
sched.initialize_output_dict(label=irc_label_2)
828+
829+
# Simulate pending pipe entries from the old guess.
830+
sched._pending_pipe_sp.add(ts_label)
831+
sched._pending_pipe_freq.add(ts_label)
832+
sched._pending_pipe_irc.add((ts_label, 'forward'))
833+
sched._pending_pipe_irc.add((ts_label, 'reverse'))
834+
835+
# Call switch_ts — should pick guess 1 and clean up all state from guess 0.
836+
sched.switch_ts(ts_label)
837+
838+
# Verify guess 1 was selected.
839+
self.assertEqual(sched.species_dict[ts_label].chosen_ts, 1)
840+
self.assertIn(1, sched.species_dict[ts_label].chosen_ts_list)
841+
842+
# Verify IRC species from guess 0 fully removed.
843+
self.assertNotIn(irc_label_1, sched.species_dict)
844+
self.assertNotIn(irc_label_2, sched.species_dict)
845+
self.assertNotIn(irc_label_1, sched.running_jobs)
846+
self.assertNotIn(irc_label_2, sched.running_jobs)
847+
self.assertNotIn(irc_label_1, sched.job_dict)
848+
self.assertNotIn(irc_label_2, sched.job_dict)
849+
self.assertNotIn(irc_label_1, sched.output)
850+
self.assertNotIn(irc_label_2, sched.output)
851+
self.assertNotIn(irc_label_1, sched.unique_species_labels)
852+
self.assertNotIn(irc_label_2, sched.unique_species_labels)
853+
self.assertIsNone(sched.species_dict[ts_label].irc_label)
854+
855+
# Verify job_types reset and convergence cleared.
856+
self.assertFalse(sched.output[ts_label]['job_types']['opt'])
857+
self.assertFalse(sched.output[ts_label]['job_types']['freq'])
858+
self.assertFalse(sched.output[ts_label]['job_types']['sp'])
859+
self.assertIsNone(sched.output[ts_label]['convergence'])
860+
861+
# Verify pending pipe entries cleared.
862+
self.assertNotIn(ts_label, sched._pending_pipe_sp)
863+
self.assertNotIn(ts_label, sched._pending_pipe_freq)
864+
self.assertNotIn((ts_label, 'forward'), sched._pending_pipe_irc)
865+
self.assertNotIn((ts_label, 'reverse'), sched._pending_pipe_irc)
866+
867+
# Verify ts_checks were reset.
868+
self.assertIsNone(sched.species_dict[ts_label].ts_checks['freq'])
869+
self.assertIsNone(sched.species_dict[ts_label].ts_checks['NMD'])
870+
self.assertIsNone(sched.species_dict[ts_label].ts_checks['E0'])
871+
760872
@classmethod
761873
def tearDownClass(cls):
762874
"""

0 commit comments

Comments
 (0)