Skip to content

Commit 9c4221a

Browse files
committed
Decouple species stage completion checks in the Scheduler
Refactor the scheduling logic to separate job termination processing from stage transition decisions. By moving conformer, TS guess, and species completion checks into dedicated methods called at the end of the polling loop, the state machine becomes more robust and ensures transitions are evaluated consistently across different execution modes.
1 parent d7b20a3 commit 9c4221a

2 files changed

Lines changed: 245 additions & 49 deletions

File tree

arc/scheduler.py

Lines changed: 114 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -622,55 +622,21 @@ def schedule_jobs(self):
622622
job = self.job_dict[label]['conf_opt'][i] if 'conf_opt' in job_name \
623623
else self.job_dict[label]['conf_sp'][i]
624624
if not (job.job_id in self.server_job_ids and job.job_id not in self.completed_incore_jobs):
625-
# this is a completed conformer job
626625
successful_server_termination = self.end_job(job=job, label=label, job_name=job_name)
627626
if successful_server_termination:
628627
troubleshooting_conformer = self.parse_conformer(job=job, label=label, i=i)
629628
if 'conf_opt' in job_name and self.job_types['conf_sp'] and not troubleshooting_conformer:
630629
# Accumulate for deferred pipe batching of conf_sp.
631630
self._pending_pipe_conf_sp.setdefault(label, set()).add(i)
632-
if troubleshooting_conformer:
633-
break
634-
# Just terminated a conformer job.
635-
# Are there additional conformer jobs currently running for this species?
636-
# Note: end_job already removed the current job from running_jobs,
637-
# so we don't need to exclude job_name.
638-
for spec_jobs in job_list:
639-
if 'conf_opt' in spec_jobs or 'conf_sp' in spec_jobs:
640-
break
641-
else:
642-
# All conformer jobs terminated.
643-
# Check isomorphism and run opt on most stable conformer geometry.
644-
logger.info(f'\nConformer jobs for {label} successfully terminated.\n')
645-
if self.species_dict[label].is_ts:
646-
self.determine_most_likely_ts_conformer(label)
647-
else:
648-
self.determine_most_stable_conformer(label, sp_flag=True if self.job_types['conf_sp'] else False) # also checks isomorphism
649-
if self.species_dict[label].initial_xyz is not None:
650-
# if initial_xyz is None, then we're probably troubleshooting conformers, don't opt
651-
if not self.composite_method:
652-
self.run_opt_job(label, fine=self.fine_only)
653-
else:
654-
self.run_composite_job(label)
655631
self.timer = False
656632
break
657633
if 'tsg' in job_name:
658634
job = self.job_dict[label]['tsg'][get_i_from_job_name(job_name)]
659635
if not (job.job_id in self.server_job_ids and job.job_id not in self.completed_incore_jobs):
660-
# This is a successfully completed tsg job. It may have resulted in several TSGuesses.
661636
self.end_job(job=job, label=label, job_name=job_name)
662637
if job.local_path_to_output_file.endswith('.yml') or job.local_path_to_output_file.endswith('.log'):
663638
for rxn in job.reactions:
664639
rxn.ts_species.process_completed_tsg_queue_jobs(path=job.local_path_to_output_file)
665-
# Just terminated a tsg job.
666-
# Are there additional tsg jobs currently running for this species?
667-
for spec_jobs in job_list:
668-
if 'tsg' in spec_jobs:
669-
break
670-
else:
671-
# All tsg jobs terminated. Spawn confs.
672-
logger.info(f'\nTS guess jobs for {label} successfully terminated.\n')
673-
self.run_conformer_jobs(labels=[label])
674640
self.timer = False
675641
break
676642
elif 'opt' in job_name and 'conf_opt' not in job_name:
@@ -803,20 +769,12 @@ def schedule_jobs(self):
803769
self.timer = False
804770
break
805771

806-
if not len(job_list):
807-
has_pending_pipe_work = (
808-
label in self._pending_pipe_sp
809-
or label in self._pending_pipe_freq
810-
or any(lbl == label for lbl, _ in self._pending_pipe_irc)
811-
or label in self._pending_pipe_conf_sp
812-
or any(label in {t.owner_key for t in p.tasks}
813-
for p in self.active_pipes.values())
814-
)
815-
if not has_pending_pipe_work:
816-
self.check_all_done(label)
817-
if not self.running_jobs[label]:
818-
# Delete the label only if it represents an empty entry.
819-
del self.running_jobs[label]
772+
for label in list(self.unique_species_labels):
773+
if label in self.output and self.output[label]['convergence'] is False:
774+
continue
775+
self._check_conformer_stage_complete(label)
776+
self._check_tsg_stage_complete(label)
777+
self._check_species_complete(label)
820778

