|
6 | 6 | """ |
7 | 7 |
|
8 | 8 | import unittest |
| 9 | +from unittest.mock import patch |
9 | 10 | import os |
10 | 11 | import shutil |
11 | 12 |
|
|
19 | 20 | from arc.imports import settings |
20 | 21 | from arc.reaction import ARCReaction |
21 | 22 | from arc.species.converter import str_to_xyz |
22 | | -from arc.species.species import ARCSpecies |
| 23 | +from arc.species.species import ARCSpecies, TSGuess |
23 | 24 |
|
24 | 25 |
|
25 | 26 | default_levels_of_theory = settings['default_levels_of_theory'] |
@@ -757,6 +758,119 @@ def test_add_label_to_unique_species_labels(self): |
757 | 758 | self.assertEqual(unique_label, 'new_species_15_1') |
758 | 759 | self.assertEqual(self.sched2.unique_species_labels, ['methylamine', 'C2H6', 'CtripCO', 'new_species_15', 'new_species_15_0', 'new_species_15_1']) |
759 | 760 |
|
| 761 | + @patch('arc.scheduler.Scheduler.run_opt_job') |
| 762 | + def test_switch_ts_cleanup(self, mock_run_opt): |
| 763 | + """Test that switch_ts resets job_types, convergence, cleans up IRC species, and clears pending pipes.""" |
| 764 | + ts_xyz = str_to_xyz("""N 0.91779059 0.51946178 0.00000000 |
| 765 | + H 1.81402049 1.03819414 0.00000000 |
| 766 | + H 0.00000000 0.00000000 0.00000000 |
| 767 | + H 0.91779059 1.22790192 0.72426890""") |
| 768 | + |
| 769 | + ts_spc = ARCSpecies(label='TS_test', is_ts=True, xyz=ts_xyz, multiplicity=1, charge=0, |
| 770 | + compute_thermo=False) |
| 771 | + # Create two TSGuess objects so determine_most_likely_ts_conformer can pick the 2nd after the 1st fails. |
| 772 | + ts_spc.ts_guesses = [ |
| 773 | + TSGuess(index=0, method='heuristics', success=True, energy=100.0, xyz=ts_xyz, |
| 774 | + execution_time='0:00:01'), |
| 775 | + TSGuess(index=1, method='heuristics', success=True, energy=110.0, xyz=ts_xyz, |
| 776 | + execution_time='0:00:01'), |
| 777 | + ] |
| 778 | + ts_spc.ts_guesses[0].opt_xyz = ts_xyz |
| 779 | + ts_spc.ts_guesses[0].imaginary_freqs = [-500.0] |
| 780 | + ts_spc.ts_guesses[1].opt_xyz = ts_xyz |
| 781 | + ts_spc.ts_guesses[1].imaginary_freqs = [-400.0] |
| 782 | + # Simulate guess 0 already tried. |
| 783 | + ts_spc.chosen_ts = 0 |
| 784 | + ts_spc.chosen_ts_list = [0] |
| 785 | + ts_spc.ts_guesses_exhausted = False |
| 786 | + |
| 787 | + sched = Scheduler(project='test_switch_ts', ess_settings=self.ess_settings, |
| 788 | + species_list=[ts_spc], |
| 789 | + opt_level=Level(repr=default_levels_of_theory['opt']), |
| 790 | + freq_level=Level(repr=default_levels_of_theory['freq']), |
| 791 | + sp_level=Level(repr=default_levels_of_theory['sp']), |
| 792 | + ts_guess_level=Level(repr=default_levels_of_theory['ts_guesses']), |
| 793 | + project_directory=os.path.join(ARC_PATH, 'Projects', |
| 794 | + 'arc_project_for_testing_delete_after_usage4'), |
| 795 | + testing=True, |
| 796 | + job_types=self.job_types1, |
| 797 | + ) |
| 798 | + |
| 799 | + ts_label = 'TS_test' |
| 800 | + # Simulate state after guess 0 completed: freq/sp/opt marked done. |
| 801 | + sched.output[ts_label]['job_types']['opt'] = True |
| 802 | + sched.output[ts_label]['job_types']['freq'] = True |
| 803 | + sched.output[ts_label]['job_types']['sp'] = True |
| 804 | + sched.output[ts_label]['convergence'] = True |
| 805 | + sched.job_dict[ts_label] = {'opt': {}, 'freq': {}, 'sp': {}} |
| 806 | + sched.running_jobs[ts_label] = [] |
| 807 | + |
| 808 | + # Simulate IRC species spawned from guess 0. |
| 809 | + irc_label_1 = 'IRC_TS_test_1' |
| 810 | + irc_label_2 = 'IRC_TS_test_2' |
| 811 | + irc_spc_1 = ARCSpecies(label=irc_label_1, xyz=ts_xyz, compute_thermo=False, |
| 812 | + irc_label=ts_label) |
| 813 | + irc_spc_2 = ARCSpecies(label=irc_label_2, xyz=ts_xyz, compute_thermo=False, |
| 814 | + irc_label=ts_label) |
| 815 | + ts_spc.irc_label = f'{irc_label_1} {irc_label_2}' |
| 816 | + sched.species_dict[irc_label_1] = irc_spc_1 |
| 817 | + sched.species_dict[irc_label_2] = irc_spc_2 |
| 818 | + sched.species_list.extend([irc_spc_1, irc_spc_2]) |
| 819 | + sched.unique_species_labels.extend([irc_label_1, irc_label_2]) |
| 820 | + sched.running_jobs[irc_label_1] = ['opt_a100'] |
| 821 | + sched.running_jobs[irc_label_2] = ['opt_a101'] |
| 822 | + sched.job_dict[irc_label_1] = {'opt': {}} |
| 823 | + sched.job_dict[irc_label_2] = {'opt': {}} |
| 824 | + sched.initialize_output_dict(label=irc_label_1) |
| 825 | + sched.initialize_output_dict(label=irc_label_2) |
| 826 | + |
| 827 | + # Simulate pending pipe entries from the old guess. |
| 828 | + sched._pending_pipe_sp.add(ts_label) |
| 829 | + sched._pending_pipe_freq.add(ts_label) |
| 830 | + sched._pending_pipe_irc.add((ts_label, 'forward')) |
| 831 | + sched._pending_pipe_irc.add((ts_label, 'reverse')) |
| 832 | + |
| 833 | + # Call switch_ts — should pick guess 1 and clean up all state from guess 0. |
| 834 | + sched.switch_ts(ts_label) |
| 835 | + |
| 836 | + # Verify guess 1 was selected. |
| 837 | + self.assertEqual(sched.species_dict[ts_label].chosen_ts, 1) |
| 838 | + self.assertIn(1, sched.species_dict[ts_label].chosen_ts_list) |
| 839 | + |
| 840 | + # Verify IRC species from guess 0 fully removed. |
| 841 | + self.assertNotIn(irc_label_1, sched.species_dict) |
| 842 | + self.assertNotIn(irc_label_2, sched.species_dict) |
| 843 | + self.assertNotIn(irc_label_1, sched.running_jobs) |
| 844 | + self.assertNotIn(irc_label_2, sched.running_jobs) |
| 845 | + self.assertNotIn(irc_label_1, sched.job_dict) |
| 846 | + self.assertNotIn(irc_label_2, sched.job_dict) |
| 847 | + self.assertNotIn(irc_label_1, sched.output) |
| 848 | + self.assertNotIn(irc_label_2, sched.output) |
| 849 | + self.assertNotIn(irc_label_1, sched.unique_species_labels) |
| 850 | + self.assertNotIn(irc_label_2, sched.unique_species_labels) |
| 851 | + self.assertIsNone(sched.species_dict[ts_label].irc_label) |
| 852 | + |
| 853 | + # Verify job_types reset and convergence cleared. |
| 854 | + self.assertFalse(sched.output[ts_label]['job_types']['opt']) |
| 855 | + self.assertFalse(sched.output[ts_label]['job_types']['freq']) |
| 856 | + self.assertFalse(sched.output[ts_label]['job_types']['sp']) |
| 857 | + self.assertIsNone(sched.output[ts_label]['convergence']) |
| 858 | + |
| 859 | + # Verify pending pipe entries cleared. |
| 860 | + self.assertNotIn(ts_label, sched._pending_pipe_sp) |
| 861 | + self.assertNotIn(ts_label, sched._pending_pipe_freq) |
| 862 | + self.assertNotIn((ts_label, 'forward'), sched._pending_pipe_irc) |
| 863 | + self.assertNotIn((ts_label, 'reverse'), sched._pending_pipe_irc) |
| 864 | + |
| 865 | + # Verify ts_checks were reset. |
| 866 | + self.assertIsNone(sched.species_dict[ts_label].ts_checks['freq']) |
| 867 | + self.assertIsNone(sched.species_dict[ts_label].ts_checks['NMD']) |
| 868 | + self.assertIsNone(sched.species_dict[ts_label].ts_checks['E0']) |
| 869 | + |
| 870 | + # Clean up. |
| 871 | + shutil.rmtree(os.path.join(ARC_PATH, 'Projects', 'arc_project_for_testing_delete_after_usage4'), |
| 872 | + ignore_errors=True) |
| 873 | + |
760 | 874 | @classmethod |
761 | 875 | def tearDownClass(cls): |
762 | 876 | """ |
|
0 commit comments