|
2 | 2 | A module for checking the quality of TS-related calculations, contains helper functions for Scheduler. |
3 | 3 | """ |
4 | 4 |
|
| 5 | +from itertools import product |
5 | 6 | import os |
6 | 7 |
|
7 | 8 | import numpy as np |
|
17 | 18 | read_yaml_file, |
18 | 19 | sum_list_entries, |
19 | 20 | ) |
20 | | -from arc.family.family import get_reaction_family_products |
21 | 21 | from arc.imports import settings |
22 | | -from arc.species.converter import check_xyz_dict, displace_xyz, xyz_to_dmat |
| 22 | +from arc.species.converter import check_isomorphism, check_xyz_dict, xyz_from_data, xyz_to_dmat |
| 23 | +from arc.species.perceive import perceive_molecule_from_xyz |
23 | 24 | from arc.statmech.factory import statmech_factory |
24 | 25 |
|
25 | 26 | if TYPE_CHECKING: |
26 | 27 | from arc.job.adapter import JobAdapter |
27 | 28 | from arc.level import Level |
| 29 | + from arc.molecule.molecule import Molecule |
28 | 30 | from arc.species.species import ARCSpecies, TSGuess |
29 | 31 | from arc.reaction import ARCReaction |
30 | 32 |
|
31 | 33 | logger = get_logger() |
32 | 34 |
|
| 35 | +MAX_IRC_FRAGMENTS_FOR_CHARGE_SEARCH = 4 |
33 | 36 | LOWEST_MAJOR_TS_FREQ, HIGHEST_MAJOR_TS_FREQ = settings['LOWEST_MAJOR_TS_FREQ'], settings['HIGHEST_MAJOR_TS_FREQ'] |
34 | 37 |
|
35 | 38 |
|
@@ -498,28 +501,188 @@ def check_irc_species_and_rxn(xyz_1: dict, |
498 | 501 | Check that the two species that result from optimizing the outputs of two IRC runs |
499 | 502 | correspond to the desired reactants and products of the corresponding reaction. |
500 | 503 |
|
| 504 | + Uses molecular graph isomorphism (including bond orders and resonance structures) |
| 505 | + when molecule perception succeeds for both endpoints. Falls back to distance-matrix-based |
| 506 | + bond-list comparison if perception fails for either endpoint or if the expected |
| 507 | + reactant/product ``Molecule`` objects are unavailable. |
| 508 | +
|
501 | 509 | Args: |
502 | | - xyz_1 (dict): The coordinates of IRS species 1. |
503 | | - xyz_2 (dict): The coordinates of IRS species 2. |
| 510 | + xyz_1 (dict): The coordinates of IRC species 1. |
| 511 | + xyz_2 (dict): The coordinates of IRC species 2. |
504 | 512 | rxn (ARCReaction): The corresponding reaction object instance. |
505 | 513 | """ |
506 | 514 | if rxn is None: |
507 | 515 | return None |
508 | 516 | rxn.ts_species.ts_checks['IRC'] = False |
509 | 517 | xyz_1, xyz_2 = check_xyz_dict(xyz_1), check_xyz_dict(xyz_2) |
| 518 | + |
| 519 | + # Primary check: molecular graph isomorphism |
| 520 | + reactants, products = rxn.get_reactants_and_products(return_copies=True) |
| 521 | + r_mols = [r.mol for r in reactants] |
| 522 | + p_mols = [p.mol for p in products] |
| 523 | + |
| 524 | + if not any(m is None for m in r_mols + p_mols): |
| 525 | + charge = rxn.charge or 0 |
| 526 | + frags_1 = _perceive_irc_fragments(xyz_1, charge=charge) |
| 527 | + frags_2 = _perceive_irc_fragments(xyz_2, charge=charge) |
| 528 | + if frags_1 is not None and frags_2 is not None: |
| 529 | + if (_match_fragments_to_species(frags_1, r_mols) |
| 530 | + and _match_fragments_to_species(frags_2, p_mols)) \ |
| 531 | + or (_match_fragments_to_species(frags_1, p_mols) |
| 532 | + and _match_fragments_to_species(frags_2, r_mols)): |
| 533 | + rxn.ts_species.ts_checks['IRC'] = True |
| 534 | + return |
| 535 | + logger.debug('IRC isomorphism check failed, falling back to bond-list comparison.') |
| 536 | + else: |
| 537 | + logger.debug('IRC molecule perception failed for one or both endpoints, ' |
| 538 | + 'falling back to bond-list comparison.') |
| 539 | + |
| 540 | + # Fallback: bond-list connectivity comparison |
| 541 | + try: |
| 542 | + r_bonds, p_bonds = rxn.get_bonds() |
| 543 | + except Exception: |
| 544 | + logger.debug('Could not get reaction bonds for IRC fallback check.') |
| 545 | + return |
510 | 546 | dmat_1, dmat_2 = xyz_to_dmat(xyz_1), xyz_to_dmat(xyz_2) |
511 | | - dmat_bonds_1 = get_bonds_from_dmat(dmat=dmat_1, |
512 | | - elements=xyz_1['symbols'], |
513 | | - ) |
514 | | - dmat_bonds_2 = get_bonds_from_dmat(dmat=dmat_2, |
515 | | - elements=xyz_2['symbols'], |
516 | | - ) |
517 | | - r_bonds, p_bonds = rxn.get_bonds() |
| 547 | + dmat_bonds_1 = get_bonds_from_dmat(dmat=dmat_1, elements=xyz_1['symbols']) |
| 548 | + dmat_bonds_2 = get_bonds_from_dmat(dmat=dmat_2, elements=xyz_2['symbols']) |
518 | 549 | if _check_equal_bonds_list(dmat_bonds_1, r_bonds) and _check_equal_bonds_list(dmat_bonds_2, p_bonds) \ |
519 | 550 | or _check_equal_bonds_list(dmat_bonds_2, r_bonds) and _check_equal_bonds_list(dmat_bonds_1, p_bonds): |
520 | 551 | rxn.ts_species.ts_checks['IRC'] = True |
521 | 552 |
|
522 | 553 |
|
| 554 | +def _perceive_irc_fragments(xyz: dict, |
| 555 | + charge: int = 0, |
| 556 | + ) -> Optional[List['Molecule']]: |
| 557 | + """ |
| 558 | + Perceive individual molecular fragments from an IRC endpoint geometry. |
| 559 | +
|
| 560 | + Detects connected components from the distance-matrix-based bond list, |
| 561 | + then perceives fragments as ``Molecule`` objects. For multi-fragment systems, |
| 562 | + charge distribution across fragments is handled by brute-force search over |
| 563 | + charge splits that sum to the total charge, preferring minimal total absolute charge. |
| 564 | +
|
| 565 | + Args: |
| 566 | + xyz (dict): The Cartesian coordinates of the IRC endpoint. |
| 567 | + charge (int): The net charge of the full system. |
| 568 | +
|
| 569 | + Returns: |
| 570 | + Optional[List[Molecule]]: A list of perceived ``Molecule`` objects (one per fragment), |
| 571 | + or ``None`` if perception fails for any fragment. |
| 572 | + """ |
| 573 | + symbols = xyz['symbols'] |
| 574 | + coords = xyz['coords'] |
| 575 | + n_atoms = len(symbols) |
| 576 | + |
| 577 | + dmat = xyz_to_dmat(xyz) |
| 578 | + # Pass n_fragments != 1 to skip the heavy-atom bridging heuristic in get_bonds_from_dmat. |
| 579 | + bonds = get_bonds_from_dmat(dmat=dmat, elements=symbols, n_fragments=0) |
| 580 | + |
| 581 | + adj = {i: set() for i in range(n_atoms)} |
| 582 | + for a, b in bonds: |
| 583 | + adj[a].add(b) |
| 584 | + adj[b].add(a) |
| 585 | + |
| 586 | + visited = set() |
| 587 | + fragment_indices = [] |
| 588 | + for start in range(n_atoms): |
| 589 | + if start in visited: |
| 590 | + continue |
| 591 | + component = [] |
| 592 | + stack = [start] |
| 593 | + while stack: |
| 594 | + node = stack.pop() |
| 595 | + if node in visited: |
| 596 | + continue |
| 597 | + visited.add(node) |
| 598 | + component.append(node) |
| 599 | + for neighbor in adj[node]: |
| 600 | + if neighbor not in visited: |
| 601 | + stack.append(neighbor) |
| 602 | + fragment_indices.append(sorted(component)) |
| 603 | + |
| 604 | + n_frags = len(fragment_indices) |
| 605 | + frag_xyzs = [] |
| 606 | + for frag_idx in fragment_indices: |
| 607 | + frag_symbols = tuple(symbols[i] for i in frag_idx) |
| 608 | + frag_coords = tuple(coords[i] for i in frag_idx) |
| 609 | + frag_xyzs.append(xyz_from_data(coords=frag_coords, symbols=frag_symbols)) |
| 610 | + |
| 611 | + if n_frags == 1: |
| 612 | + mol = perceive_molecule_from_xyz(frag_xyzs[0], charge=charge, n_fragments=1) |
| 613 | + return [mol] if mol is not None else None |
| 614 | + |
| 615 | + # Prefer splits that minimize the total absolute charge (e.g., 0,0 over +1,-1). |
| 616 | + if n_frags > MAX_IRC_FRAGMENTS_FOR_CHARGE_SEARCH: |
| 617 | + return None |
| 618 | + max_abs_charge = max(2, abs(charge) + 1) |
| 619 | + charge_range = range(-max_abs_charge, max_abs_charge + 1) |
| 620 | + best_mols = None |
| 621 | + best_sep = float('inf') |
| 622 | + for charges in product(charge_range, repeat=n_frags): |
| 623 | + if sum(charges) != charge: |
| 624 | + continue |
| 625 | + sep = sum(abs(c) for c in charges) |
| 626 | + if sep >= best_sep: |
| 627 | + continue |
| 628 | + mols = [] |
| 629 | + ok = True |
| 630 | + for frag_xyz, frag_charge in zip(frag_xyzs, charges): |
| 631 | + mol = perceive_molecule_from_xyz(frag_xyz, charge=frag_charge, n_fragments=1) |
| 632 | + if mol is None or mol.get_net_charge() != frag_charge: |
| 633 | + ok = False |
| 634 | + break |
| 635 | + mols.append(mol) |
| 636 | + if ok: |
| 637 | + best_mols, best_sep = mols, sep |
| 638 | + if sep == 0: |
| 639 | + break |
| 640 | + return best_mols |
| 641 | + |
| 642 | + |
| 643 | +def _match_fragments_to_species(fragments: List['Molecule'], |
| 644 | + expected_mols: List['Molecule'], |
| 645 | + ) -> bool: |
| 646 | + """ |
| 647 | + Check whether a list of perceived molecular fragments matches a list of expected species |
| 648 | + via graph isomorphism. Handles multi-species reactions (e.g., A + B) by finding a |
| 649 | + one-to-one matching between fragments and expected species using backtracking with pruning. |
| 650 | +
|
| 651 | + Args: |
| 652 | + fragments (List[Molecule]): Perceived fragment molecules from an IRC endpoint. |
| 653 | + expected_mols (List[Molecule]): Expected species molecules from the reaction. |
| 654 | +
|
| 655 | + Returns: |
| 656 | + bool: Whether a valid one-to-one isomorphic matching exists. |
| 657 | + """ |
| 658 | + n = len(fragments) |
| 659 | + if n != len(expected_mols): |
| 660 | + return False |
| 661 | + if n == 0: |
| 662 | + return True |
| 663 | + frag_formulas = sorted(frag.get_formula() for frag in fragments) |
| 664 | + expected_formulas = sorted(mol.get_formula() for mol in expected_mols) |
| 665 | + if frag_formulas != expected_formulas: |
| 666 | + return False |
| 667 | + if n == 1: |
| 668 | + return check_isomorphism(fragments[0], expected_mols[0]) |
| 669 | + iso_matrix = [[check_isomorphism(fragments[i], expected_mols[j]) for j in range(n)] for i in range(n)] |
| 670 | + used = [False] * n |
| 671 | + |
| 672 | + def _backtrack(i: int) -> bool: |
| 673 | + if i == n: |
| 674 | + return True |
| 675 | + for j in range(n): |
| 676 | + if not used[j] and iso_matrix[i][j]: |
| 677 | + used[j] = True |
| 678 | + if _backtrack(i + 1): |
| 679 | + return True |
| 680 | + used[j] = False |
| 681 | + return False |
| 682 | + |
| 683 | + return _backtrack(0) |
| 684 | + |
| 685 | + |
523 | 686 | def _check_equal_bonds_list(bonds_1: List[Tuple[int, int]], |
524 | 687 | bonds_2: List[Tuple[int, int]], |
525 | 688 | ) -> bool: |
|
0 commit comments