821779
# Poll active pipe runs (per-run failures are handled inside poll_pipes).
822780
if self.active_pipes:
@@ -840,6 +798,114 @@ def schedule_jobs(self):
840798
# Generate a TS report:
841799
self.generate_final_ts_guess_report()
842800

801+
def _check_conformer_stage_complete(self, label: str) -> None:
802+
"""
803+
Check whether all conformer jobs (conf_opt/conf_sp) for a species have
804+
finished. If so, select the best conformer and spawn the next job.
805+
806+
Called unconditionally after job event processing so that no break
807+
in the job-processing loop can skip the conformer-to-opt transition.
808+
"""
809+
if 'conf_opt' not in self.job_dict.get(label, {}):
810+
return
811+
if any('conf_opt' in j or 'conf_sp' in j
812+
for j in self.running_jobs.get(label, [])):
813+
return
814+
if label in self._pending_pipe_conf_sp:
815+
return
816+
if any(label in {t.owner_key for t in p.tasks}
817+
for p in self.active_pipes.values()
818+
if any(t.task_family in ('conf_opt', 'conf_sp', 'ts_opt') for t in p.tasks)):
819+
return
820+
if self.species_dict[label].initial_xyz is not None:
821+
return
822+
if self.output[label].get('job_types', {}).get('conf_opt'):
823+
return
824+
if self.species_dict[label].is_ts and self.species_dict[label].ts_guesses_exhausted:
825+
return
826+
827+
if self.species_dict[label].is_ts:
828+
has_successful_conformer = any(
829+
tsg.energy is not None for tsg in self.species_dict[label].ts_guesses)
830+
else:
831+
has_successful_conformer = any(
832+
e is not None for e in self.species_dict[label].conformer_energies)
833+
834+
if not has_successful_conformer:
835+
logger.error(f'All conformer jobs for {label} failed. '
836+
f'No conformer has a valid energy.')
837+
if self.species_dict[label].is_ts:
838+
self.species_dict[label].ts_guesses_exhausted = True
839+
return
840+
841+
logger.info(f'\nConformer jobs for {label} successfully terminated.\n')
842+
if self.species_dict[label].is_ts:
843+
self.determine_most_likely_ts_conformer(label)
844+
else:
845+
self.determine_most_stable_conformer(
846+
label, sp_flag=True if self.job_types.get('conf_sp') else False)
847+
if self.species_dict[label].initial_xyz is not None:
848+
if not self.composite_method:
849+
self.run_opt_job(label, fine=self.fine_only)
850+
else:
851+
self.run_composite_job(label)
852+
elif not any('conf_opt' in j or 'conf_sp' in j
853+
for j in self.running_jobs.get(label, [])):
854+
self.output[label]['job_types']['conf_opt'] = True
855+
856+
def _check_tsg_stage_complete(self, label: str) -> None:
857+
"""
858+
Check whether all TS guess jobs for a species have finished.
859+
If so, spawn conformer jobs for the TS.
860+
"""
861+
if 'tsg' not in self.job_dict.get(label, {}):
862+
return
863+
if any('tsg' in j for j in self.running_jobs.get(label, [])):
864+
return
865+
if not self.species_dict[label].is_ts:
866+
return
867+
if self.species_dict[label].ts_conf_spawned:
868+
return
869+
if self.species_dict[label].ts_guesses_exhausted:
870+
return
871+
if not all(tsg.success is not None for tsg in self.species_dict[label].ts_guesses):
872+
return
873+
874+
if not any(tsg.success for tsg in self.species_dict[label].ts_guesses):
875+
logger.error(f'All TS guess jobs for {label} failed. '
876+
f'No successful TS guess found.')
877+
self.species_dict[label].ts_guesses_exhausted = True
878+
return
879+
880+
logger.info(f'\nTS guess jobs for {label} successfully terminated.\n')
881+
self.run_conformer_jobs(labels=[label])
882+
883+
def _check_species_complete(self, label: str) -> None:
884+
"""
885+
Check whether all jobs for a species are complete and call
886+
check_all_done if so. Clean up empty running_jobs entries.
887+
"""
888+
if label in self.output and self.output[label]['convergence'] is not None:
889+
# Species already finalized (converged or failed); clean up and skip.
890+
if label in self.running_jobs and not self.running_jobs[label]:
891+
del self.running_jobs[label]
892+
return
893+
running = self.running_jobs.get(label, [])
894+
if running:
895+
return
896+
has_pending_pipe_work = (
897+
label in self._pending_pipe_sp
898+
or label in self._pending_pipe_freq
899+
or any(lbl == label for lbl, _ in self._pending_pipe_irc)
900+
or label in self._pending_pipe_conf_sp
901+
or any(label in {t.owner_key for t in p.tasks}
902+
for p in self.active_pipes.values())
903+
)
904+
if not has_pending_pipe_work:
905+
self.check_all_done(label)
906+
if label in self.running_jobs and not self.running_jobs[label]:
907+
del self.running_jobs[label]
908+
843909
def run_job(self,
844910
job_type: str,
845911
conformer: Optional[int] = None,

arc/scheduler_test.py

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import unittest
99
import os
1010
import shutil
11+
from unittest.mock import patch, MagicMock
1112

1213
import arc.parser.parser as parser
1314
from arc.checks.ts import check_ts
@@ -19,7 +20,7 @@
1920
from arc.imports import settings
2021
from arc.reaction import ARCReaction
2122
from arc.species.converter import str_to_xyz
22-
from arc.species.species import ARCSpecies
23+
from arc.species.species import ARCSpecies, TSGuess
2324

2425

2526
default_levels_of_theory = settings['default_levels_of_theory']
@@ -757,6 +758,135 @@ def test_add_label_to_unique_species_labels(self):
757758
self.assertEqual(unique_label, 'new_species_15_1')
758759
self.assertEqual(self.sched2.unique_species_labels, ['methylamine', 'C2H6', 'CtripCO', 'new_species_15', 'new_species_15_0', 'new_species_15_1'])
759760

761+
def _make_isolated_scheduler(self):
762+
"""Create a Scheduler with a fresh species object for tests that mutate species state."""
763+
spc = ARCSpecies(label='spc_test', smiles='CN',
764+
xyz=str_to_xyz("""C -0.57422867 -0.01669771 0.01229213
765+
N 0.82084044 0.08279104 -0.37769346
766+
H -1.05737005 -0.84067772 -0.52007494
767+
H -1.10211468 0.90879867 -0.23383011
768+
H -0.66133128 -0.19490562 1.08785111
769+
H 0.88047852 0.26966160 -1.37780789
770+
H 1.27889520 -0.81548721 -0.22940984"""))
771+
sched = Scheduler(
772+
project='project_test_stage_checks',
773+
ess_settings=self.ess_settings,
774+
species_list=[spc],
775+
composite_method=None,
776+
conformer_opt_level=Level(repr=default_levels_of_theory['conformer']),
777+
opt_level=Level(repr=default_levels_of_theory['opt']),
778+
freq_level=Level(repr=default_levels_of_theory['freq']),
779+
sp_level=Level(repr=default_levels_of_theory['sp']),
780+
scan_level=Level(repr=default_levels_of_theory['scan']),
781+
ts_guess_level=Level(repr=default_levels_of_theory['ts_guesses']),
782+
project_directory=os.path.join(ARC_PATH, 'Projects', 'arc_project_for_testing_delete_after_usage6'),
783+
testing=True,
784+
job_types=self.job_types1,
785+
orbitals_level=default_levels_of_theory['orbitals'],
786+
adaptive_levels=None,
787+
)
788+
return sched, spc.label
789+
790+
def test_check_conformer_stage_complete_spawns_opt_for_ts(self):
791+
"""Test that _check_conformer_stage_complete() calls determine_most_likely_ts_conformer() and
792+
spawns an opt job after all TS conformer jobs finish, even when the job-processing loop broke
793+
early due to troubleshooting."""
794+
sched, label = self._make_isolated_scheduler()
795+
# Set up species as a TS with completed conformer jobs.
796+
sched.species_dict[label].is_ts = True
797+
sched.species_dict[label].ts_conf_spawned = True
798+
sched.species_dict[label].ts_guesses_exhausted = False
799+
sched.species_dict[label].initial_xyz = None
800+
tsg = TSGuess(method='autotst', index=0, success=True, energy=10.0)
801+
sched.species_dict[label].ts_guesses = [tsg]
802+
sched.job_dict[label] = {'conf_opt': {0: MagicMock()}}
803+
sched.running_jobs[label] = [] # all conf_opt jobs done
804+
sched.output[label]['job_types']['conf_opt'] = False
805+
806+
with patch.object(sched, 'determine_most_likely_ts_conformer') as mock_det, \
807+
patch.object(sched, 'run_opt_job') as mock_opt:
808+
# Simulate determine_most_likely_ts_conformer setting initial_xyz.
809+
def set_xyz(lbl):
810+
sched.species_dict[lbl].initial_xyz = {'symbols': ('C',), 'isotopes': (12,), 'coords': ((0, 0, 0),)}
811+
mock_det.side_effect = set_xyz
812+
sched._check_conformer_stage_complete(label)
813+
mock_det.assert_called_once_with(label)
814+
mock_opt.assert_called_once_with(label, fine=sched.fine_only)
815+
816+
def test_check_tsg_stage_complete_all_failed(self):
817+
"""Test that _check_tsg_stage_complete() sets ts_guesses_exhausted when all TS guesses
818+
failed, and does not call run_conformer_jobs()."""
819+
sched, label = self._make_isolated_scheduler()
820+
sched.species_dict[label].is_ts = True
821+
sched.species_dict[label].ts_conf_spawned = False
822+
sched.species_dict[label].ts_guesses_exhausted = False
823+
tsg1 = TSGuess(method='autotst', index=0, success=False)
824+
tsg2 = TSGuess(method='gcn', index=1, success=False)
825+
sched.species_dict[label].ts_guesses = [tsg1, tsg2]
826+
sched.job_dict[label] = {'tsg': {0: MagicMock(), 1: MagicMock()}}
827+
sched.running_jobs[label] = [] # no tsg jobs running
828+
829+
with patch.object(sched, 'run_conformer_jobs') as mock_conf:
830+
sched._check_tsg_stage_complete(label)
831+
mock_conf.assert_not_called()
832+
self.assertTrue(sched.species_dict[label].ts_guesses_exhausted)
833+
834+
def test_check_tsg_stage_complete_no_repeat_after_exhausted(self):
835+
"""Test that _check_tsg_stage_complete() returns immediately when ts_guesses_exhausted
836+
is already True (does not re-log or re-call run_conformer_jobs)."""
837+
sched, label = self._make_isolated_scheduler()
838+
sched.species_dict[label].is_ts = True
839+
sched.species_dict[label].ts_conf_spawned = False
840+
sched.species_dict[label].ts_guesses_exhausted = True
841+
tsg = TSGuess(method='autotst', index=0, success=False)
842+
sched.species_dict[label].ts_guesses = [tsg]
843+
sched.job_dict[label] = {'tsg': {0: MagicMock()}}
844+
sched.running_jobs[label] = []
845+
846+
with patch.object(sched, 'run_conformer_jobs') as mock_conf:
847+
sched._check_tsg_stage_complete(label)
848+
mock_conf.assert_not_called()
849+
850+
def test_check_species_complete_no_repeat_after_converged(self):
851+
"""Test that _check_species_complete() does not call check_all_done()
852+
for a species whose convergence is already True."""
853+
sched, label = self._make_isolated_scheduler()
854+
sched.output[label]['convergence'] = True
855+
sched.running_jobs[label] = [] # empty entry left over
856+
857+
with patch.object(sched, 'check_all_done') as mock_cad:
858+
sched._check_species_complete(label)
859+
mock_cad.assert_not_called()
860+
# Also verify empty running_jobs entry was cleaned up.
861+
self.assertNotIn(label, sched.running_jobs)
862+
863+
def test_check_species_complete_no_repeat_after_failed(self):
864+
"""Test that _check_species_complete() does not call check_all_done()
865+
for a species whose convergence is already False."""
866+
sched, label = self._make_isolated_scheduler()
867+
sched.output[label]['convergence'] = False
868+
869+
with patch.object(sched, 'check_all_done') as mock_cad:
870+
sched._check_species_complete(label)
871+
mock_cad.assert_not_called()
872+
873+
def test_check_species_complete_calls_check_all_done_when_ready(self):
874+
"""Test that _check_species_complete() calls check_all_done() when running_jobs
875+
is empty and convergence is still None (not yet finalized)."""
876+
sched, label = self._make_isolated_scheduler()
877+
sched.output[label]['convergence'] = None
878+
sched.running_jobs[label] = []
879+
sched._pending_pipe_sp = set()
880+
sched._pending_pipe_freq = set()
881+
sched._pending_pipe_irc = set()
882+
sched._pending_pipe_conf_sp = {}
883+
884+
with patch.object(sched, 'check_all_done') as mock_cad:
885+
sched._check_species_complete(label)
886+
mock_cad.assert_called_once_with(label)
887+
# Empty running_jobs entry should be cleaned up.
888+
self.assertNotIn(label, sched.running_jobs)
889+
760890
@classmethod
761891
def tearDownClass(cls):
762892
"""

0 commit comments

Comments
 (0)