diff --git a/.github/workflows/test-and-lint.yml b/.github/workflows/test-and-lint.yml index f62fa21..58a646c 100644 --- a/.github/workflows/test-and-lint.yml +++ b/.github/workflows/test-and-lint.yml @@ -2,7 +2,7 @@ name: Test & Lint on: push: - branches: [ "main", "dev", "staging", "refractor" ] + branches: [ "main", "partialits", "staging", "refractor" ] pull_request: branches: [ "main" ] diff --git a/.gitignore b/.gitignore index 07ad7b7..40588fd 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,5 @@ run.sh docs/* run_rdcanon.py Data/Fragment/* +test_partial.py +Data/Benchmark/synthesis/* diff --git a/README.md b/README.md index 47d57fe..c64736d 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,15 @@ # SynKit +[![PyPI version](https://img.shields.io/pypi/v/synkit.svg)](https://pypi.org/project/synkit/) +[![Conda version](https://img.shields.io/conda/vn/tieulongphan/synkit.svg)](https://anaconda.org/tieulongphan/synkit) +[![Docker Pulls](https://img.shields.io/docker/pulls/tieulongphan/synkit.svg)](https://hub.docker.com/r/tieulongphan/synkit) +[![Docker Image Version](https://img.shields.io/docker/v/tieulongphan/synkit/latest?label=container)](https://hub.docker.com/r/tieulongphan/synkit) +[![License](https://img.shields.io/github/license/tieulongphan/synkit.svg)](https://github.com/tieulongphan/synkit/blob/main/LICENSE) +[![Release](https://img.shields.io/github/v/release/tieulongphan/synkit.svg)](https://github.com/tieulongphan/synkit/releases) +[![Last Commit](https://img.shields.io/github/last-commit/tieulongphan/synkit.svg)](https://github.com/tieulongphan/synkit/commits) +[![Zenodo](https://zenodo.org/badge/DOI/10.5281/zenodo.15269901.svg)](https://doi.org/10.5281/zenodo.15269901) +[![CI](https://github.com/tieulongphan/synkit/actions/workflows/test-and-lint.yml/badge.svg?branch=main)](https://github.com/tieulongphan/synkit/actions/workflows/test-and-lint.yml) +[![Dependency PRs](https://img.shields.io/github/issues-pr-raw/tieulongphan/synkit?label=dependency%20PRs)](https://github.com/tieulongphan/synkit/pulls?q=is%3Apr+label%3Adependencies) +[![Stars](https://img.shields.io/github/stars/tieulongphan/synkit.svg?style=social&label=Star)](https://github.com/tieulongphan/synkit/stargazers) **Toolkit for Synthesis Planning** diff --git a/Test/Graph/MTG/test_mtg.py b/Test/Graph/MTG/test_mtg.py index 2750293..0f8827d 100644 --- a/Test/Graph/MTG/test_mtg.py +++ b/Test/Graph/MTG/test_mtg.py @@ -20,19 +20,17 @@ def setUp(self) -> None: self.test_graph_2 = [get_rc(rsmi_to_its(var)) for var in test_2] def test_MTG_1(self): - grp = GroupComp(self.test_graph_1[0], self.test_graph_1[1]) - candidates = grp.get_mapping() - print(candidates) - mtg = MTG(self.test_graph_1[0], self.test_graph_1[1], candidates[0]) - self.assertEqual(len(mtg.get_nodes()), 6) - self.assertEqual(len(mtg.get_edges()), 7) + mtg = MTG(self.test_graph_1[0:2], mcs_mol=True) + self.assertEqual(mtg._graph.number_of_nodes(), 6) + self.assertEqual(mtg._graph.number_of_edges(), 7) def test_MTG_2(self): grp = GroupComp(self.test_graph_2[0], self.test_graph_2[1]) candidates = grp.get_mapping() - mtg = MTG(self.test_graph_2[0], self.test_graph_2[1], candidates[0]) - self.assertEqual(len(mtg.get_nodes()), 5) - self.assertEqual(len(mtg.get_edges()), 4) + # print(candidates) + mtg = MTG(self.test_graph_2[0:], candidates) + self.assertEqual(mtg._graph.number_of_nodes(), 5) + self.assertEqual(mtg._graph.number_of_edges(), 4) if __name__ == "__main__": diff --git a/Test/Graph/Wildcard/__init__.py b/Test/Graph/Wildcard/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Test/Graph/Wildcard/test_radwc.py b/Test/Graph/Wildcard/test_radwc.py new file mode 100644 index 0000000..9df29bb --- /dev/null +++ b/Test/Graph/Wildcard/test_radwc.py @@ -0,0 +1,55 @@ +import unittest +from synkit.Graph.Wildcard.radwc import RadWC + + +class TestRadWC(unittest.TestCase): + def test_no_product_radicals(self): + """If product has no radicals, output should be unchanged.""" + rxn = "[CH3:1][OH:2]>>[CH3:1][OH:2]" + self.assertEqual(RadWC.transform(rxn), rxn) + + def test_single_radical_in_product(self): + """A single radical in product gets a wildcard.""" + rxn = "[CH3:1][OH:2]>>[CH2:1].[OH:2]" + out = RadWC.transform(rxn) + # Check [*:3] is attached to [CH2:1] + self.assertIn("[CH2:1]([*:3])", out) # Atom-maps: 1,2 exist, so 3 is next + + def test_multiple_radicals_in_product(self): + """Multiple radicals in product get multiple wildcards.""" + rxn = "[CH3:1][OH:2]>>[CH2:1].[O:2]" + out = RadWC.transform(rxn) + # [CH2:1] has *:3 and *:4, [O:2] has *:5 + self.assertIn("[CH2:1]([*:3])", out) + self.assertIn("[O:2]([*:5])", out) + + def test_radical_and_nonradical_mixture(self): + """Mixed radical/non-radical product fragments, only radicals get wildcard.""" + rxn = "[CH3:1][OH:2]>>[CH2:1].[OH:2]" + out = RadWC.transform(rxn) + # [CH2:1] gets *:3, [OH:2] is unchanged + self.assertIn("[CH2:1]([*:3])", out) + self.assertIn("[OH:2]", out) + + def test_user_start_map(self): + """User-supplied map index is used for wildcards.""" + rxn = "[CH3:7][OH:8]>>[CH2:7].[OH:8]" + out = RadWC.transform(rxn, start_map=50) + self.assertIn("[CH2:7]([*:50])", out) + + def test_empty_reaction(self): + """Empty input should raise ValueError.""" + with self.assertRaises(ValueError): + RadWC.transform("") + + def test_three_component(self): + """Agent block is preserved.""" + rxn = "[CH3:1][OH:2]>[Na+]>[CH2:1].[OH:2]" + out = RadWC.transform(rxn) + self.assertTrue(out.startswith("[CH3:1][OH:2]>[Na+]>")) + self.assertIn("[CH2:1]([*:3])", out) + self.assertIn("[OH:2]", out) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/Graph/Wildcard/test_wildcard.py b/Test/Graph/Wildcard/test_wildcard.py new file mode 100644 index 0000000..d403f8d --- /dev/null +++ b/Test/Graph/Wildcard/test_wildcard.py @@ -0,0 +1,81 @@ +import unittest +from synkit.IO import rsmi_to_graph +from synkit.Graph.Wildcard.wildcard import WildCard + + +class TestWildCard(unittest.TestCase): + def setUp(self): + # The main, complex test case with atom mapping + self.rsmi_main = ( + "[cH:1]1[cH:14][c:10]2[c:23]([cH:11][n:25]1)[cH:17][cH:12][cH:4][c:31]2[NH2:28]." + "[cH:2]1[c:20]([C:22]([OH:7])=[O:21])[s:18][c:24]([S:6][c:29]2[c:15]([Cl:26])[cH:8]" + "[n:19][cH:9][c:16]2[Cl:27])[c:30]1[N+:5]([O-:3])=[O:13]>>" + "[cH:1]1[cH:14][c:10]2[c:23]([cH:11][n:25]1)[cH:17][cH:12][cH:4][c:31]2[NH:28]" + "[C:22]([c:20]1[cH:2][c:30]([N+:5]([O-:3])=[O:13])[c:24]([S:6][c:29]2[c:15]([Cl:26])" + "[cH:8][n:19][cH:9][c:16]2[Cl:27])[s:18]1)=[O:21]" + ) + # No atoms lost: R == P, should not add wildcards + self.rsmi_no_loss = "CCO>>CCO" + # All atoms lost: RSMI that loses everything (nonsense, but good test) + self.rsmi_all_lost = "CCO>>" + # Empty + self.rsmi_empty = "" + # Wildcard already present + self.rsmi_existing_wildcard = "[CH3:1][CH2:2][OH:3]>>[CH2:1][CH2:2].[*:4][OH:3]" + # No atom map (should raise error) + self.rsmi_no_atom_map = "C(C)Cl>>CC" + + def test_main_case_wildcard_added(self): + """Complex case: output product contains wildcard and roundtrip is valid.""" + out_rsmi = WildCard.rsmi_with_wildcards(self.rsmi_main) + _, product = out_rsmi.split(">>") + self.assertIsInstance(out_rsmi, str) + self.assertIn( + "*", product, "Wildcard '*' should be present in the product side." + ) + # Roundtrip: should parse back without error + r, p = rsmi_to_graph(out_rsmi) + self.assertTrue(r.number_of_nodes() > 0) + self.assertTrue(p.number_of_nodes() > 0) + + def test_no_atoms_lost(self): + """No atoms lost: should raise ValueError if input is not atom-mapped.""" + with self.assertRaises(ValueError): + WildCard.rsmi_with_wildcards(self.rsmi_no_loss) + + def test_all_atoms_lost(self): + """All atoms lost: should raise ValueError if input is not atom-mapped.""" + with self.assertRaises(ValueError): + WildCard.rsmi_with_wildcards(self.rsmi_all_lost) + + def test_empty_input(self): + """Empty input: should raise ValueError.""" + with self.assertRaises(ValueError): + WildCard.rsmi_with_wildcards(self.rsmi_empty) + + def test_wildcard_not_duplicated(self): + """Existing wildcards: should not create duplicate wildcards for same lost bond.""" + out_rsmi = WildCard.rsmi_with_wildcards(self.rsmi_existing_wildcard) + _, product = out_rsmi.split(">>") + # At least one '*' in the product SMILES string + self.assertIn("*", product) + + def test_no_false_positive_wildcards(self): + """Wildcards are only added if there are truly lost subgraphs; non-atom-mapped input raises.""" + rsmi = "C>>C" + with self.assertRaises(ValueError): + WildCard.rsmi_with_wildcards(rsmi) + + def test_output_is_str_and_split(self): + """Should raise ValueError if input is not atom-mapped.""" + with self.assertRaises(ValueError): + WildCard.rsmi_with_wildcards(self.rsmi_no_loss) + + def test_missing_atom_map_raises(self): + """Should raise ValueError if atom_map attributes are missing.""" + with self.assertRaises(ValueError): + WildCard.rsmi_with_wildcards(self.rsmi_no_atom_map) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/IO/combinatorial/__init__.py b/Test/IO/combinatorial/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Test/IO/combinatorial/test_smarts_expander.py b/Test/IO/combinatorial/test_smarts_expander.py new file mode 100644 index 0000000..61d6990 --- /dev/null +++ b/Test/IO/combinatorial/test_smarts_expander.py @@ -0,0 +1,38 @@ +import unittest +from synkit.IO.combinatorial.smarts_expander import SMARTSExpander + + +class TestSMARTSExpander(unittest.TestCase): + + def test_no_placeholders(self): + s = "CCO" + self.assertEqual(list(SMARTSExpander.expand_iter(s)), ["CCO"]) + self.assertEqual(SMARTSExpander.expand(s), ["CCO"]) + + def test_simple_expansion(self): + s = "[C,N:1][O,P:2]" + result = SMARTSExpander.expand(s) + self.assertEqual( + set(result), {"[C:1][O:2]", "[C:1][P:2]", "[N:1][O:2]", "[N:1][P:2]"} + ) + + def test_disjoint_constraint(self): + s = "[C,N:1][O:1]" + with self.assertRaises(ValueError): + list(SMARTSExpander.expand_iter(s)) + + def test_realistic_reaction(self): + rxn = ( + "[H+:6].[C:7](-[O:8](-[H:12]))(-[C,N,O,P,S:9])(-[C,N,O,P,S:10])(-[H:11])." + "[C:2](-[S:4](-[C,N,O,P,S:5]))(-[C,N,O,P,S:1])(=[O:3])>>" + "[S:4](-[H:6])(-[C,N,O,P,S:5]).[H+:12]." + "[C:7](-[O:8](-[C:2](-[C,N,O,P,S:1])(=[O:3])))(-[C,N,O,P,S:9])(-[C,N,O,P,S:10])(-[H:11])" + ) + ex_list = list(SMARTSExpander.expand_iter(rxn)) + self.assertEqual(len(ex_list), 625) + # Optional: Just check format or count, not endswith + self.assertTrue(ex_list[0].startswith("[H+:6]")) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/IO/combinatorial/test_smarts_generalizer.py b/Test/IO/combinatorial/test_smarts_generalizer.py new file mode 100644 index 0000000..0f2e907 --- /dev/null +++ b/Test/IO/combinatorial/test_smarts_generalizer.py @@ -0,0 +1,62 @@ +import unittest +from rdkit import Chem +from synkit.IO.combinatorial.smarts_generalizer import SMARTSGeneralizer + + +class TestSMARTSGeneralizer(unittest.TestCase): + + def setUp(self): + self.gen = SMARTSGeneralizer(sanity_check=True) + + def test_basic_generalization(self): + inputs = [ + "[C:1]-[N:2]>>[N:1]-[C:2]", + "[N:1]-[N:2]>>[N:1]-[N:2]", + "[O:1]-[N:2]>>[N:1]-[N:2]", + ] + output = self.gen.generalize(inputs) + # Instead of strict string match, check correct mapped elements + self.assertIn("[C,N,O:1]", output) + self.assertIn("[N:2]", output) + self.assertIn(">>", output) + + def test_single_smarts(self): + inputs = ["[C:1]-[N:2]>>[N:1]-[C:2]"] + output = self.gen.generalize(inputs) + # Should match input exactly + self.assertEqual(output, "[C:1]-[N:2]>>[N:1]-[C:2]") + + def test_different_topology_raises(self): + inputs = ["[C:1]-[N:2]>>[N:1]-[C:2]", "[N:1]-[N:2]-[C:3]>>[N:1]-[N:2]-[C:3]"] + with self.assertRaises(ValueError): + self.gen.generalize(inputs) + + def test_empty_input_raises(self): + with self.assertRaises(ValueError): + self.gen.generalize([]) + + def test_molecule_smarts(self): + gen = SMARTSGeneralizer(sanity_check=True) + inputs = ["[C:1]-[N:2]", "[N:1]-[N:2]", "[O:1]-[N:2]"] + out = gen.generalize(inputs) + self.assertEqual(out, "[C,N,O:1]-[N:2]") + + mol = Chem.MolFromSmarts(out) + self.assertIsNotNone(mol) + + def test_invalid_sanity_check(self): + gen = SMARTSGeneralizer(sanity_check=True) + # Using an obviously broken SMARTS (bad bracket placement) + _ = [ + "[C:1]-[N:2]>>[N:1]-[X:2]" + ] # 'X' is a valid SMARTS wildcard! Use a real error + with self.assertRaises(ValueError): + gen.generalize(["[C:1][C:2]>>[N:1][C:2]["]) # broken SMARTS + + def test_repr(self): + gen = SMARTSGeneralizer() + self.assertIn("sanity_check", repr(gen)) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/IO/combinatorial/test_smarts_to_graph.py b/Test/IO/combinatorial/test_smarts_to_graph.py new file mode 100644 index 0000000..999066a --- /dev/null +++ b/Test/IO/combinatorial/test_smarts_to_graph.py @@ -0,0 +1,85 @@ +import unittest +import networkx as nx + +from synkit.IO.combinatorial.smarts_to_graph import SMARTSToGraph + + +class TestSMARTSToGraph(unittest.TestCase): + + def setUp(self): + self.stg = SMARTSToGraph() + + def test_smarts_to_graph_simple(self): + g = self.stg.smarts_to_graph("[C:1]-[O:2]") + self.assertIsInstance(g, nx.Graph) + self.assertEqual(set(g.nodes), {1, 2}) + self.assertEqual(g.nodes[1]["element"], "C") + self.assertEqual(g.nodes[2]["element"], "O") + self.assertIsNone(g.nodes[1]["constraint"]) + + def test_smarts_to_graph_constraint(self): + g = self.stg.smarts_to_graph("[C,N,O:1]-[N:2]") + # Node 1 should be placeholder + self.assertEqual(g.nodes[1]["element"], "*") + self.assertIsInstance(g.nodes[1]["constraint"], list) + self.assertIn("C", g.nodes[1]["constraint"]) + self.assertEqual(g.nodes[2]["element"], "N") + self.assertIsNone(g.nodes[2]["constraint"]) + + def test_smarts_to_graph_hcount(self): + g = self.stg.smarts_to_graph("[CH3:1]-[O:2]") + # For SMARTS as written, RDKit returns 0 hydrogens for both + self.assertEqual(g.nodes[1]["hcount"], 0) + self.assertEqual(g.nodes[2]["hcount"], 0) + + def test_invalid_smarts(self): + with self.assertRaises(ValueError): + self.stg.smarts_to_graph("[C:1]-[N") + + def test_missing_atom_map(self): + with self.assertRaises(ValueError): + self.stg.smarts_to_graph("[C]-[O:2]") + + def test_rxn_smarts_to_graphs(self): + rxn = ( + "[H+:6].[C:7](-[O:8](-[H:12]))(-[C,N,O,P,S:9])(-[C,N,O,P,S:10])(-[H:11])." + "[C:2](-[S:4](-[C,N,O,P,S:5]))(-[C,N,O,P,S:1])(=[O:3])>>" + "[S:4](-[H:6])(-[C,N,O,P,S:5]).[H+:12]." + "[C:7](-[O:8](-[C:2](-[C,N,O,P,S:1])(=[O:3])))(-[C,N,O,P,S:9])(-[C,N,O,P,S:10])(-[H:11])" + ) + g_react, _ = self.stg.rxn_smarts_to_graphs(rxn) + + # These are the atom_map indices that should have constraint (from SMARTS [C,N,O,P,S:idx]) + expected_constraint_nodes = {1, 5, 9, 10} + for idx in expected_constraint_nodes: + self.assertIn(idx, g_react.nodes) + self.assertIsNotNone( + g_react.nodes[idx]["constraint"], + f"Node {idx} should have a constraint list but does not", + ) + self.assertEqual( + set(g_react.nodes[idx]["constraint"]), + {"C", "N", "O", "P", "S"}, + f"Node {idx} has incorrect constraint list", + ) + # All other nodes should NOT have a constraint + for idx in set(g_react.nodes) - expected_constraint_nodes: + self.assertIsNone( + g_react.nodes[idx]["constraint"], + f"Node {idx} should NOT have a constraint list", + ) + + def test_rxn_separator(self): + with self.assertRaises(ValueError): + self.stg.rxn_smarts_to_graphs("[C:1]-[O:2]") # no '>>' + + def test_repr_and_describe(self): + r = repr(self.stg) + self.assertIn("placeholders", r) + desc = self.stg.describe() + self.assertIn("smarts_to_graph", desc) + self.assertIn("rxn_smarts_to_graphs", desc) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/Synthesis/Reactor/test_core_engine.py b/Test/Synthesis/Reactor/test_core_engine.py deleted file mode 100644 index d30cb9d..0000000 --- a/Test/Synthesis/Reactor/test_core_engine.py +++ /dev/null @@ -1,112 +0,0 @@ -# import os -# import unittest -# import tempfile -# from synkit.Synthesis.Reactor.core_engine import CoreEngine - - -# class TestCoreEngine(unittest.TestCase): -# def setUp(self): -# # Create a temporary directory -# self.temp_dir = tempfile.TemporaryDirectory() - -# # Path for the rule file -# self.rule_file_path = os.path.join(self.temp_dir.name, "test_rule.gml") - -# # Define rule content -# self.rule_content = """ -# rule [ -# ruleID "1" -# left [ -# edge [ source 1 target 2 label "=" ] -# edge [ source 3 target 4 label "-" ] -# ] -# context [ -# node [ id 1 label "C" ] -# node [ id 2 label "C" ] -# node [ id 3 label "H" ] -# node [ id 4 label "H" ] -# ] -# right [ -# edge [ source 1 target 2 label "-" ] -# edge [ source 1 target 3 label "-" ] -# edge [ source 2 target 4 label "-" ] -# ] -# ] -# """ - -# # Write rule content to the temporary file -# with open(self.rule_file_path, "w") as rule_file: -# rule_file.write(self.rule_content) - -# # Initialize SMILES strings for testing -# self.initial_smiles_fw = ["CC=CC", "[HH]"] -# self.initial_smiles_bw = ["CCCC"] - -# def tearDown(self): -# # Clean up temporary directory -# self.temp_dir.cleanup() - -# def test_perform_reaction_forward(self): -# # Test the perform_reaction method with forward reaction type -# result = CoreEngine._inference( -# rule_file_path=self.rule_file_path, -# initial_smiles=self.initial_smiles_fw, -# prediction_type="forward", -# print_results=False, -# verbosity=0, -# ) -# print(result) -# # Check if result is a list of strings and has content -# self.assertIsInstance( -# result, list, "Expected a list of reaction SMILES strings." -# ) -# self.assertTrue( -# len(result) > 0, "Result should contain reaction SMILES strings." -# ) - -# self.assertEqual(result[0], "CC=CC.[HH]>>CCCC") - -# # Check if the result SMILES format matches expected output format -# for reaction_smiles in result: -# self.assertIn(">>", reaction_smiles, "Reaction SMILES format is incorrect.") -# parts = reaction_smiles.split(">>") -# self.assertEqual( -# parts[0], -# ".".join(self.initial_smiles_fw), -# "Base SMILES are not correctly formatted.", -# ) -# self.assertTrue(len(parts[1]) > 0, "Product SMILES should be non-empty.") - -# def test_perform_reaction_backward(self): -# # Test the perform_reaction method with backward reaction type -# result = CoreEngine._inference( -# rule_file_path=self.rule_file_path, -# initial_smiles=self.initial_smiles_bw, -# prediction_type="backward", -# print_results=False, -# verbosity=0, -# ) -# # Check if result is a list of strings and has content -# self.assertIsInstance( -# result, list, "Expected a list of reaction SMILES strings." -# ) -# self.assertTrue( -# len(result) > 0, "Result should contain reaction SMILES strings." -# ) -# self.assertEqual(result[0], "C=CCC.[H][H]>>CCCC") -# self.assertEqual(result[1], "[H][H].C(C)=CC>>CCCC") - -# # Check if the result SMILES format matches expected output format -# for reaction_smiles in result: -# self.assertIn(">>", reaction_smiles, "Reaction SMILES format is incorrect.") -# parts = reaction_smiles.split(">>") -# self.assertTrue(len(parts[0]) > 0, "Product SMILES should be non-empty.") -# self.assertEqual( -# parts[1], -# ".".join(self.initial_smiles_bw), -# "Base SMILES are not correctly formatted.", -# ) - - -# if __name__ == "__main__": -# unittest.main() diff --git a/Test/Synthesis/Reactor/test_imba_engine.py b/Test/Synthesis/Reactor/test_imba_engine.py new file mode 100644 index 0000000..65dd508 --- /dev/null +++ b/Test/Synthesis/Reactor/test_imba_engine.py @@ -0,0 +1,94 @@ +import unittest +from synkit.IO import rsmi_to_its +from synkit.Graph.Wildcard.wildcard import WildCard +from synkit.Chem.Reaction.standardize import Standardize +from synkit.Synthesis.Reactor.imba_engine import ImbaEngine + + +class TestImbaEngine(unittest.TestCase): + def setUp(self): + # A complex standardized RSMI from your example + self.smart = ( + "[cH:1]1[cH:14][c:10]2[c:23]([cH:11][n:25]1)[cH:17][cH:12][cH:4][c:31]2[NH2:28]." + "[cH:2]1[c:20]([C:22]([OH:7])=[O:21])[s:18][c:24]([S:6][c:29]2[c:15]" + "([Cl:26])[cH:8][n:19][cH:9][c:16]2[Cl:27])[c:30]1[N+:5]([O-:3])=[O:13]>>" + "[cH:1]1[cH:14][c:10]2[c:23]([cH:11][n:25]1)[cH:17][cH:12][cH:4]" + "[c:31]2[NH:28][C:22]([c:20]1[cH:2][c:30]([N+:5]([O-:3])=[O:13])[c:24]([S:6]" + "[c:29]2[c:15]([Cl:26])[cH:8][n:19][cH:9][c:16]2[Cl:27])[s:18]1)=[O:21]" + ) + # Standardize removes AAM + self.rsmi = Standardize().fit(self.smart, remove_aam=True) + + def test_pipeline_forward(self): + """Test forward ImbaEngine pipeline end-to-end.""" + # Apply wildcard insertion + wild_smart = WildCard().rsmi_with_wildcards(self.smart) + # Build ITS graphs + temp = rsmi_to_its(wild_smart, core=True) + # substrate split from standardized RSMI + substrate_r, _ = self.rsmi.split(">>") + # Run engine forward with template without cleaning fragments + engine = ImbaEngine(substrate_r, temp, add_wildcard=True, clean_fragments=False) + out = engine.smarts_list + self.assertEqual(len(out), 1) + out_rsmi = Standardize().fit(out[0], remove_aam=True) + + self.assertIn("*", out_rsmi) + self.assertNotEqual(out_rsmi, self.rsmi) + + # Run engine forward with template with cleaning fragments + engine = ImbaEngine(substrate_r, temp, add_wildcard=True, clean_fragments=True) + out = engine.smarts_list + outs = [Standardize().fit(o, remove_aam=True) for o in out] + self.assertIn(self.rsmi, outs) + + def test_pipeline_backward(self): + """Test backward ImbaEngine pipeline end-to-end with and without fragment cleaning.""" + # Prepare wildcard and ITS template + wild_rsmi = WildCard().rsmi_with_wildcards(self.smart) + its = rsmi_to_its(wild_rsmi, core=True) + + _, substrate_p = self.rsmi.split(">>") + + # 1. Without fragment cleaning + engine = ImbaEngine( + substrate_p, + its, + add_wildcard=True, + clean_fragments=False, + invert=True, + partial=True, + ) + out = engine.smarts_list + self.assertEqual(len(out), 2) + out_rsmi = Standardize().fit(out[0], remove_aam=True) + self.assertIn("*", out_rsmi) + self.assertNotEqual(out_rsmi, self.rsmi) + + # 2. With fragment cleaning + engine_clean = ImbaEngine( + substrate_p, + its, + add_wildcard=True, + clean_fragments=True, + invert=True, + partial=True, + ) + + out_clean = engine_clean.smarts_list + self.assertEqual(len(out_clean), 2) + outs = [Standardize().fit(o, remove_aam=True) for o in out_clean] + self.assertIn(self.rsmi, outs) + + def test_invalid_rsmi(self): + """Invalid RSMI pipeline should raise an exception at Standardize or ITS step.""" + # Standardize should fail for invalid RSMI + with self.assertRaises(Exception): + # Attempt full pipeline + rsmi = Standardize().fit("not_a_rsmi", remove_aam=True) + wild = WildCard().rsmi_with_wildcards(rsmi) + _ = rsmi_to_its(wild, core=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/doc/figures/mtg_mechanism.png b/doc/figures/mtg_mechanism.png new file mode 100644 index 0000000..578932f Binary files /dev/null and b/doc/figures/mtg_mechanism.png differ diff --git a/doc/graph.rst b/doc/graph.rst index 1c6cca6..043884e 100644 --- a/doc/graph.rst +++ b/doc/graph.rst @@ -200,42 +200,77 @@ This example builds two reaction-center ITS graphs, computes their MCS mapping, :caption: Building and visualizing an MTG with composite ITS :linenos: - from synkit.IO.chem_converter import rsmi_to_its - from synkit.Graph.MTG.mcs_matcher import MCSMatcher from synkit.Graph.MTG.mtg import MTG - from synkit.Graph import clean_graph_keep_largest_component - from synkit.Vis import GraphVisualizer + from synkit.Graph.ITS.its_decompose import get_rc + from synkit.examples import list_examples, load_example import matplotlib.pyplot as plt + from synkit.Vis.graph_visualizer import GraphVisualizer - # 1) Define two related reaction SMILES and build their reaction-center ITS graphs - rsmi_list = [ - '[CH:4]([H:7])([H:8])[CH:5]=[O:6]>>[CH:4]([H:8])=[CH:5][O:6]([H:7])', # tautomerization - '[CH3:1][C:2]=[O:3].[CH:4]([H:8])=[CH:5][O:6]([H:7])>>' - '[CH3:1][C:2]([O:3][H:7])[CH:4]([H:8])[CH:5]=[O:6]' # nucleophilic addition - ] - rc_graphs = [rsmi_to_its(r, core=True) for r in rsmi_list] - # 2) Find MCS mapping between the two ITS graphs - mcs = MCSMatcher(node_label_names=['element', 'charge'], edge_attribute='order') - mcs.find_rc_mapping(rc_graphs[0], rc_graphs[1], mcs=True) - mapping = mcs.get_mappings()[0] + data = load_example("aldol") - # 3) Build the Mechanistic Transition Graph (MTG) - mtg = MTG(rc_graphs[0], rc_graphs[1], mapping) - mtg_graph = mtg.get_graph() + mech_neutral = data[0]['mechanisms'][1]['steps'] + smart_neutral = [i['smart_string'] for i in mech_neutral] - # 4) Also build the composite ITS by directly gluing the two RC graphs - its_composite = clean_graph_keep_largest_component(mtg_graph) + mech_acid = data[0]['mechanisms'][2]['steps'] + smart_acid = [i['smart_string'] for i in mech_acid] + # neutral + mtg = MTG(smart_neutral, mcs_mol=True) + mtg_its_neutral = mtg.get_compose_its() + mtg_rc_neutral = get_rc(mtg_its_neutral, keep_mtg=True) + rc_neutral = get_rc(mtg_its_neutral, keep_mtg=False) - # 5) Visualize all four graphs: two RCs, the composite ITS, and the MTG - fig, axes = plt.subplots(2, 2, figsize=(14, 6)) + # acid + mtg = MTG(smart_acid, mcs_mol=True) + mtg_its_acid = mtg.get_compose_its() + mtg_rc_acid = get_rc(mtg_its_acid, keep_mtg=True) + rc_acid = get_rc(mtg_its_acid, keep_mtg=False) + + # Visualize + fig, ax = plt.subplots(2, 2, figsize=(16, 8)) vis = GraphVisualizer() - vis.plot_its(rc_graphs[0], axes[0, 0], use_edge_color=True, title='A. Tautomerization RC') - vis.plot_its(rc_graphs[1], axes[0, 1], use_edge_color=True, title='B. Nucleophilic Addition RC') - vis.plot_its(its_composite, axes[1, 0], use_edge_color=True, title='C. Composite ITS') - vis.plot_its(mtg_graph, axes[1, 1], use_edge_color=True, title='D. Mechanistic TG', og=True) + vis.plot_its( + mtg_rc_neutral, + ax=ax[0, 0], + use_edge_color=True, + og=True, + title='A. MTG for aldol addition (neutral)', + title_font_size=20, + title_font_weight='medium', + title_font_style='normal' + ) + vis.plot_its( + rc_neutral, + ax=ax[0, 1], + use_edge_color=True, + og=True, + title='B. Reaction center (neutral)', + title_font_size=20, + title_font_weight='medium', + title_font_style='normal' + ) + vis.plot_its( + mtg_rc_acid, + ax=ax[1, 0], + use_edge_color=True, + og=True, + title='C. MTG for aldol addition (acid)', + title_font_size=20, + title_font_weight='medium', + title_font_style='normal' + ) + vis.plot_its( + rc_acid, + ax=ax[1, 1], + use_edge_color=True, + og=True, + title='D. Reaction center (acid)', + title_font_size=20, + title_font_weight='medium', + title_font_style='normal' + ) plt.tight_layout() plt.show() @@ -243,16 +278,13 @@ This example builds two reaction-center ITS graphs, computes their MCS mapping, .. container:: figure - .. image:: ./figures/mtg.png + .. image:: ./figures/mtg_mechanism.png :alt: Composite ITS and MTG visualization :align: center :width: 1000px *Figure:* - (A) Reaction‐center graph for the tautomerization step - (B) Reaction‐center graph for the nucleophilic addition step - (C) Composite ITS graph "gluing" both transformations - (D) Mechanistic Transition Graph (MTG) showing step-wise mechanism + Composition of the mechanistic sequences for aldol addition under neutral and acidic conditions, showing the composite MTG (left column) and the reaction center (right column). Context graph ------------- diff --git a/lint.sh b/lint.sh index 085068d..ba8fdd7 100755 --- a/lint.sh +++ b/lint.sh @@ -18,7 +18,9 @@ benchmark_reactor.py:W292,C901,\ syn_reactor.py:C901,\ sing.py:C901,\ turbo_iso.py:C901,\ -rule_vis.py:C901" \ +rule_vis.py:C901, +gml_to_graph.py:C901, +wildcard.py:C901" \ --exclude=venv,\ core_engine.py,\ rule_apply.py,\ diff --git a/pyproject.toml b/pyproject.toml index 657e6cf..9f3382e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "synkit" -version = "0.0.13" +version = "0.0.14" license = { text = "MIT" } license-files = ["LICENSE"] authors = [ diff --git a/recipe/meta.yaml b/recipe/meta.yaml index 4705fba..d4c6256 100644 --- a/recipe/meta.yaml +++ b/recipe/meta.yaml @@ -1,6 +1,6 @@ package: name: synkit - version: 0.0.13 + version: 0.0.14 source: path: .. diff --git a/synkit/Chem/Reaction/radical_wildcard.py b/synkit/Chem/Reaction/radical_wildcard.py index 166d981..7fe1f1e 100644 --- a/synkit/Chem/Reaction/radical_wildcard.py +++ b/synkit/Chem/Reaction/radical_wildcard.py @@ -4,6 +4,63 @@ from typing import Tuple, List, Optional, Dict +def clean_wc( + rsmi: str, invert: bool = False, max_frag: bool = False, wild_card: bool = True +) -> str: + """ + Clean wildcard-containing fragments from one side of a reaction SMILES, + optionally selecting the largest remaining fragment. + + :param rsmi: Reaction SMILES string in the form 'R>>P'. + :type rsmi: str + :param invert: If True, process the reactant side; otherwise the product side. + :type invert: bool + :param max_frag: If True, force fragment selection (implies wild_card=True). + :type max_frag: bool + :param wild_card: If True, remove fragments containing '*' before selection. + :type wild_card: bool + :returns: The processed reaction SMILES. + :rtype: str + :raises ValueError: If input does not split into reactant and product. + + Example + ------- + >>> clean_wc('A.B>>C.*', invert=False, wild_card=True) + 'A.B>>C' + >>> clean_wc('A.B>>C.D', invert=False, max_frag=True) + 'A.B>>C' + """ + # Ensure max_frag implies wild_card + if max_frag: + wild_card = True + + # Split into reactant and product + parts = rsmi.split(">>") + if len(parts) != 2: + raise ValueError("Reaction SMILES must contain exactly one '>>'.") + react, prod = parts + + # Select side to process + side = react if invert else prod + + processed = side + if wild_card: + frags = side.split(".") + # Filter out fragments containing wildcards + filtered = [frag for frag in frags if "*" not in frag] + if len(filtered) > 1: + # select the longest fragment + processed = max(filtered, key=len) + elif len(filtered) == 1: + processed = filtered[0] + # if no filtered fragments or single fragment, keep original side + + # Reconstruct and return + if invert: + return f"{processed}>>{prod}" + return f"{react}>>{processed}" + + class RadicalWildcardAdder: """A utility for adding wildcard dummy atoms ([*]) to radical centers in reaction SMILES, with unique incremental atom-map indices and correct diff --git a/synkit/Graph/Feature/Fingerprint/__init__.py b/synkit/Graph/Feature/Fingerprint/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/synkit/Graph/Feature/Fingerprint/wl_rxn_fps.py b/synkit/Graph/Feature/Fingerprint/wl_rxn_fps.py new file mode 100644 index 0000000..a1cb4e2 --- /dev/null +++ b/synkit/Graph/Feature/Fingerprint/wl_rxn_fps.py @@ -0,0 +1,231 @@ +from __future__ import annotations +import networkx as nx +from collections import Counter +from hashlib import blake2b +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Any, Union +import numpy as np +from synkit.IO import rsmi_to_graph + + +@dataclass +class WLRnxFps: + """ + Weisfeiler-Lehman Reaction Fingerprint Sketch (NetworkX). + + :param radius: number of WL refinement iterations + :type radius: int + :param size: bit budget for the parity sketch + :type size: int + :param to_array: whether to return fingerprint as NumPy array + :type to_array: bool + + Usage example: + >>> rsmi = ('COc1cc(NC(N)=S)ccc1-n1cnc(C)c1.O=C1C(Br)CCCC1c1ccc(Cl)cc1Cl' + >>> +'>>Br.COc1cc(Nc2nc3c(s2)CCCC3c2ccc(Cl)cc2Cl)ccc1-n1cnc(C)c1.O') + >>> react, prod = rsmi_to_graph(rsmi, drop_non_aam=False, use_index_as_atom_map=False) + >>> fps = WLRnxFps(radius=2, size=1024, to_array=False).fit(react, prod) + >>> bits = fps.fingerprint + """ + + radius: int = 2 + size: int = 1024 + to_array: bool = False + + _tokens_R: Optional[Counter] = field(init=False, default=None) + _tokens_P: Optional[Counter] = field(init=False, default=None) + _delta: Optional[Counter] = field(init=False, default=None) + _support: Optional[List[int]] = field(init=False, default=None) + _fingerprint: Optional[Union[List[int], np.ndarray]] = field( + init=False, default=None + ) + + def fit(self, react: nx.Graph, prod: nx.Graph) -> WLRnxFps: + """ + Compute WL tokens for reactant and product graphs, then build parity sketch on Δ-support. + + :param react: reactant graph with node attrs 'element','aromatic','hcount','charge' + :type react: nx.Graph + :param prod: product graph with same node/edge attrs + :type prod: nx.Graph + :returns: self + :rtype: WLRnxFps + :raises ValueError: if size is not positive + """ + if self.size <= 0: + raise ValueError("size must be a positive integer") + + def wl_tokens(G: nx.Graph) -> Counter: + labels: Dict[int, int] = {} + for n, attrs in G.nodes(data=True): + atom_tuple = ( + attrs.get("element"), + bool(attrs.get("aromatic", False)), + int(attrs.get("charge", 0)), + int(attrs.get("hcount", 0)), + G.degree(n), + ) + labels[n] = _h64(("init", atom_tuple)) + cnt = Counter(labels.values()) + for k in range(1, self.radius + 1): + new_labels: Dict[int, int] = {} + for n in G.nodes(): + neigh = [] + for m in G.neighbors(n): + bond_order = float(G.edges[n, m].get("order", 1.0)) + neigh.append((_h64(("bond", bond_order)), labels[m])) + neigh.sort() + new_labels[n] = _h64(("wl", k, labels[n], tuple(neigh))) + labels = new_labels + cnt.update(labels.values()) + return cnt + + TR = wl_tokens(react) + TP = wl_tokens(prod) + + Delta = Counter(TP) + for h, v in TR.items(): + Delta[h] -= v + if Delta[h] == 0: + del Delta[h] + + support = list(Delta.keys()) + self._tokens_R = TR + self._tokens_P = TP + self._delta = Delta + self._support = support + + bits = np.zeros(self.size, dtype=int) if self.to_array else [0] * self.size + for h in support: + idx = h % self.size + if self.to_array: + bits[idx] ^= 1 + else: + bits[idx] = bits[idx] ^ 1 + self._fingerprint = bits + + return self + + @classmethod + def from_rsmi( + cls, + rsmi: str, + radius: int = 2, + size: int = 1024, + to_array: bool = False, + drop_non_aam: bool = False, + use_index_as_atom_map: bool = False, + ) -> WLRnxFps: + """ + Build WLRnxFps directly from a reaction SMILES string. + + :param rsmi: reaction SMILES string + :type rsmi: str + :param radius: number of WL refinement iterations + :type radius: int + :param size: bit budget for the parity sketch + :type size: int + :param to_array: return fingerprint as NumPy array if True + :type to_array: bool + :param drop_non_aam: drop atoms without atom-atom mapping + :type drop_non_aam: bool + :param use_index_as_atom_map: interpret node indices as atom map numbers + :type use_index_as_atom_map: bool + :returns: fitted WLRnxFps instance + :rtype: WLRnxFps + :raises ValueError: on invalid SMILES parsing + """ + try: + react, prod = rsmi_to_graph( + rsmi, + drop_non_aam=drop_non_aam, + use_index_as_atom_map=use_index_as_atom_map, + ) + except Exception as e: + raise ValueError(f"Failed to parse rsmi: {e}") + + return cls(radius=radius, size=size, to_array=to_array).fit(react, prod) + + @property + def tokens_R(self) -> Counter: + """ + :returns: WL token counts for reactant + :rtype: Counter + :raises AttributeError: if fit() has not been called + """ + if self._tokens_R is None: + raise AttributeError("Call fit() before accessing tokens_R") + return self._tokens_R + + @property + def tokens_P(self) -> Counter: + """ + :returns: WL token counts for product + :rtype: Counter + :raises AttributeError: if fit() has not been called + """ + if self._tokens_P is None: + raise AttributeError("Call fit() before accessing tokens_P") + return self._tokens_P + + @property + def delta(self) -> Counter: + """ + :returns: Signed token difference (product - reactant) + :rtype: Counter + :raises AttributeError: if fit() has not been called + """ + if self._delta is None: + raise AttributeError("Call fit() before accessing delta") + return self._delta + + @property + def support(self) -> List[int]: + """ + :returns: Tokens with non-zero delta + :rtype: List[int] + :raises AttributeError: if fit() has not been called + """ + if self._support is None: + raise AttributeError("Call fit() before accessing support") + return self._support + + @property + def fingerprint(self) -> Union[List[int], np.ndarray]: + """ + :returns: Parity sketch bit vector (0/1) + :rtype: Union[List[int], numpy.ndarray] + :raises AttributeError: if fit() has not been called + """ + if self._fingerprint is None: + raise AttributeError("Call fit() before accessing fingerprint") + return self._fingerprint + + def __repr__(self) -> str: + support_len = len(self._support) if self._support is not None else 0 + return ( + f"" + ) + + def help(self) -> None: + """ + Print usage examples and class docstring. + + :returns: None + """ + print(self.__doc__) + + +def _h64(obj: Any) -> int: + """ + Compute a stable 64-bit hash of an object. + + :param obj: any hashable representation + :type obj: Any + :returns: 64-bit integer hash + :rtype: int + """ + h = blake2b(digest_size=8) + h.update(repr(obj).encode("utf-8")) + return int.from_bytes(h.digest(), "little") diff --git a/synkit/Graph/ITS/its_decompose.py b/synkit/Graph/ITS/its_decompose.py index df9a1fd..416fd33 100644 --- a/synkit/Graph/ITS/its_decompose.py +++ b/synkit/Graph/ITS/its_decompose.py @@ -1,6 +1,6 @@ import re import networkx as nx -from typing import Optional, List +from typing import Optional, List, Any from rdkit import Chem from rdkit.Chem.MolStandardize import rdMolStandardize @@ -8,75 +8,300 @@ __all__ = ["get_rc", "its_decompose"] +# def get_rc( +# ITS: nx.Graph, +# element_key: List[str] = ["element", "charge", "typesGH", "atom_map"], +# bond_key: str = "order", +# standard_key: str = "standard_order", +# disconnected: bool = False, +# ) -> nx.Graph: +# """Extract the reaction-center (RC) subgraph from an ITS graph. + +# This function identifies: +# 1. All bonds whose standard order (difference between ITS orders) is non-zero. +# 2. All H–H bonds, ensuring they are included even if no order change is detected. +# 3. (Optional) Additional nodes with charge changes and reconnection of edges +# if `disconnected=True`. + +# :param ITS: The integrated transition-state graph with composite node/edge attributes. +# :type ITS: nx.Graph +# :param element_key: List of node‐attribute keys to copy into the RC graph. +# :type element_key: List[str] +# :param bond_key: Edge attribute key representing the tuple of bond orders. +# :type bond_key: str +# :param standard_key: Edge attribute key for the computed standard_order. +# :type standard_key: str +# :param disconnected: If True, also include nodes with charge changes and +# reconnect any ITS edges between RC nodes. +# :type disconnected: bool +# :returns: A new graph containing only the reaction-center nodes and edges. +# :rtype: nx.Graph + +# :example: +# >>> ITS = nx.Graph() +# >>> # ... populate ITS with 'order', 'standard_order', 'typesGH', etc. ... +# >>> RC = get_rc(ITS, disconnected=True) +# >>> isinstance(RC, nx.Graph) +# True +# """ +# rc = nx.Graph() +# _add_bond_order_changes(ITS, rc, element_key, bond_key, standard_key) + +# # 1.5) H-H bonds (force inclusion, with fallback typesGH) +# for u, v, data in ITS.edges(data=True): +# elem_u = ITS.nodes[u].get("element") +# elem_v = ITS.nodes[v].get("element") +# if elem_u == "H" and elem_v == "H": +# for n in (u, v): +# node_data = dict(ITS.nodes[n]) +# if "typesGH" not in node_data: +# node_data["typesGH"] = ( +# ("H", False, 0, 0, []), +# ("*", False, 0, 0, []), +# ) +# # Ensure typesGH is available even if not in original element_key +# final_attrs = {k: node_data[k] for k in element_key if k in node_data} +# final_attrs["typesGH"] = node_data["typesGH"] +# rc.add_node(n, **final_attrs) + +# rc.add_edge( +# u, +# v, +# **{ +# bond_key: data.get(bond_key), +# standard_key: data.get(standard_key), +# }, +# ) +# if disconnected: +# _add_charge_change_nodes(ITS, rc, element_key) +# _reconnect_rc_edges(ITS, rc, bond_key, standard_key) + +# return rc + + +# def get_rc( +# ITS: nx.Graph, +# element_key: list[str] = ["element", "charge", "typesGH", "atom_map"], +# bond_key: str = "order", +# standard_key: str = "standard_order", +# disconnected: bool = False, +# keep_mtg: bool = False, +# ) -> nx.Graph: +# """ +# Extract the reaction-center (RC) subgraph from an ITS graph. + +# This function identifies: +# 1. All bonds whose standard order is non-zero. +# 2. (Optional) All bonds labeled with 'is_mtg=True' if keep_mtg is True. +# 3. All H-H bonds, ensuring they are included even if no order change is detected. +# 4. (Optional) Additional nodes with charge changes and reconnection of edges +# if `disconnected=True`. + +# :param ITS: The integrated transition-state graph with composite node/edge attributes. +# :type ITS: nx.Graph +# :param element_key: List of node-attribute keys to copy into the RC graph. +# :type element_key: List[str] +# :param bond_key: Edge attribute key representing the tuple of bond orders. +# :type bond_key: str +# :param standard_key: Edge attribute key for the computed standard_order. +# :type standard_key: str +# :param disconnected: If True, also include nodes with charge changes and +# reconnect any ITS edges between RC nodes. +# :type disconnected: bool +# :param keep_mtg: If True, also include edges where 'is_mtg' attribute is True. +# :type keep_mtg: bool +# :returns: A new graph containing only the reaction-center nodes and edges. +# :rtype: nx.Graph +# """ +# rc = nx.Graph() +# # 1) Bonds with standard order change or mechanistic transition +# for u, v, data in ITS.edges(data=True): +# std = data.get(standard_key) +# is_mtg_attr = data.get("is_mtg", False) +# include = False +# if isinstance(std, (int, float)) and std != 0: +# include = True +# if keep_mtg and is_mtg_attr: +# include = True +# if not include: +# continue +# # add nodes +# for n in (u, v): +# if not rc.has_node(n): +# node_data = dict(ITS.nodes[n]) +# final_attrs = {k: node_data[k] for k in element_key if k in node_data} +# rc.add_node(n, **final_attrs) +# # add edge +# edge_attrs = { +# bond_key: data.get(bond_key), +# standard_key: std, +# "is_mtg": is_mtg_attr, +# } +# rc.add_edge(u, v, **edge_attrs) + +# # 2) H-H bonds (force inclusion, with fallback typesGH) +# for u, v, data in ITS.edges(data=True): +# elem_u = ITS.nodes[u].get("element") +# elem_v = ITS.nodes[v].get("element") +# if elem_u == "H" and elem_v == "H": +# for n in (u, v): +# if not rc.has_node(n): +# node_data = dict(ITS.nodes[n]) +# if "typesGH" not in node_data: +# node_data["typesGH"] = ( +# ("H", False, 0, 0, []), +# ("*", False, 0, 0, []), +# ) +# final_attrs = { +# k: node_data[k] for k in element_key if k in node_data +# } +# final_attrs["typesGH"] = node_data["typesGH"] +# rc.add_node(n, **final_attrs) +# if not rc.has_edge(u, v): +# rc.add_edge( +# u, +# v, +# **{ +# bond_key: data.get(bond_key), +# standard_key: data.get(standard_key), +# "is_mtg": data.get("is_mtg", False), +# }, +# ) + +# if disconnected: +# _add_charge_change_nodes(ITS, rc, element_key) +# _reconnect_rc_edges(ITS, rc, bond_key, standard_key) + +# return rc + +# import networkx as nx +# from typing import List, Any + + def get_rc( ITS: nx.Graph, element_key: List[str] = ["element", "charge", "typesGH", "atom_map"], bond_key: str = "order", standard_key: str = "standard_order", disconnected: bool = False, + keep_mtg: bool = False, ) -> nx.Graph: - """Extract the reaction-center (RC) subgraph from an ITS graph. - - This function identifies: - 1. All bonds whose standard order (difference between ITS orders) is non-zero. - 2. All H–H bonds, ensuring they are included even if no order change is detected. - 3. (Optional) Additional nodes with charge changes and reconnection of edges - if `disconnected=True`. - - :param ITS: The integrated transition-state graph with composite node/edge attributes. - :type ITS: nx.Graph - :param element_key: List of node‐attribute keys to copy into the RC graph. - :type element_key: List[str] - :param bond_key: Edge attribute key representing the tuple of bond orders. - :type bond_key: str - :param standard_key: Edge attribute key for the computed standard_order. - :type standard_key: str - :param disconnected: If True, also include nodes with charge changes and - reconnect any ITS edges between RC nodes. - :type disconnected: bool - :returns: A new graph containing only the reaction-center nodes and edges. - :rtype: nx.Graph - - :example: - >>> ITS = nx.Graph() - >>> # ... populate ITS with 'order', 'standard_order', 'typesGH', etc. ... - >>> RC = get_rc(ITS, disconnected=True) - >>> isinstance(RC, nx.Graph) - True + """ + Extract the reaction-center (RC) subgraph from an ITS graph. """ rc = nx.Graph() - _add_bond_order_changes(ITS, rc, element_key, bond_key, standard_key) + _add_changed_bonds(ITS, rc, element_key, bond_key, standard_key, keep_mtg) + _add_hh_bonds(ITS, rc, element_key, bond_key, standard_key) + if disconnected: + _add_charge_change_nodes(ITS, rc, element_key) + _reconnect_rc_edges(ITS, rc, bond_key, standard_key) + return rc - # 1.5) H-H bonds (force inclusion, with fallback typesGH) + +def _add_changed_bonds( + ITS: nx.Graph, + rc: nx.Graph, + element_key: List[str], + bond_key: str, + standard_key: str, + keep_mtg: bool, +) -> None: + """ + Add bonds with non-zero standard order or mechanistic transitions. + """ + for u, v, data in ITS.edges(data=True): + std = data.get(standard_key) + is_mtg_attr = data.get("is_mtg", False) + if not _should_include_edge(std, is_mtg_attr, keep_mtg): + continue + _ensure_node(rc, ITS, u, element_key) + _ensure_node(rc, ITS, v, element_key) + rc.add_edge( + u, + v, + **{bond_key: data.get(bond_key), standard_key: std, "is_mtg": is_mtg_attr}, + ) + + +def _add_hh_bonds( + ITS: nx.Graph, + rc: nx.Graph, + element_key: List[str], + bond_key: str, + standard_key: str, +) -> None: + """ + Force inclusion of H-H bonds, with fallback for typesGH. + """ for u, v, data in ITS.edges(data=True): - elem_u = ITS.nodes[u].get("element") - elem_v = ITS.nodes[v].get("element") - if elem_u == "H" and elem_v == "H": + if _is_hh_pair(ITS, u, v): for n in (u, v): - node_data = dict(ITS.nodes[n]) - if "typesGH" not in node_data: - node_data["typesGH"] = ( - ("H", False, 0, 0, []), - ("*", False, 0, 0, []), - ) - # Ensure typesGH is available even if not in original element_key - final_attrs = {k: node_data[k] for k in element_key if k in node_data} - final_attrs["typesGH"] = node_data["typesGH"] - rc.add_node(n, **final_attrs) + _ensure_node_hh(rc, ITS, n, element_key) + if not rc.has_edge(u, v): + rc.add_edge( + u, + v, + **{ + bond_key: data.get(bond_key), + standard_key: data.get(standard_key), + "is_mtg": data.get("is_mtg", False), + }, + ) - rc.add_edge( - u, - v, - **{ - bond_key: data.get(bond_key), - standard_key: data.get(standard_key), - }, - ) - if disconnected: - _add_charge_change_nodes(ITS, rc, element_key) - _reconnect_rc_edges(ITS, rc, bond_key, standard_key) - return rc +def _should_include_edge( + std: Any, + is_mtg_attr: bool, + keep_mtg: bool, +) -> bool: + """ + Determine if an edge should be included based on standard order and mechanistic flag. + """ + if isinstance(std, (int, float)) and std != 0: + return True + if keep_mtg and is_mtg_attr: + return True + return False + + +def _is_hh_pair(ITS: nx.Graph, u: Any, v: Any) -> bool: + """ + Check if both nodes of an edge are hydrogen. + """ + return ITS.nodes[u].get("element") == "H" and ITS.nodes[v].get("element") == "H" + + +def _ensure_node( + rc: nx.Graph, + ITS: nx.Graph, + node: Any, + element_key: List[str], +) -> None: + """ + Add a node to RC with selected attributes if not already present. + """ + if not rc.has_node(node): + node_data = ITS.nodes[node] + final_attrs = {k: node_data[k] for k in element_key if k in node_data} + rc.add_node(node, **final_attrs) + + +def _ensure_node_hh( + rc: nx.Graph, + ITS: nx.Graph, + node: Any, + element_key: List[str], +) -> None: + """ + Add H node to RC, ensuring typesGH fallback if missing. + """ + if not rc.has_node(node): + node_data = dict(ITS.nodes[node]) + if "typesGH" not in node_data: + node_data["typesGH"] = (("H", False, 0, 0, []), ("*", False, 0, 0, [])) + final_attrs = {k: node_data[k] for k in element_key if k in node_data} + final_attrs["typesGH"] = node_data["typesGH"] + rc.add_node(node, **final_attrs) def _carry_node_attrs(src: nx.Graph, dst: nx.Graph, n: int, keys: List[str]) -> None: @@ -139,172 +364,6 @@ def _add_bond_order_changes( ) -# def get_rc( -# ITS: nx.Graph, -# element_key: List[str] = ["element", "charge", "typesGH", "atom_map"], -# bond_key: str = "order", -# standard_key: str = "standard_order", -# disconnected: bool = False, -# ) -> nx.Graph: -# """ -# Extract the reaction center (RC) from ITS graph. - -# Enhancements: -# - Adds nodes and edges where bond order changes (core logic). -# - If disconnected=True: -# - Adds nodes with charge change based on typesGH. -# - Reconnects any ITS edge between two RC nodes. -# - NEW: Always includes H-H bonds in RC. Adds default typesGH if missing. -# """ -# rc = nx.Graph() - -# # 1) edges with bond-order change -# for u, v, data in ITS.edges(data=True): -# old, new = data.get(bond_key, [None, None]) -# if old != new: -# for n in (u, v): -# if not rc.has_node(n): -# rc.add_node( -# n, -# **{ -# k: ITS.nodes[n][k] for k in element_key if k in ITS.nodes[n] -# }, -# ) -# rc.add_edge( -# u, -# v, -# **{bond_key: data.get(bond_key), standard_key: data.get(standard_key)}, -# ) - -# # 1.5) H-H bonds (force inclusion, with fallback typesGH) -# for u, v, data in ITS.edges(data=True): -# elem_u = ITS.nodes[u].get("element") -# elem_v = ITS.nodes[v].get("element") -# if elem_u == "H" and elem_v == "H": -# for n in (u, v): -# node_data = dict(ITS.nodes[n]) -# if "typesGH" not in node_data: -# node_data["typesGH"] = ( -# ("H", False, 0, 0, []), -# ("*", False, 0, 0, []), -# ) -# # Ensure typesGH is available even if not in original element_key -# final_attrs = {k: node_data[k] for k in element_key if k in node_data} -# final_attrs["typesGH"] = node_data["typesGH"] -# rc.add_node(n, **final_attrs) - -# rc.add_edge( -# u, -# v, -# **{ -# bond_key: data.get(bond_key), -# standard_key: data.get(standard_key), -# }, -# ) - -# if disconnected: -# # 2) nodes with typesGH-based charge change -# for n, data in ITS.nodes(data=True): -# gh = data.get("typesGH") -# if ( -# isinstance(gh, (list, tuple)) -# and len(gh) >= 2 -# and len(gh[0]) > 3 -# and len(gh[1]) > 3 -# and gh[0][3] != gh[1][3] -# ): -# if not rc.has_node(n): -# rc.add_node(n, **{k: data[k] for k in element_key if k in data}) - -# # 3) reconnect RC nodes -# for u, v, data in ITS.edges(data=True): -# if rc.has_node(u) and rc.has_node(v) and not rc.has_edge(u, v): -# rc.add_edge( -# u, -# v, -# **{ -# bond_key: data.get(bond_key), -# standard_key: data.get(standard_key), -# }, -# ) - -# return rc - - -# def get_rc( -# ITS: nx.Graph, -# element_key: List[str] = ["element", "charge", "typesGH", "atom_map"], -# bond_key: str = "order", -# standard_key: str = "standard_order", -# disconnected: bool = False, -# ) -> nx.Graph: -# """ -# Extract the reaction center (RC) from ITS by: - -# 1. Always adding any edge whose bond order changes -# (bond_key[0] != bond_key[1]), plus its two end-nodes. -# 2. [if disconnected=True] Adding any node whose 'typesGH' record shows a charge change -# (typesGH[0][3] != typesGH[1][3]), even if isolated. -# 3. [if disconnected=True] Re-adding any ITS edge between two nodes already in RC -# (to preserve connectivity), carrying over bond_key & standard_key. - -# Parameters: -# - ITS (nx.Graph): input ITS graph. -# - element_key (List[str]): node attrs to carry over. -# - bond_key (str): edge attr key for bond order. -# - standard_key (str): edge attr key for standard order. -# - disconnected (bool): if True, include “charge-change” nodes (step 2) and -# reconnect any edges among RC nodes (step 3). If False, only performs step 1. -# """ -# rc = nx.Graph() - -# # 1) edges with bond-order change -# for u, v, data in ITS.edges(data=True): -# old, new = data.get(bond_key, [None, None]) -# if old != new: -# for n in (u, v): -# if not rc.has_node(n): -# rc.add_node( -# n, -# **{ -# k: ITS.nodes[n][k] for k in element_key if k in ITS.nodes[n] -# }, -# ) -# rc.add_edge( -# u, -# v, -# **{bond_key: data.get(bond_key), standard_key: data.get(standard_key)}, -# ) - -# if disconnected: -# # 2) nodes with a typesGH-based charge change -# for n, data in ITS.nodes(data=True): -# gh = data.get("typesGH") -# if ( -# isinstance(gh, (list, tuple)) -# and len(gh) >= 2 -# and len(gh[0]) > 3 -# and len(gh[1]) > 3 -# and gh[0][3] != gh[1][3] -# ): -# if not rc.has_node(n): -# rc.add_node(n, **{k: data[k] for k in element_key if k in data}) - -# # 3) re-add any ITS edge between RC nodes to preserve connectivity -# for u, v, data in ITS.edges(data=True): -# if rc.has_node(u) and rc.has_node(v) and not rc.has_edge(u, v): -# rc.add_edge( -# u, -# v, -# **{ -# bond_key: data.get(bond_key), -# standard_key: data.get(standard_key), -# }, -# ) - -# return rc - - def its_decompose(its_graph: nx.Graph, nodes_share="typesGH", edges_share="order"): """Decompose an ITS graph into two separate reactant (G) and product (H) graphs. diff --git a/synkit/Graph/MTG/mcs_matcher.py b/synkit/Graph/MTG/mcs_matcher.py index e8a6228..ea27d83 100644 --- a/synkit/Graph/MTG/mcs_matcher.py +++ b/synkit/Graph/MTG/mcs_matcher.py @@ -2,7 +2,7 @@ ================================================= A convenience wrapper around ``networkx.algorithms.isomorphism.GraphMatcher`` -that finds *all* common‑subgraph (or maximum‑common‑subgraph) node mappings +that finds *all* common-subgraph (or maximum-common-subgraph) node mappings between two molecular graphs. Highlights @@ -17,8 +17,9 @@ ``MCSMatcher(node_label_names, node_label_defaults, edge_attribute='order', allow_shift=True)`` Construct a matcher instance. -``matcher.find_common_subgraph(G1, G2, mcs=False)`` - Run the search (stores but does *not* return mappings). +``matcher.find_common_subgraph(G1, G2, mcs=False, mcs_mol=False)`` + Run the search (stores but does *not* return mappings). If ``mcs_mol=True``, + find mappings by matching entire connected components (largest molecules). ``matcher.get_mappings()`` Retrieve the stored mapping list. @@ -64,9 +65,6 @@ class MCSMatcher: Placeholder for future asymmetric rules (ignored for scalars). """ - # --------------------------------------------------------------------- - # Construction - # --------------------------------------------------------------------- def __init__( self, node_label_names: Optional[List[str]] | None = None, @@ -93,9 +91,6 @@ def __init__( self._mappings: List[Dict[int, int]] = [] self._last_size: int = 0 - # ------------------------------------------------------------------ - # Private helpers - # ------------------------------------------------------------------ def _edge_match( self, host_attrs: Dict[str, Any], pat_attrs: Dict[str, Any] ) -> bool: @@ -109,27 +104,74 @@ def _edge_match( @staticmethod def _invert_mapping(gm_mapping: Dict[int, int]) -> Dict[int, int]: - """Convert *host → pattern* dict to *pattern → host*.""" + """Convert *host→pattern* dict to *pattern→host*.""" return {pat: host for host, pat in gm_mapping.items()} - # ------------------------------------------------------------------ - # Public runners - # ------------------------------------------------------------------ + def _find_mcs_mol(self, G1: nx.Graph, G2: nx.Graph) -> Dict[int, int]: + """ + Match connected components of G1 to G2 of the same size, combining + each component's isomorphic mapping into one dict. + """ + # sort components by size descending + comps1 = sorted(nx.connected_components(G1), key=len, reverse=True) + comps2 = sorted(nx.connected_components(G2), key=len, reverse=True) + + used2: Set[frozenset[int]] = set() + combined: Dict[int, int] = {} + + for comp1 in comps1: + size = len(comp1) + sub1 = G1.subgraph(comp1) + + for comp2 in comps2: + if len(comp2) != size: + continue + key2 = frozenset(comp2) + if key2 in used2: + continue + + sub2 = G2.subgraph(comp2) + gm = GraphMatcher( + sub1, + sub2, + node_match=self.node_match, + edge_match=self._edge_match, + ) + if gm.is_isomorphic(): + combined.update(gm.mapping) + used2.add(key2) + break + + return combined + def find_common_subgraph( - self, G1: nx.Graph, G2: nx.Graph, *, mcs: bool = False + self, + G1: nx.Graph, + G2: nx.Graph, + *, + mcs: bool = False, + mcs_mol: bool = False, ) -> None: """Search for subgraph isomorphisms and cache the mappings. Parameters ---------- - G1 : nx.Graph – *pattern* graph (searched as a subgraph) - G2 : nx.Graph – *host* graph + G1 : nx.Graph - *pattern* graph (searched as a subgraph) + G2 : nx.Graph - *host* graph mcs : bool, optional If *True*, keep only mappings of maximum size. + mcs_mol : bool, optional + If *True*, match entire connected components (largest molecules). """ self._mappings.clear() self._last_size = 0 + if mcs_mol: + combined = self._find_mcs_mol(G1, G2) + self._mappings = [combined] + self._last_size = len(combined) + return + max_k = min(len(G1), len(G2)) sizes = range(max_k, 0, -1) seen: Set[tuple] = set() @@ -167,21 +209,22 @@ def find_common_subgraph( # final ordering – largest first then lexicographic self._mappings.sort(key=lambda d: (-len(d), tuple(sorted(d.items())))) - # ------------------------------------------------------------------ - # Convenience wrapper for ITS reaction‑centres - # ------------------------------------------------------------------ - def find_rc_mapping(self, rc1, rc2, *, mcs: bool = False) -> None: # type: ignore[override] + def find_rc_mapping( + self, + rc1, + rc2, + *, + mcs: bool = False, + mcs_mol: bool = False, + ) -> None: # type: ignore[override] if its_decompose is None: raise ImportError( "synkit is not available; cannot decompose reaction centres." ) _, r1 = its_decompose(rc1) l2, _ = its_decompose(rc2) - self.find_common_subgraph(r1, l2, mcs=mcs) + self.find_common_subgraph(r1, l2, mcs=mcs, mcs_mol=mcs_mol) - # ------------------------------------------------------------------ - # Properties and dunders - # ------------------------------------------------------------------ def get_mappings(self) -> List[Dict[int, int]]: """Return the cached mapping list (empty if `find_*` not yet called).""" @@ -192,7 +235,6 @@ def last_size(self) -> int: """Number of nodes in the most recent mapping set (0 if none).""" return self._last_size - # Pretty representations def __repr__(self) -> str: # noqa: D401 return ( f"MCSMatcher(mappings={len(self._mappings)}, last_size={self._last_size})" diff --git a/synkit/Graph/MTG/mtg.py b/synkit/Graph/MTG/mtg.py index 613ce68..6cc773d 100644 --- a/synkit/Graph/MTG/mtg.py +++ b/synkit/Graph/MTG/mtg.py @@ -1,208 +1,886 @@ +from __future__ import annotations + +"""MTG – Mechanistic Transition Graph fusion utility. + +This module exposes :class:`~MTG`, a helper that merges a chronological +sequence of **Intermediate Transition State** (ITS) graphs – or their RSMI +string representations – into a single *product* graph capturing the entire +bond-order history across the reaction trajectory. + +The implementation is self-contained except for the external *synkit* helpers +used for RSMI⇒ITS inter-conversion and canonicalisation. +""" + +from collections.abc import Iterator +from typing import Any, Dict, List, Mapping, MutableMapping, Set, Tuple, Union + import networkx as nx -from collections import defaultdict -from typing import Dict, List, Tuple, Any, Set, Union -# ----------------------------------------------------------------------------- -# Type aliases -# ----------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Optional dependencies +# --------------------------------------------------------------------------- +try: + import pandas as pd # type: ignore +except ImportError: # pragma: no cover – pandas is only required for to_dataframe() + pd = None # noqa: N816 + +from synkit.Graph.Hyrogen._misc import h_to_explicit +from synkit.Graph.ITS.normalize_aam import NormalizeAAM +from synkit.Graph.MTG.mcs_matcher import MCSMatcher +from synkit.Graph.MTG.utils import ( + normalize_hcount_and_typesGH, + normalize_order, + label_mtg_edges, + compute_standard_order, +) +from synkit.Graph.canon_graph import GraphCanonicaliser +from synkit.IO import its_to_rsmi, rsmi_to_its + NodeID = int -Order = Tuple[float, float] -Node = Tuple[NodeID, Dict[str, Any]] -Edge = Tuple[NodeID, NodeID, Dict[str, Any]] +OrderPair = Tuple[float, float] +MissingOrder = Tuple[Set[float], Set[float]] +GraphMapping = Dict[NodeID, NodeID] + +_PLACEHOLDER: MissingOrder = (set(), set()) +_PLACEHOLDER_TYPESGH = (set(), set(), set(), set(), set()) __all__ = ["MTG"] class MTG: - """Fuse two molecular graphs via a pair‑groupoid edge‑composition rule. - - Parameters - ---------- - G1, G2 - Input :class:`networkx.Graph` (or *DiGraph*) objects. Nodes must carry an - ``"element"`` attribute; edges carry an ``"order"`` 2‑tuple *(x, y)*. - mapping - A partial node map **G1 → G2** indicating which atoms are chemically - identical (intersection). Keys are node IDs in *G1*, values in *G2*. - - Notes - ----- - 1. ``intersection_ids`` are created where mapping ``G1[i] → G2[j]``. - 2. Edges are inserted in two passes: - * *Pass 1* – all edges from *G1* are copied unchanged. - * *Pass 2* – edges from *G2* are remapped; when both endpoints are in - ``intersection_ids`` **and** their bond orders satisfy the *pair‐ - groupoid* condition - - ``(a₁, a₂) + (b₁, b₂) with a₂ == b₁ → (a₁, b₂)``, - - the edges are *composed* instead of duplicated. - - Examples - -------- - >>> mtg = MTG(G1, G2, {1: 3, 4: 6, 5: 1}) - >>> fused = mtg.get_graph() - >>> fused.nodes(data=True) - ... + """Fuse a chronological series of ITS graphs into a Mechanistic Transition Graph. + + :param sequences: A list of ITS-format NetworkX graphs or RSMI strings. + :param mappings: Optional list of precomputed mappings; computed via MCS if None. + :param node_label_names: Keys for node-label matching. + :param canonicaliser: Optional GraphCanonicaliser for snapshot canonicalisation. + :raises ValueError: On invalid sequence or mapping lengths. + :raises RuntimeError: On mapping failures. """ - # ------------------------------------------------------------------ - # Construction helpers - # ------------------------------------------------------------------ def __init__( self, - G1: Union[nx.Graph, nx.DiGraph], - G2: Union[nx.Graph, nx.DiGraph], - mapping: Dict[NodeID, NodeID], + sequences: Union[List[nx.Graph], List[str]], + mappings: List[GraphMapping] | None = None, + *, + node_label_names: List[str] | None = None, + canonicaliser: GraphCanonicaliser | None = None, + mcs_mol: bool = False, + mcs: bool = False, ) -> None: - # Store originals - self.G1 = G1 - self.G2 = G2 - self.mapping12 = mapping # G1 → G2 - - # ---- 1. Build fused node set --------------------------------- - ( - self.product_nodes, # list[(id, attrs)] in fused graph - self.map1, # G1 id → fused id - self.map2, # G2 id → fused id - self.intersection_ids, # list[fused id] - ) = self._fuse_nodes() - - # ---- 2. Fuse edges with groupoid rule ------------------------ - fused_edges_step1 = self._insert_edges_from(self.G1.edges(data=True), self.map1) - self.product_edges = self._insert_edges_from( - self.G2.edges(data=True), self.map2, existing=fused_edges_step1 + if len(sequences) < 2: + raise ValueError("Need at least two snapshots.") + + self._node_label_names = node_label_names or ["element", "charge", "hcount"] + self._canonicaliser = canonicaliser + self.mcs_mol = mcs_mol + self.mcs = mcs + + self._graphs = self._prepare_graph_sequence(sequences) + self._k = len(self._graphs) + + self._mappings = ( + mappings if mappings is not None else self._compute_mappings(self._graphs) ) + if len(self._mappings) != self._k - 1: + raise ValueError("Mappings must match snapshot pairs.") - # ------------------------------------------------------------------ - # Node fusion - # ------------------------------------------------------------------ - def _fuse_nodes(self): - merged: Dict[NodeID, Dict[str, Any]] = {} - map1: Dict[NodeID, NodeID] = {} - map2: Dict[NodeID, NodeID] = {} - used: Set[NodeID] = set() - - # --- copy G1 directly into fused graph ------------------------ - for v, attrs in self.G1.nodes(data=True): - merged[v] = attrs.copy() - map1[v] = v - used.add(v) - - # inverse mapping: G2 node → G1 node it merges to - inv_map = {g2: g1 for g1, g2 in self.mapping12.items()} - intersection: List[NodeID] = [] - - # --- process G2 nodes ----------------------------------------- - next_id = max(used) + 1 if used else 0 - for v, attrs in self.G2.nodes(data=True): - if v in inv_map: # merged node - tgt = inv_map[v] - map2[v] = tgt - intersection.append(tgt) - else: # unique node from G2 - while next_id in used: - next_id += 1 - merged[next_id] = attrs.copy() - map2[v] = next_id - used.add(next_id) - next_id += 1 - - nodes_sorted = sorted(merged.items()) # list[(id, attrs)] - return nodes_sorted, map1, map2, intersection - - # ------------------------------------------------------------------ - # Edge insertion & groupoid composition - # ------------------------------------------------------------------ - def _insert_edges_from( - self, edge_iter, node_map: Dict[NodeID, NodeID], existing: List[Edge] = None - ) -> List[Edge]: - """Insert edges into *existing* applying the groupoid rule when - possible.""" - existing = [] if existing is None else existing.copy() - - # Remap and append new edges - for u, v, attrs in edge_iter: - u3 = node_map[u] - v3 = node_map[v] - existing.append((u3, v3, attrs.copy())) - - # Canonicalize keys for undirected graphs - def key(u, v): - return (u, v) if isinstance(self.G1, nx.DiGraph) else tuple(sorted((u, v))) - - # Group edges by (u,v) - buckets: Dict[Tuple[NodeID, NodeID], List[Order]] = defaultdict(list) - bucket_src: Dict[Tuple[NodeID, NodeID], List[str]] = defaultdict(list) - for idx, (u, v, attrs) in enumerate(existing): - buckets[key(u, v)].append(tuple(attrs["order"])) - bucket_src[key(u, v)].append("G1" if idx < len(self.G1.edges) else "G2") - - fused_edges: List[Edge] = [] - for (u, v), orders in buckets.items(): - # src = bucket_src[(u, v)] - if ( - u in self.intersection_ids - and v in self.intersection_ids - and len(orders) >= 2 - ): - # Attempt pair‑wise composition between G1 (first) and any G2 edge - made_composite = False - for idx2, ord2 in enumerate(orders[1:], start=1): - a1, a2 = orders[0] - b1, b2 = ord2 - if a2 == b1: - fused_edges.append((u, v, {"order": (a1, b2)})) - made_composite = True - break - if not made_composite: - # fall back to *all* distinct orders - for ord_ in orders: - fused_edges.append((u, v, {"order": ord_})) - else: - for ord_ in orders: - fused_edges.append((u, v, {"order": ord_})) + self._prod_nodes: Dict[int, Dict[str, Any]] + self._node_map: Dict[Tuple[int, NodeID], int] + self._graph: nx.Graph + + self._build_node_map_and_attributes() + self._build_edge_history_and_graph() + + def __repr__(self) -> str: + return f"" + + def __len__(self) -> int: + return self._graph.number_of_nodes() + + def __iter__(self) -> Iterator[int]: + return iter(self._graph.nodes) + + def __getitem__(self, node_id: int) -> Dict[str, Any]: + return self._graph.nodes[node_id] - return self._dedupe_edges(fused_edges) + @staticmethod + def describe() -> str: + return ( + "# Usage example\n" + "mtg = MTG([G0, G1, G2])\n" + "mg = mtg.get_mtg()\n" + "rsmi = mtg.get_aam()\n" + ) + + def get_mtg(self, *, directed: bool = False) -> nx.Graph: + return self._graph.to_directed() if directed else self._graph + + def get_compose_its(self, *, directed: bool = False) -> nx.Graph: + g = self.get_mtg(directed=directed) + g = label_mtg_edges(g, inplace=False) + g = normalize_order(g) + g = normalize_hcount_and_typesGH(g) + return compute_standard_order(g) + + def get_aam(self, *, directed: bool = False, explicit_h: bool = False) -> str: + g = self.get_compose_its(directed=directed) + rsmi = its_to_rsmi(g, explicit_hydrogen=True) + return ( + NormalizeAAM().fit(rsmi, fix_aam_indice=False) if not explicit_h else rsmi + ) + + def to_dataframe(self): + if pd is None: + raise RuntimeError("pandas required for DataFrame export.") + return pd.DataFrame.from_dict( + dict(self._graph.nodes(data=True)), orient="index" + ) - # ------------------------------------------------------------------ @staticmethod - def _dedupe_edges(edges: List[Edge]) -> List[Edge]: - seen: Set[Tuple[int, int, Order]] = set() - out: List[Edge] = [] - for u, v, attrs in edges: - key = (min(u, v), max(u, v), tuple(attrs["order"])) - if key not in seen: - seen.add(key) - out.append((u, v, attrs)) + def _merge_attrs(lhs: MutableMapping[str, Any], rhs: Mapping[str, Any]) -> None: + for k, v in rhs.items(): + if not lhs.get(k) and v is not None: + lhs[k] = v + + def _build_node_map_and_attributes(self) -> None: + prod, node_map = {}, {} + last = self._graphs[-1] + for nid, attrs in last.nodes(data=True): + prod[nid] = attrs.copy() + node_map[(self._k - 1, nid)] = nid + pid_counter = max(prod, default=-1) + 1 + + # merge attributes backwards + for i in range(self._k - 2, -1, -1): + G, mp = self._graphs[i], self._mappings[i] + for nid, attrs in G.nodes(data=True): + tgt = mp.get(nid) + if tgt is not None and (i + 1, tgt) in node_map: + pid = node_map[(i + 1, tgt)] + self._merge_attrs(prod[pid], attrs) + else: + pid = pid_counter + prod[pid] = attrs.copy() + pid_counter += 1 + node_map[(i, nid)] = pid + + # assemble typesGH history per pid + first_idx: Dict[int, int] = {} + for (gi, n), p in node_map.items(): + # track the earliest snapshot index where pid appears + if p in first_idx: + first_idx[p] = min(first_idx[p], gi) + else: + first_idx[p] = gi + + for p, attrs in prod.items(): + hist: List[Any] = [] + fi = first_idx[p] + for i in range(self._k): + if i < fi: + hist.append(_PLACEHOLDER_TYPESGH) + elif i == fi: + val = ( + self._graphs[i] + .nodes[ + next( + n + for (gi, n), pp in node_map.items() + if gi == i and pp == p + ) + ] + .get("typesGH", (_PLACEHOLDER_TYPESGH, _PLACEHOLDER_TYPESGH)) + ) + hist.append(val) + else: + originals = [ + n for (gi, n), pp in node_map.items() if gi == i and pp == p + ] + if originals: + val = ( + self._graphs[i] + .nodes[originals[0]] + .get( + "typesGH", (_PLACEHOLDER_TYPESGH, _PLACEHOLDER_TYPESGH) + )[-1] + ) + hist.append(val) + else: + hist.append(_PLACEHOLDER_TYPESGH) + attrs["typesGH_history"] = tuple(hist) + attrs["typesGH"] = attrs["typesGH_history"] + + self._prod_nodes = prod + self._node_map = node_map + + def _build_edge_history_and_graph(self) -> None: + hist: Dict[Tuple[int, int], List[MissingOrder]] = {} + for i, G in enumerate(self._graphs): + for u, v, a in G.edges(data=True): + pu, pv = self._node_map[(i, u)], self._node_map[(i, v)] + key = tuple(sorted((pu, pv))) + lst = hist.setdefault(key, [_PLACEHOLDER] * self._k) + lst[i] = a.get("order", _PLACEHOLDER) + g = nx.Graph() + g.add_nodes_from(self._prod_nodes.items()) + for (u, v), lst in hist.items(): + g.add_edge(u, v, order=tuple(lst)) + if g.number_of_nodes() != len(self._prod_nodes): + raise RuntimeError("Node count mismatch.") + self._graph = g + + def _prepare_graph_sequence( + self, seq: List[nx.Graph] | List[str] + ) -> List[nx.Graph]: + out: List[nx.Graph] = [] + for item in seq: + g = rsmi_to_its(item, core=False) if isinstance(item, str) else item + if self._canonicaliser: + g = self._canonicaliser.canonicalise_graph(g).canonical_graph + g = h_to_explicit(g, its=True) + # out.append(g) + out.append(normalize_hcount_and_typesGH(g)) return out - # ------------------------------------------------------------------ - # Public helpers - # ------------------------------------------------------------------ - def get_nodes(self) -> List[Node]: - """List of `(id, attrs)` for fused graph.""" - return self.product_nodes - - def get_edges(self) -> List[Edge]: - """List of `(u, v, attrs)` for fused graph.""" - return self.product_edges - - def get_map1(self) -> Dict[NodeID, NodeID]: - return self.map1 - - def get_map2(self) -> Dict[NodeID, NodeID]: - return self.map2 - - def get_graph(self, *, directed: bool = False): - G = nx.DiGraph() if directed else nx.Graph() - G.add_nodes_from(self.product_nodes) - for u, v, attrs in self.product_edges: - o = attrs["order"] - attrs["standard_order"] = o[0] - o[1] if None not in o else None - G.add_edge(u, v, **attrs) - return G - - # ------------------------------------------------------------------ - def __repr__(self): - return f"MTG(|V|={len(self.product_nodes)}, |E|={len(self.product_edges)})" + def _compute_mappings(self, graphs: List[nx.Graph]) -> List[GraphMapping]: + mappings: List[GraphMapping] = [] + for i in range(len(graphs) - 1): + m = MCSMatcher(node_label_names=self._node_label_names) + m.find_rc_mapping( + graphs[i], graphs[i + 1], mcs=self.mcs, mcs_mol=self.mcs_mol + ) + if not m._mappings: + raise RuntimeError(f"No mapping between {i} and {i+1}") + mappings.append(m._mappings[0]) + return mappings + + @property + def node_mapping(self) -> Dict[Tuple[int, NodeID], int]: + return dict(self._node_map) + + @property + def k(self) -> int: + return self._k + + +# from __future__ import annotations + +# """MTG – Mechanistic Transition Graph fusion utility. + +# This module exposes :class:`~MTG`, a helper that merges a chronological +# sequence of **Intermediate Transition State** (ITS) graphs – or their RSMI +# string representations – into a single *product* graph capturing the entire +# bond-order history across the reaction trajectory. + +# The implementation is self-contained except for the external *synkit* helpers +# used for RSMI⇆ITS inter-conversion and canonicalisation. +# """ + +# from collections.abc import Iterator +# from typing import Any, Dict, List, Mapping, MutableMapping, Set, Tuple, Union + +# import networkx as nx + +# # --------------------------------------------------------------------------- +# # Optional dependencies +# # --------------------------------------------------------------------------- +# try: +# import pandas as pd # type: ignore +# except ImportError: # pragma: no cover – pandas is only required for to_dataframe() +# pd = None # noqa: N816 – keep lowercase alias even if stubbed + +# from synkit.Graph.Hyrogen._misc import h_to_explicit # noqa: WPS433 – external import +# from synkit.Graph.ITS.normalize_aam import NormalizeAAM # noqa: WPS433 +# from synkit.Graph.MTG.mcs_matcher import MCSMatcher # noqa: WPS433 +# from synkit.Graph.MTG.utils import ( +# normalize_hcount_and_typesGH, +# normalize_order, +# label_mtg_edges, +# compute_standard_order, +# ) # noqa: WPS433 +# from synkit.Graph.canon_graph import GraphCanonicaliser # noqa: WPS433 +# from synkit.IO import its_to_rsmi, rsmi_to_its # noqa: WPS433 + +# NodeID = int +# OrderPair = Tuple[float, float] +# MissingOrder = Tuple[Set[float], Set[float]] +# GraphMapping = Dict[NodeID, NodeID] + +# # A placeholder for a *missing* edge-order in a particular snapshot. Using +# # `set()` makes the value clearly distinguishable from genuine numeric orders. +# _PLACEHOLDER: MissingOrder = (set(), set()) + +# __all__ = [ +# "MTG", +# ] + + +# class MTG: # pylint: disable=too-many-instance-attributes +# """Fuse a chronological series of ITS graphs into a Mechanistic Transition Graph. + +# :param sequences: Either a list of ITS-format NetworkX graphs or a list of RSMI +# strings in chronological order. +# :type sequences: List[nx.Graph] or List[str] +# :param mappings: Pre-computed node mappings between each consecutive pair of graphs. +# If None, mappings are computed via MCSMatcher. +# :type mappings: List[GraphMapping] or None +# :param node_label_names: Node attribute keys used for MCS-based matching. +# :type node_label_names: List[str] or None +# :param canonicaliser: Optional GraphCanonicaliser to canonicalise each +# snapshot before fusion. +# :type canonicaliser: GraphCanonicaliser or None +# :raises ValueError: If fewer than two sequences are provided or mapping count mismatches. +# :raises TypeError: If sequence elements are neither NetworkX graphs nor RSMI strings. +# :raises RuntimeError: If automatic mapping fails for any adjacent graph pair. +# """ + +# # --------------------------------------------------------------------- +# # Construction helpers +# # --------------------------------------------------------------------- + +# def __init__( +# self, +# sequences: Union[List[nx.Graph], List[str]], +# mappings: List[GraphMapping] | None = None, +# *, +# node_label_names: List[str] | None = None, +# canonicaliser: GraphCanonicaliser | None = None, +# mcs_mol: bool = False, +# mcs: bool = False, +# ) -> None: # noqa: D401 – docstring handled above +# # --- Basic validation ------------------------------------------------ +# if len(sequences) < 2: # also covers non-list via __len__ check raising +# raise ValueError( +# "At least two ITS snapshots are required to construct an MTG.", +# ) + +# self._node_label_names: List[str] = node_label_names or [ +# "element", +# "charge", +# "hcount", +# ] +# self._canonicaliser = canonicaliser +# self.mcs_mol: bool = mcs_mol +# self.mcs: bool = mcs + +# # --- Input normalisation ------------------------------------------- +# self._graphs: List[nx.Graph] = self._prepare_graph_sequence(sequences) +# self._k: int = len(self._graphs) + +# # --- Graph-to-graph mappings --------------------------------------- +# self._mappings: List[GraphMapping] = ( +# self._compute_mappings(self._graphs) if mappings is None else mappings +# ) +# if len(self._mappings) != self._k - 1: +# raise ValueError( +# "Need exactly one mapping per pair of adjacent snapshots.", +# ) + +# # --- Core fusion machinery ----------------------------------------- +# self._prod_nodes: Dict[int, Dict[str, Any]] +# self._node_map: Dict[Tuple[int, NodeID], int] +# self._graph: nx.Graph # final fused graph – populated below + +# self._build_node_map_and_attributes() +# self._build_edge_history_and_graph() + +# # --------------------------------------------------------------------- +# # Python dunder & public helpers +# # --------------------------------------------------------------------- + +# def __repr__(self) -> str: # noqa: D401 – simple representation +# """Return a summary representation including snapshot count and graph size.""" +# return ( +# f"" +# ) + +# # Collection-like API --------------------------------------------------- + +# def __len__(self) -> int: +# """Return the number of fused nodes in the product graph.""" +# return self._graph.number_of_nodes() + +# def __iter__(self) -> Iterator[int]: +# """Iterate over fused node identifiers.""" +# return iter(self._graph.nodes) + +# def __getitem__(self, node_id: int) -> Dict[str, Any]: +# """Access the attribute dictionary of a fused node by its ID. + +# :param node_id: Fused node identifier +# :type node_id: int +# :returns: Node attribute mapping +# :rtype: Dict[str, Any] +# """ +# return self._graph.nodes[node_id] + +# # --------------------------------------------------------------------- +# # Public / user-facing API +# # --------------------------------------------------------------------- + +# @staticmethod +# def describe() -> str: # noqa: D401 – simple helper +# """Return an inline usage example for quick reference.""" + +# return ( +# "# Example usage\n" +# "mtg = MTG([G0, G1, G2])\n" +# "fused_graph = mtg.get_mtg()\n" +# "rsmi_with_aam = mtg.get_aam()\n" +# ) + +# # ------------------------------------------------------------------ +# # Graph export helpers +# # ------------------------------------------------------------------ + +# def get_mtg(self, *, directed: bool = False) -> nx.Graph: +# """Return the fused product graph. + +# :param directed: If True, return a directed copy of the fused graph +# :type directed: bool +# :returns: Fused product graph +# :rtype: networkx.Graph or networkx.DiGraph +# """ +# return self._graph.to_directed() if directed else self._graph + +# def get_compose_its(self, *, directed: bool = False) -> nx.Graph: +# """Return a graph with normalized edge orders for ITS export. + +# :param directed: If True, normalize a directed version +# :type directed: bool +# :returns: Graph with collapsed (order_G, order_H) tuples +# :rtype: networkx.Graph or networkx.DiGraph +# """ +# fused = self.get_mtg(directed=directed) +# fused = label_mtg_edges(fused, inplace=False) +# fused = normalize_order(fused) +# return compute_standard_order(fused) + +# def get_aam( +# self, +# *, +# directed: bool = False, +# explicit_h: bool = False, +# ) -> str: +# """Export fused graph to an RSMI string with atom-atom mapping. + +# :param directed: If True, use a directed ITS representation +# :type directed: bool +# :param explicit_h: If True, include explicit hydrogens; otherwise normalize AAM +# :type explicit_h: bool +# :returns: RSMI string with AAM +# :rtype: str +# """ + +# its_graph = self.get_compose_its(directed=directed) +# rsmi = its_to_rsmi(its_graph, explicit_hydrogen=True) +# if not explicit_h: +# rsmi = NormalizeAAM().fit(rsmi, fix_aam_indice=False) +# return rsmi + +# def to_dataframe(self): +# """Return a pandas DataFrame of fused node attributes. + +# :returns: DataFrame indexed by fused node IDs with attribute columns +# :rtype: pandas.DataFrame +# :raises RuntimeError: If pandas is not installed +# """ +# if pd is None: +# raise RuntimeError( +# "pandas is required for `to_dataframe()` but is not installed." +# ) +# return pd.DataFrame.from_dict( +# dict(self._graph.nodes(data=True)), orient="index" +# ) + +# # ------------------------------------------------------------------ +# # Node & edge fusion internals +# # ------------------------------------------------------------------ + +# @staticmethod +# def _merge_attrs(lhs: MutableMapping[str, Any], rhs: Mapping[str, Any]) -> None: +# """Update in-place, preferring non-empty or non-None values from rhs. + +# :param lhs: Target attribute dict to update +# :type lhs: MutableMapping[str, Any] +# :param rhs: Source attribute dict +# :type rhs: Mapping[str, Any] +# """ +# for key, value in rhs.items(): +# if ( +# not lhs.get(key) and value is not None +# ): # noqa: WPS501 – explicitly allow False/0 +# lhs[key] = value + +# # ................................................................. + +# def _build_node_map_and_attributes(self) -> None: +# """Construct fused nodes by merging snapshots backwards. + +# Builds: +# - self._prod_nodes: pid → attribute dict +# - self._node_map: (snapshot_index, original_node_id) → pid +# """ + +# prod: Dict[int, Dict[str, Any]] = {} +# node_map: Dict[Tuple[int, NodeID], int] = {} + +# # --- Seed with last snapshot ------------------------------------- +# last_graph = self._graphs[-1] +# for nid, attrs in last_graph.nodes(data=True): +# prod[nid] = attrs.copy() +# node_map[(self._k - 1, nid)] = nid +# next_pid: int = (max(prod) if prod else -1) + 1 + +# # --- Walk backwards and merge ------------------------------------ +# for idx in range(self._k - 2, -1, -1): +# G = self._graphs[idx] +# mapping = self._mappings[idx] +# for nid, attrs in G.nodes(data=True): +# target = mapping.get(nid) +# if target is not None and (idx + 1, target) in node_map: +# pid = node_map[(idx + 1, target)] +# self._merge_attrs(prod[pid], attrs) +# else: # new (unmapped) node – assign fresh pid +# while next_pid in prod: # safeguard although unlikely +# next_pid += 1 +# pid = next_pid +# prod[pid] = attrs.copy() +# next_pid += 1 +# node_map[(idx, nid)] = pid + +# self._prod_nodes = prod +# self._node_map = node_map + +# # ................................................................. + +# def _build_edge_history_and_graph( +# self, +# ) -> None: # noqa: C901 – complex but contained +# """Assemble the fused graph with per-edge order histores. + +# Each edge in the result has an `order` attribute: a tuple of +# length `k`, where each element is either an order-pair or a placeholder. +# """ + +# history: Dict[Tuple[int, int], List[MissingOrder]] = {} + +# # Collect order trajectories ----------------------------------------------------- +# for gi, G in enumerate(self._graphs): +# for u, v, attrs in G.edges(data=True): +# pu, pv = ( +# self._node_map[(gi, u)], +# self._node_map[(gi, v)], +# ) +# key = tuple(sorted((pu, pv))) # undirected canonical ordering +# orders = history.setdefault(key, [_PLACEHOLDER] * self._k) +# orders[gi] = attrs.get("order", _PLACEHOLDER) # type: ignore[arg-type] + +# # Build fused NetworkX graph ----------------------------------------------------- +# graph = nx.Graph() +# graph.add_nodes_from(self._prod_nodes.items()) +# for (u, v), orders in history.items(): +# graph.add_edge(u, v, order=tuple(orders)) + +# # Sanity check ------------------------------------------------------------------ +# if graph.number_of_nodes() != len(self._prod_nodes): +# raise RuntimeError("Node count mismatch during MTG assembly.") + +# self._graph = graph + +# # ------------------------------------------------------------------ +# # Mapping helpers +# # ------------------------------------------------------------------ + +# def _prepare_graph_sequence( +# self, seq: List[nx.Graph] | List[str] +# ) -> List[nx.Graph]: +# """Convert input list to a cleaned sequence of ITS graphs. + +# :param seq: Raw sequence of graphs or RSMI strings +# :type seq: List[nx.Graph] or List[str] +# :returns: List of normalized ITS graphs +# :rtype: List[nx.Graph] +# :raises TypeError: If an element is neither nx.Graph nor str +# """ + +# prepared: List[nx.Graph] = [] +# for item in seq: +# if isinstance(item, str): +# graph = rsmi_to_its(item, core=False) +# elif isinstance(item, nx.Graph): +# graph = item +# else: # pragma: no cover – guard against future unsupported types +# raise TypeError( +# "Sequences must contain either NetworkX graphs or RSMI strings.", +# ) + +# # Canonicalise (optional) --------------------------------------------------- +# if self._canonicaliser is not None: +# graph = self._canonicaliser.canonicalise_graph(graph).canonical_graph # type: ignore[attr-defined] + +# # Ensure explicit hydrogens & normalised hcount / typesGH ---------- +# graph = h_to_explicit(graph, its=True) +# graph = normalize_hcount_and_typesGH(graph) +# prepared.append(graph) + +# return prepared + +# # .................................................................. + +# def _compute_mappings(self, graphs: List[nx.Graph]) -> List[GraphMapping]: +# """Compute atom mappings via MCS matching for each adjacent pair. + +# :param graphs: ITS graphs in chronological order +# :type graphs: List[nx.Graph] +# :returns: List of mappings of length k-1 +# :rtype: List[GraphMapping] +# :raises RuntimeError: If no mapping found for a pair +# """ + +# mappings: List[GraphMapping] = [] +# for idx in range(len(graphs) - 1): +# matcher = MCSMatcher(node_label_names=self._node_label_names) +# matcher.find_rc_mapping( +# graphs[idx], graphs[idx + 1], mcs=self.mcs, mcs_mol=self.mcs_mol +# ) +# if not matcher._mappings: # pylint: disable=protected-access +# raise RuntimeError( +# f"No MCS mapping found between snapshots {idx} and {idx + 1}.", +# ) +# mappings.append(matcher._mappings[0]) # pylint: disable=protected-access +# return mappings + +# # ------------------------------------------------------------------ +# # Convenience accessors (mostly for unit tests) +# # ------------------------------------------------------------------ + +# @property +# def node_mapping(self) -> Dict[Tuple[int, NodeID], int]: +# """Return the internal mapping from (snapshot_index, original_node_id) to fused pid. + +# :returns: Mapping dictionary +# :rtype: Dict[Tuple[int, NodeID], int] +# """ +# return dict(self._node_map) + +# @property +# def k(self) -> int: +# """Return the number of snapshots fused. + +# :returns: Snapshot count +# :rtype: int +# """ +# return self._k + + +# import networkx as nx +# from collections import defaultdict +# from typing import Dict, List, Tuple, Any, Set, Union + +# # ----------------------------------------------------------------------------- +# # Type aliases +# # ----------------------------------------------------------------------------- +# NodeID = int +# Order = Tuple[float, float] +# Node = Tuple[NodeID, Dict[str, Any]] +# Edge = Tuple[NodeID, NodeID, Dict[str, Any]] + +# __all__ = ["MTG"] + + +# class MTG: +# """Fuse two molecular graphs via a pair‑groupoid edge‑composition rule. + +# Parameters +# ---------- +# G1, G2 +# Input :class:`networkx.Graph` (or *DiGraph*) objects. Nodes must carry an +# ``"element"`` attribute; edges carry an ``"order"`` 2‑tuple *(x, y)*. +# mapping +# A partial node map **G1 → G2** indicating which atoms are chemically +# identical (intersection). Keys are node IDs in *G1*, values in *G2*. + +# Notes +# ----- +# 1. ``intersection_ids`` are created where mapping ``G1[i] → G2[j]``. +# 2. Edges are inserted in two passes: +# * *Pass 1* – all edges from *G1* are copied unchanged. +# * *Pass 2* – edges from *G2* are remapped; when both endpoints are in +# ``intersection_ids`` **and** their bond orders satisfy the *pair‐ +# groupoid* condition + +# ``(a₁, a₂) + (b₁, b₂) with a₂ == b₁ → (a₁, b₂)``, + +# the edges are *composed* instead of duplicated. + +# Examples +# -------- +# >>> mtg = MTG(G1, G2, {1: 3, 4: 6, 5: 1}) +# >>> fused = mtg.get_graph() +# >>> fused.nodes(data=True) +# ... +# """ + +# # ------------------------------------------------------------------ +# # Construction helpers +# # ------------------------------------------------------------------ +# def __init__( +# self, +# G1: Union[nx.Graph, nx.DiGraph], +# G2: Union[nx.Graph, nx.DiGraph], +# mapping: Dict[NodeID, NodeID], +# ) -> None: +# # Store originals +# self.G1 = G1 +# self.G2 = G2 +# self.mapping12 = mapping # G1 → G2 + +# # ---- 1. Build fused node set --------------------------------- +# ( +# self.product_nodes, # list[(id, attrs)] in fused graph +# self.map1, # G1 id → fused id +# self.map2, # G2 id → fused id +# self.intersection_ids, # list[fused id] +# ) = self._fuse_nodes() + +# # ---- 2. Fuse edges with groupoid rule ------------------------ +# fused_edges_step1 = self._insert_edges_from(self.G1.edges(data=True), self.map1) +# self.product_edges = self._insert_edges_from( +# self.G2.edges(data=True), self.map2, existing=fused_edges_step1 +# ) + +# # ------------------------------------------------------------------ +# # Node fusion +# # ------------------------------------------------------------------ +# def _fuse_nodes(self): +# merged: Dict[NodeID, Dict[str, Any]] = {} +# map1: Dict[NodeID, NodeID] = {} +# map2: Dict[NodeID, NodeID] = {} +# used: Set[NodeID] = set() + +# # --- copy G1 directly into fused graph ------------------------ +# for v, attrs in self.G1.nodes(data=True): +# merged[v] = attrs.copy() +# map1[v] = v +# used.add(v) + +# # inverse mapping: G2 node → G1 node it merges to +# inv_map = {g2: g1 for g1, g2 in self.mapping12.items()} +# intersection: List[NodeID] = [] + +# # --- process G2 nodes ----------------------------------------- +# next_id = max(used) + 1 if used else 0 +# for v, attrs in self.G2.nodes(data=True): +# if v in inv_map: # merged node +# tgt = inv_map[v] +# map2[v] = tgt +# intersection.append(tgt) +# else: # unique node from G2 +# while next_id in used: +# next_id += 1 +# merged[next_id] = attrs.copy() +# map2[v] = next_id +# used.add(next_id) +# next_id += 1 + +# nodes_sorted = sorted(merged.items()) # list[(id, attrs)] +# return nodes_sorted, map1, map2, intersection + +# # ------------------------------------------------------------------ +# # Edge insertion & groupoid composition +# # ------------------------------------------------------------------ +# def _insert_edges_from( +# self, edge_iter, node_map: Dict[NodeID, NodeID], existing: List[Edge] = None +# ) -> List[Edge]: +# """Insert edges into *existing* applying the groupoid rule when +# possible.""" +# existing = [] if existing is None else existing.copy() + +# # Remap and append new edges +# for u, v, attrs in edge_iter: +# u3 = node_map[u] +# v3 = node_map[v] +# existing.append((u3, v3, attrs.copy())) + +# # Canonicalize keys for undirected graphs +# def key(u, v): +# return (u, v) if isinstance(self.G1, nx.DiGraph) else tuple(sorted((u, v))) + +# # Group edges by (u,v) +# buckets: Dict[Tuple[NodeID, NodeID], List[Order]] = defaultdict(list) +# bucket_src: Dict[Tuple[NodeID, NodeID], List[str]] = defaultdict(list) +# for idx, (u, v, attrs) in enumerate(existing): +# buckets[key(u, v)].append(tuple(attrs["order"])) +# bucket_src[key(u, v)].append("G1" if idx < len(self.G1.edges) else "G2") + +# fused_edges: List[Edge] = [] +# for (u, v), orders in buckets.items(): +# # src = bucket_src[(u, v)] +# if ( +# u in self.intersection_ids +# and v in self.intersection_ids +# and len(orders) >= 2 +# ): +# # Attempt pair‑wise composition between G1 (first) and any G2 edge +# made_composite = False +# for idx2, ord2 in enumerate(orders[1:], start=1): +# a1, a2 = orders[0] +# b1, b2 = ord2 +# if a2 == b1: +# fused_edges.append((u, v, {"order": (a1, b2)})) +# made_composite = True +# break +# if not made_composite: +# # fall back to *all* distinct orders +# for ord_ in orders: +# fused_edges.append((u, v, {"order": ord_})) +# else: +# for ord_ in orders: +# fused_edges.append((u, v, {"order": ord_})) + +# return self._dedupe_edges(fused_edges) + +# # ------------------------------------------------------------------ +# @staticmethod +# def _dedupe_edges(edges: List[Edge]) -> List[Edge]: +# seen: Set[Tuple[int, int, Order]] = set() +# out: List[Edge] = [] +# for u, v, attrs in edges: +# key = (min(u, v), max(u, v), tuple(attrs["order"])) +# if key not in seen: +# seen.add(key) +# out.append((u, v, attrs)) +# return out + +# # ------------------------------------------------------------------ +# # Public helpers +# # ------------------------------------------------------------------ +# def get_nodes(self) -> List[Node]: +# """List of `(id, attrs)` for fused graph.""" +# return self.product_nodes + +# def get_edges(self) -> List[Edge]: +# """List of `(u, v, attrs)` for fused graph.""" +# return self.product_edges + +# def get_map1(self) -> Dict[NodeID, NodeID]: +# return self.map1 + +# def get_map2(self) -> Dict[NodeID, NodeID]: +# return self.map2 + +# def get_graph(self, *, directed: bool = False): +# G = nx.DiGraph() if directed else nx.Graph() +# G.add_nodes_from(self.product_nodes) +# for u, v, attrs in self.product_edges: +# o = attrs["order"] +# attrs["standard_order"] = o[0] - o[1] if None not in o else None +# G.add_edge(u, v, **attrs) +# return G + +# # ------------------------------------------------------------------ +# def __repr__(self): +# return f"MTG(|V|={len(self.product_nodes)}, |E|={len(self.product_edges)})" diff --git a/synkit/Graph/MTG/mtg_explore.py b/synkit/Graph/MTG/mtg_explore.py new file mode 100644 index 0000000..3895fb5 --- /dev/null +++ b/synkit/Graph/MTG/mtg_explore.py @@ -0,0 +1,74 @@ +from typing import Optional, List +from synkit.Graph.MTG.mtg import MTG +from synkit.Rule.Apply.rule_matcher import RuleMatcher +from synkit.Graph.MTG.mcs_matcher import MCSMatcher +from networkx import Graph + + +def find_mtg( + g1: Graph, + g2: Graph, + ground_truth: str, + node_label_names: Optional[List[str]] = None, +) -> Optional[MTG]: + """ + Attempt to construct a Mapping Transformation Graph (MTG) for two input graphs + by finding maximum common substructure mappings and validating against a ground truth. + + :param g1: The first input graph to match. + :type g1: networkx.Graph + :param g2: The second input graph to match. + :type g2: networkx.Graph + :param ground_truth: A string representation of the expected atom-atom mapping (AAM) + used to validate candidate mappings. + :type ground_truth: str + :param node_label_names: List of node attribute names to use for MCS matching. + Defaults to ["element", "charge", "hcount"]. + :type node_label_names: list of str, optional + :returns: An MTG instance if a valid mapping satisfying the ground truth is found; + otherwise, None. + :rtype: MTG or None + :raises ValueError: If input graphs are empty or ground_truth is invalid format. + + :example: + >>> from networkx import Graph + >>> g1, g2 = Graph(), Graph() + >>> # populate g1 and g2 with nodes/edges + >>> mtg = find_mtg( + ... g1, + ... g2, + ... ground_truth="{0:1, 1:0}", + ... node_label_names=["element", "charge", "hcount"] + ... ) + >>> if mtg: + ... print(mtg) + """ + # Validate inputs + if not g1 or not g2: + raise ValueError("Input graphs g1 and g2 must be non-empty.") + if not isinstance(ground_truth, str) or not ground_truth.strip(): + raise ValueError("Ground truth mapping must be a non-empty string.") + + # Set default node_label_names if not provided + if node_label_names is None: + node_label_names = ["element", "charge", "hcount"] + + # Initialize maximum common substructure matcher + mcs = MCSMatcher(node_label_names=node_label_names) + mcs.find_rc_mapping(g1, g2, mcs=False) + mappings = mcs._mappings + + for mapping in mappings: + # Construct MTG using current mapping + mtg = MTG([g1, g2], mappings=[mapping]) + aam = mtg.get_aam() + try: + # Validate generated AAM against ground truth + RuleMatcher(ground_truth, aam, explicit_h=False) + return mtg + except AssertionError: + # Mapping did not satisfy ground truth, try next + continue + + # No valid mapping found + return None diff --git a/synkit/Graph/MTG/utils.py b/synkit/Graph/MTG/utils.py new file mode 100644 index 0000000..5d78ebb --- /dev/null +++ b/synkit/Graph/MTG/utils.py @@ -0,0 +1,425 @@ +import networkx as nx +import copy +from typing import Tuple, Any, Optional, Sequence, Set, Union + +OrderPair = Tuple[float, float] +MissingOrder = Tuple[Set[float], Set[float]] + +GraphType = Union[nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph] +TypesGHTuple = Tuple[Any, ...] # e.g., ('H', 0, 0, 0, ['C']) + + +def _extract_leaf_candidates(orig_th: Tuple[Any, ...]) -> Tuple[TypesGHTuple, ...]: + """ + Flatten one level: if an element is a sequence whose elements are themselves sequences, + treat those inner sequences as candidates; otherwise the element itself is a candidate. + """ + candidates: list[TypesGHTuple] = [] + for item in orig_th: + if ( + isinstance(item, (list, tuple)) + and item + and all(isinstance(inner, (list, tuple)) for inner in item) + ): + # e.g., (('H',...), ('H',...)) -> take inner tuples + for inner in item: + if isinstance(inner, (list, tuple)): + candidates.append(tuple(inner)) + else: + if isinstance(item, (list, tuple)): + candidates.append(tuple(item)) + else: + # non-sequence fallback, wrap to tuple for uniformity + candidates.append((item,)) + return tuple(candidates) + + +def normalize_hcount_and_typesGH(G: GraphType) -> GraphType: + """ + Return a fresh copy of G where: + * each node's `hcount` attribute is set to 0 + * each node's `typesGH` is processed as follows: + 1. Flatten one level so that nested tuples-of-tuples are expanded. + 2. Drop any tuple that contains a `set` anywhere. + 3. From the remaining tuples, keep only the first and last (if more than two). + 4. Zero indices 1 and 2 in each kept tuple (if they exist). + 5. If nothing remains after dropping, result is an empty tuple. + + :param G: input NetworkX graph + :type G: nx.Graph or nx.DiGraph or nx.MultiGraph or nx.MultiDiGraph + :returns: a new graph with normalized hcount and typesGH + :rtype: same type as G + :raises: TypeError if G is not a supported NetworkX graph or if typesGH is malformed. + """ + if not isinstance(G, (nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph)): + raise TypeError(f"Expected a NetworkX graph, got {type(G)!r}") + + H = G.__class__() + H.graph.update(copy.deepcopy(G.graph)) + + for node, data in G.nodes(data=True): + new_data = data.copy() + new_data["hcount"] = 0 + + orig_th = data.get("typesGH", ()) + if not isinstance(orig_th, (list, tuple)): + raise TypeError( + f"Node {node} has typesGH of unexpected type {type(orig_th)}; expected sequence" + ) + + # Step 1: flatten one level of nested tuples-of-tuples + candidates = _extract_leaf_candidates(tuple(orig_th)) + + # Step 2: drop any tuple containing a set + filtered = [ + t for t in candidates if not any(isinstance(elem, set) for elem in t) + ] + + # Step 3: select first and last (or whatever remains) + if not filtered: + selected: Tuple[TypesGHTuple, ...] = () + elif len(filtered) > 2: + selected = (filtered[0], filtered[-1]) + else: + selected = tuple(filtered) # 1 or 2 elements + + # Step 4: zero indices 1 and 2 + normalized: list[TypesGHTuple] = [] + for inner in selected: + if not isinstance(inner, (list, tuple)): + raise TypeError( + f"Inner element of typesGH for node {node} is not tuple-like: {inner!r}" + ) + inner_list = list(inner) + if len(inner_list) > 1: + inner_list[1] = 0 + if len(inner_list) > 2: + inner_list[2] = 0 + normalized.append(tuple(inner_list)) + + new_data["typesGH"] = tuple(normalized) + H.add_node(node, **new_data) + + # Copy edges appropriately + if G.is_multigraph(): + for u, v, key, edata in G.edges(keys=True, data=True): + H.add_edge(u, v, key=key, **copy.deepcopy(edata)) + else: + for u, v, edata in G.edges(data=True): + H.add_edge(u, v, **copy.deepcopy(edata)) + + return H + + +def extract_order_norm( + order_sequence: Sequence[Union[OrderPair, MissingOrder]], +) -> Optional[OrderPair]: + """ + Given a sequence of order tuples and/or placeholders (MissingOrder), + return the normalized bond order as a 2-tuple: + - left: the first tuple element 'a' in the sequence where not both parts are sets + - right: the second tuple element 'b' in the sequence where not both parts are sets, scanning from the end + + The input sequence must have length >= 2. + + :param order_sequence: A sequence of order tuples or placeholders + :type order_sequence: Sequence[tuple[float, float]] or Sequence[MissingOrder] + :returns: A 2-tuple (left, right) if found; otherwise None + :rtype: tuple[float, float] or None + :raises ValueError: If sequence length is less than 2 + + :example: + >>> seq = [({1}, {2}), (3.0, 4.0), ({5}, {6}), (7.0, 8.0)] + >>> extract_order_norm(seq) + (3.0, 8.0) + """ + if not isinstance(order_sequence, Sequence) or len(order_sequence) < 2: + raise ValueError("order_sequence must be a sequence of length >= 2") + + left: Any = None + right: Any = None + # Find first non-placeholder for left + for entry in order_sequence: + a, b = entry + if not (isinstance(a, set) and isinstance(b, set)): + left = a + break + # Find last non-placeholder for right + for entry in reversed(order_sequence): + a, b = entry + if not (isinstance(a, set) and isinstance(b, set)): + right = b + break + if left is not None and right is not None: + return (left, right) + return None + + +def normalize_order(G: nx.Graph) -> nx.Graph: + """ + Return a copy of the graph with each edge's 'order' attribute normalized. + If an edge has an 'order' attribute that is a sequence of length >= 2, + it is replaced by the 2-tuple returned by :func:`extract_order_norm`, + if that function returns a non-None result. + + :param G: Input NetworkX graph + :type G: nx.Graph, nx.DiGraph, nx.MultiGraph, or nx.MultiDiGraph + :returns: A new graph of the same type with normalized edge 'order' + :rtype: same as G + :raises TypeError: If G is not a NetworkX graph + + :example: + >>> import networkx as nx + >>> G = nx.Graph() + >>> G.add_edge(1, 2, order=[(1,2), ({3},{4}), (5,6)]) + >>> H = normalize_order(G) + >>> H.edges[1,2]['order'] + (1, 6) + """ + from copy import deepcopy + + if not isinstance(G, (nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph)): + raise TypeError(f"Expected a NetworkX graph, got {type(G)}") + + H = deepcopy(G) + # Iterate edges appropriately + if H.is_multigraph(): + edge_iter = H.edges(keys=True, data=True) + for _, _, _, attr in edge_iter: + order = attr.get("order") + if isinstance(order, (list, tuple)) and len(order) >= 2: + norm = extract_order_norm(order) + if norm is not None: + attr["order"] = norm + else: + for _, _, attr in H.edges(data=True): + order = attr.get("order") + if isinstance(order, (list, tuple)) and len(order) >= 2: + norm = extract_order_norm(order) + if norm is not None: + attr["order"] = norm + return H + + +# def normalize_hcount_and_typesGH(G): +# """ +# Return a fresh copy of G where: +# - each node's `hcount` attribute is set to 0 +# - in each tuple of `typesGH`, indices 1 and 2 are set to 0 + +# :param G: input NetworkX graph +# :type G: nx.Graph or nx.DiGraph or nx.MultiGraph or nx.MultiDiGraph +# :returns: a new graph with normalized hcount and typesGH +# :rtype: same type as G +# :raises: TypeError if G is not a NetworkX graph + +# :example: +# >>> G = nx.Graph() +# >>> G.add_node(1, hcount=2, typesGH=(("C", 1, 2), ("O", 0, 1))) +# >>> H = normalize_hcount_and_typesGH(G) +# >>> H.nodes[1]['hcount'] +# 0 +# >>> H.nodes[1]['typesGH'] +# (('C', 0, 0), ('O', 0, 0)) +# """ +# if not isinstance(G, (nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph)): +# raise TypeError(f"Expected NetworkX graph, got {type(G)}") + +# # Create empty graph of same class and copy graph-level attributes +# H = G.__class__() +# H.graph.update(copy.deepcopy(G.graph)) + +# # Copy and normalize node data +# for node, data in G.nodes(data=True): +# new_data = data.copy() +# new_data["hcount"] = 0 +# orig_th = data.get("typesGH", ()) +# new_th = [] +# for inner in orig_th: +# inner_list = list(inner) +# # Zero the aromatic slot (index 1) and the hcountGH slot (index 2) +# if len(inner_list) > 1: +# inner_list[1] = 0 +# if len(inner_list) > 2: +# inner_list[2] = 0 +# new_th.append(tuple(inner_list)) +# new_data["typesGH"] = tuple(new_th) +# H.add_node(node, **new_data) + +# # Copy edges (with keys for multigraphs) +# if G.is_multigraph(): +# for u, v, key, edata in G.edges(keys=True, data=True): +# H.add_edge(u, v, key=key, **copy.deepcopy(edata)) +# else: +# for u, v, edata in G.edges(data=True): +# H.add_edge(u, v, **copy.deepcopy(edata)) + +# return H + + +def label_mtg_edges(G: nx.Graph, inplace: bool = False) -> nx.Graph: + """ + Label each edge in the MTG graph with a boolean 'is_mtg' attribute based on two criteria: + 1. There are at least two steps where the standard order (order[0] - order[1]) is non-zero. + 2. The sum of all non-None standard orders is zero. + + :param G: Input MTG graph with 'order' history per edge + :type G: nx.Graph or nx.DiGraph + :param inplace: If True, modify G in place; otherwise work on a copy + :type inplace: bool + :returns: Graph with 'is_mtg' boolean attribute on each edge + :rtype: same type as G + :raises TypeError: If G is not a NetworkX Graph + + :example: + >>> import networkx as nx + >>> G = nx.Graph() + >>> # Single change only -> less than 2 non-zero steps => False + >>> G.add_edge(7,3, order=((1.0,1.0),(1.0,0))) + >>> H = label_mtg_edges(G) + >>> H.edges[7,3]['is_mtg'] + False + >>> # Two-step equal but opposite changes -> sum zero and count>=2 => True + >>> G = nx.Graph() + >>> G.add_edge(2,1, order=((1.0,2.0),(2.0,1.0))) + >>> H = label_mtg_edges(G) + >>> H.edges[2,1]['is_mtg'] + True + """ + if not isinstance(G, (nx.Graph, nx.DiGraph)): + raise TypeError(f"Expected a NetworkX Graph, got {type(G)}") + graph = G if inplace else G.copy() + for u, v, attr in graph.edges(data=True): + history = attr.get("order") + # Extract numeric standard orders + std_vals: list[float] = [] + if isinstance(history, (list, tuple)) and len(history) >= 2: + for entry in history: + if ( + isinstance(entry, tuple) + and len(entry) == 2 + and all(isinstance(x, (int, float)) for x in entry) + ): + std_vals.append(entry[0] - entry[1]) + # Apply criteria + non_zero_count = sum(1 for v in std_vals if v != 0) + total = sum(std_vals) + attr["is_mtg"] = non_zero_count >= 2 and total == 0 + return graph + + +def compute_standard_order(G: nx.Graph, inplace: bool = False) -> nx.Graph: + """ + Compute and assign the 'standard_order' attribute for each edge in the graph. + 'standard_order' is defined as the difference order[0] - order[1] + for edges whose 'order' attribute is a 2-tuple of numeric values. + + :param G: Input NetworkX graph + :type G: nx.Graph, nx.DiGraph, nx.MultiGraph, or nx.MultiDiGraph + :param inplace: If True, modify G in-place; otherwise operate on a copy + :type inplace: bool + :returns: Graph with 'standard_order' attributes set + :rtype: same type as G + :raises TypeError: If G is not a NetworkX graph + + :example: + >>> import networkx as nx + >>> G = nx.Graph() + >>> G.add_edge(7, 3, order=(1.0, 0)) + >>> H = compute_standard_order(G) + >>> H.edges[7,3]['standard_order'] + 1.0 + """ + if not isinstance(G, (nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph)): + raise TypeError(f"Expected a NetworkX graph, got {type(G)}") + + graph = G if inplace else G.copy() + if graph.is_multigraph(): + for u, v, key, attr in graph.edges(keys=True, data=True): + order = attr.get("order") + if isinstance(order, tuple) and len(order) == 2: + a, b = order + try: + attr["standard_order"] = a - b + except Exception: + attr["standard_order"] = None + else: + for u, v, attr in graph.edges(data=True): + order = attr.get("order") + if isinstance(order, tuple) and len(order) == 2: + a, b = order + try: + attr["standard_order"] = a - b + except Exception: + attr["standard_order"] = None + return graph + + +# def extract_order_norm(order_tuple): +# """ +# Given a sequence of four 2-element tuples (order data), return the normalized order: +# - left: first element of the first tuple that is not both sets +# - right: second element of the last tuple that is not both sets + +# :param order_tuple: tuple of four 2-element tuples +# :type order_tuple: tuple(tuple, tuple, tuple, tuple) +# :returns: normalized (left, right) or None if not found +# :rtype: tuple or None +# :raises: ValueError if order_tuple is not length 4 + +# :example: +# >>> ot = (({1}, {2}), (3, 4), ({5}, {6}), (7, 8)) +# >>> extract_order_norm(ot) +# (3, 8) +# """ +# if not (isinstance(order_tuple, tuple) and len(order_tuple) == 4): +# raise ValueError("order_tuple must be a tuple of length 4") + +# left = None +# right = None +# # Find first non-all-set tuple for left +# for a, b in order_tuple: +# if not (isinstance(a, set) and isinstance(b, set)): +# left = a +# break +# # Find last non-all-set tuple for right +# for a, b in reversed(order_tuple): +# if not (isinstance(a, set) and isinstance(b, set)): +# right = b +# break +# return (left, right) if (left is not None and right is not None) else None + + +# def normalize_order(G): +# """ +# Return a copy of G with edge attribute 'order' normalized. +# For each edge, if the 'order' attribute is a 4-tuple, replace it with the +# normalized 2-tuple returned by extract_order_norm. + +# :param G: input NetworkX graph +# :type G: nx.Graph or nx.DiGraph or nx.MultiGraph or nx.MultiDiGraph +# :returns: a new graph with normalized edge orders +# :rtype: same type as G +# :raises: TypeError if G is not a NetworkX graph + +# :example: +# >>> G = nx.Graph() +# >>> G.add_edge(1, 2, order=((1,2), ({3},{4}), ({5},{6}), (7,8))) +# >>> H = copy_and_normalize_order(G) +# >>> H.edges[1,2]['order'] +# (1, 8) +# """ +# if not isinstance(G, (nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph)): +# raise TypeError(f"Expected NetworkX graph, got {type(G)}") +# H = G.copy() +# for _, _, _, attr in ( +# H.edges(keys=True, data=True) +# if H.is_multigraph() +# else [(u, v, None, attr) for u, v, attr in H.edges(data=True)] +# ): +# order = attr.get("order") +# if isinstance(order, tuple) and len(order) == 4: +# norm = extract_order_norm(order) +# if norm is not None: +# attr["order"] = norm +# return H diff --git a/synkit/Graph/Matcher/subgraph_matcher.py b/synkit/Graph/Matcher/subgraph_matcher.py index 1eb5678..4c73810 100644 --- a/synkit/Graph/Matcher/subgraph_matcher.py +++ b/synkit/Graph/Matcher/subgraph_matcher.py @@ -548,3 +548,330 @@ def help(self) -> str: # noqa: D401 – property for convenience """Return the full module docstring.""" return __doc__ + + +# --------------------------------------------------------------------- +# Sub-graph search engine +# --------------------------------------------------------------------- +# class SubgraphSearchEngine: +# """Efficient sub-graph matching helpers (static, stateless). + +# Call :py:meth:`find_subgraph_mappings` as the single public entry-point. +# """ + +# # ------------------------------------------------------------------ +# # Public dispatcher ------------------------------------------------- +# # ------------------------------------------------------------------ +# @staticmethod +# def find_subgraph_mappings( +# host: nx.Graph, +# pattern: nx.Graph, +# *, +# node_attrs: List[str], +# edge_attrs: List[str], +# strategy: Strategy = Strategy.COMPONENT, +# max_results: Optional[int] = None, +# strict_cc_count: bool = True, +# unique: bool = True, +# ) -> List[MappingDict]: +# """Return *all* pattern→host embeddings. + +# Parameters +# ---------- +# host : nx.Graph +# The larger “substrate” graph. +# pattern : nx.Graph +# The smaller “template” graph. +# node_attrs : list[str] +# Attributes that must match exactly on nodes. +# edge_attrs : list[str] +# Attributes that must match exactly on edges. +# strategy : Strategy, default=COMPONENT +# Matching strategy (ALL | COMPONENT | BACKTRACK). +# max_results : int | None, default=None +# Stop after collecting this many embeddings (before duplicate +# suppression unless *unique* is False). +# strict_cc_count : bool, default=True +# Reject immediately if host has more CCs than pattern. +# unique : bool, default=False +# If True, collapse embeddings that differ only by an +# automorphism of *pattern*. + +# Returns +# ------- +# list[dict[int, int]] +# A list of node-ID maps *pattern → host*. +# """ + +# # ── defensive copy (do not mutate caller’s graphs) ──────────── +# host = host.copy() +# pattern = pattern.copy() + +# # ── run selected strategy ───────────────────────────────────── +# if strategy is Strategy.ALL: +# results = SubgraphSearchEngine._find_all_subgraph_mappings( +# host, +# pattern, +# node_attrs=node_attrs, +# edge_attrs=edge_attrs, +# max_results=max_results if not unique else None, +# ) +# elif strategy is Strategy.COMPONENT: +# results = ( +# SubgraphSearchEngine._find_component_aware_subgraph_mappings( +# host, +# pattern, +# node_attrs=node_attrs, +# edge_attrs=edge_attrs, +# max_results=max_results if not unique else None, +# strict_cc_count=strict_cc_count, +# ) +# ) +# elif strategy is Strategy.BACKTRACK: +# results = SubgraphSearchEngine._find_bt_subgraph_mappings( +# host, +# pattern, +# node_attrs=node_attrs, +# edge_attrs=edge_attrs, +# max_results=max_results if not unique else None, +# strict_cc_count=strict_cc_count, +# ) +# else: # defensive +# raise ValueError(f"Unsupported strategy: {strategy}") + +# # ── post-filter duplicates coming from automorphisms ────────── +# if unique and results: +# def node_match(nh: EdgeAttr, np: EdgeAttr) -> bool: +# return all(nh.get(k) == np.get(k) for k in node_attrs) and ( +# nh.get("hcount", 0) >= np.get("hcount", 0) +# ) + +# def edge_match(eh: EdgeAttr, ep: EdgeAttr) -> bool: +# return all(eh.get(k) == ep.get(k) for k in edge_attrs) + +# results = SubgraphSearchEngine._filter_automorphic_duplicates( +# pattern, +# results, +# node_match=node_match, +# edge_match=edge_match, +# ) + +# # ── honour *max_results* after uniqueness filter ────────────── +# if max_results is not None and len(results) > max_results: +# results = results[:max_results] + +# return results + +# # ------------------------------------------------------------------ +# # Strategy: ALL – classical VF2 on full host ----------------------- +# # ------------------------------------------------------------------ +# @staticmethod +# def _find_all_subgraph_mappings( +# host: nx.Graph, +# pattern: nx.Graph, +# *, +# node_attrs: List[str], +# edge_attrs: List[str], +# max_results: Optional[int] = None, +# ) -> List[MappingDict]: +# """Pure VF2 search without CC awareness (baseline).""" + +# def node_match(nh: EdgeAttr, np: EdgeAttr) -> bool: +# return all(nh.get(k) == np.get(k) for k in node_attrs) and nh.get( +# "hcount", 0 +# ) >= np.get("hcount", 0) + +# def edge_match(eh: EdgeAttr, ep: EdgeAttr) -> bool: +# return all(eh.get(k) == ep.get(k) for k in edge_attrs) + +# gm = GraphMatcher(host, pattern, node_match=node_match, edge_match=edge_match) +# results: List[MappingDict] = [] +# for iso in gm.subgraph_monomorphisms_iter(): +# results.append({p: h for h, p in iso.items()}) +# if max_results is not None and len(results) >= max_results: +# break +# return results + +# # ------------------------------------------------------------------ +# # Strategy: COMPONENT – improved component-aware matcher ----------- +# # ------------------------------------------------------------------ +# @staticmethod +# def _find_component_aware_subgraph_mappings( +# host: nx.Graph, +# pattern: nx.Graph, +# *, +# node_attrs: List[str], +# edge_attrs: List[str], +# max_results: Optional[int] = None, +# strict_cc_count: bool = False, +# ) -> List[MappingDict]: +# """Component-aware VF2 without attribute/degree/WL-1 pre-filters.""" + +# # 1) split into connected components +# host_ccs = [host.subgraph(c).copy() for c in nx.connected_components(host)] +# pat_ccs = [pattern.subgraph(c).copy() for c in nx.connected_components(pattern)] +# hcc, pcc = len(host_ccs), len(pat_ccs) + +# # empty pattern ⇒ single empty mapping +# if pcc == 0: +# return [{}] + +# # fallback to full VF2 if host has fewer CCs +# if hcc < pcc: +# return SubgraphSearchEngine._find_all_subgraph_mappings( +# host, +# pattern, +# node_attrs=node_attrs, +# edge_attrs=edge_attrs, +# max_results=max_results, +# ) + +# # strict count: reject when host has more CCs than pattern +# if hcc > pcc and strict_cc_count: +# return [] + +# # 2) define VF2 predicates +# def node_match(nh: EdgeAttr, np: EdgeAttr) -> bool: +# if any(nh.get(a) != np.get(a) for a in node_attrs): +# return False +# return nh.get("hcount", 0) >= np.get("hcount", 0) + +# def edge_match(eh: EdgeAttr, ep: EdgeAttr) -> bool: +# return all(eh.get(a) == ep.get(a) for a in edge_attrs) + +# # 3) collect embeddings for each pattern-CC +# per_cc: List[List[Tuple[int, MappingDict]]] = [] +# for pc in pat_ccs: +# sz = pc.number_of_nodes() +# cand_cc_idx = [ +# i for i, hc in enumerate(host_ccs) if hc.number_of_nodes() >= sz +# ] +# if not cand_cc_idx: +# return [] # impossible – no room for this component + +# cc_maps: List[Tuple[int, MappingDict]] = [] +# for hi in cand_cc_idx: +# gm = GraphMatcher( +# host_ccs[hi], pc, node_match=node_match, edge_match=edge_match +# ) +# for iso in gm.subgraph_monomorphisms_iter(): +# cc_maps.append((hi, {p: h for h, p in iso.items()})) +# if max_results and len(cc_maps) >= max_results: +# break +# if max_results and len(cc_maps) >= max_results: +# break + +# if not cc_maps: # this pattern-CC embeds nowhere +# return [] +# per_cc.append(cc_maps) + +# # 4) order pattern-CCs by fewest embeddings → best pruning +# order = sorted(range(pcc), key=lambda i: len(per_cc[i])) +# ordered = [per_cc[i] for i in order] + +# # 5) backtrack to build full-pattern mappings +# results: List[MappingDict] = [] +# used_host: Set[int] = set() + +# def backtrack(level: int, accum: MappingDict): +# if max_results and len(results) >= max_results: +# return +# if level == pcc: +# results.append(accum.copy()) +# return +# for hi, mapping in ordered[level]: +# if hi in used_host or any(p in accum for p in mapping): +# continue +# used_host.add(hi) +# accum.update(mapping) +# backtrack(level + 1, accum) +# for p in mapping: +# accum.pop(p) +# used_host.remove(hi) +# if max_results and len(results) >= max_results: +# return + +# backtrack(0, {}) +# return results + +# # ------------------------------------------------------------------ +# # Strategy: BACKTRACK – component first, fallback to VF2 ----------- +# # ------------------------------------------------------------------ +# @staticmethod +# def _find_bt_subgraph_mappings( +# host: nx.Graph, +# pattern: nx.Graph, +# *, +# node_attrs: List[str], +# edge_attrs: List[str], +# max_results: Optional[int] = None, +# strict_cc_count: bool = False, +# ) -> List[MappingDict]: +# """Component-aware search *with* classic fallback if any CC fails.""" + +# primary = SubgraphSearchEngine._find_component_aware_subgraph_mappings( +# host, +# pattern, +# node_attrs=node_attrs, +# edge_attrs=edge_attrs, +# max_results=max_results, +# strict_cc_count=strict_cc_count, +# ) +# if primary: +# return primary +# return SubgraphSearchEngine._find_all_subgraph_mappings( +# host, +# pattern, +# node_attrs=node_attrs, +# edge_attrs=edge_attrs, +# max_results=max_results, +# ) + +# # ------------------------------------------------------------------ +# # Duplicate suppression via automorphisms -------------------------- +# # ------------------------------------------------------------------ +# @staticmethod +# def _filter_automorphic_duplicates( +# pattern: nx.Graph, +# mappings: List[MappingDict], +# *, +# node_match, +# edge_match, +# ) -> List[MappingDict]: +# """Throw away mappings that differ only by a pattern automorphism.""" + +# # 1) compute automorphism group of *pattern* +# gm_self = GraphMatcher( +# pattern, pattern, node_match=node_match, edge_match=edge_match +# ) +# autos = list(gm_self.subgraph_isomorphisms_iter()) +# if len(autos) <= 1: +# return mappings # no non-trivial symmetry + +# sorted_nodes = tuple(sorted(pattern.nodes())) +# seen: Set[Tuple[int, ...]] = set() +# unique: List[MappingDict] = [] + +# # 2) canonical fingerprint for each mapping +# for m in mappings: +# canon = min( +# tuple(m[phi[n]] for n in sorted_nodes) for phi in autos +# ) +# if canon not in seen: +# seen.add(canon) +# unique.append(m) + +# return unique + +# # ------------------------------------------------------------------ +# # Niceties ---------------------------------------------------------- +# # ------------------------------------------------------------------ +# def __repr__(self) -> str: # noqa: D401 – simple repr +# return "" + +# __str__ = __repr__ + +# @property +# def help(self) -> str: # noqa: D401 – property for convenience +# """Return the full module docstring.""" +# return __doc__ diff --git a/synkit/Graph/Wildcard/radwc.py b/synkit/Graph/Wildcard/radwc.py new file mode 100644 index 0000000..0c92c98 --- /dev/null +++ b/synkit/Graph/Wildcard/radwc.py @@ -0,0 +1,117 @@ +import re +from rdkit import Chem +from rdkit.Chem import SanitizeFlags +from typing import Tuple, Optional + + +class RadWC: + """ + Static utility for appending wildcard dummy atoms ([*]) with atom-map + indices to all radical centers **in the product block** of a reaction SMILES. + + - Reactant and agent blocks are not modified. + - Only atoms in the product with unpaired electrons are considered. + - Each product radical gets a new [*:N] with unique map number (auto or user-supplied). + + Example + ------- + >>> rxn = '[CH2:1][OH:2]>>[CH2:1][O:2]' + >>> RadWC.transform(rxn) + '[CH2:1][OH:2]>>[CH2:1][O:2]' + >>> rxn2 = '[CH2:1][OH:2]>>[CH:1].[OH:2]' + >>> RadWC.transform(rxn2) + '[CH2:1][OH:2]>>[CH:1]([*:3]).[OH:2]' + """ + + @staticmethod + def transform(rxn_smiles: str, start_map: Optional[int] = None) -> str: + """ + Add [*] wildcards (with atom-map index) to every radical in the + product block of the input reaction SMILES. + + :param rxn_smiles: Reaction SMILES, 2 or 3 blocks (R>>P or R>A>P). + :type rxn_smiles: str + :param start_map: Optional; first atom-map index for wildcards. + :type start_map: int or None + :returns: Modified reaction SMILES with product wildcards. + :rtype: str + :raises ValueError: On parse error or invalid input. + + Example + ------- + >>> RadWC.transform('[CH2:1][OH:2]>>[CH:1].[OH:2]') + '[CH2:1][OH:2]>>[CH:1]([*:3]).[OH:2]' + """ + react_blk, agents_blk, prod_blk = RadWC._split_reaction(rxn_smiles) + # Determine atom-map to use for wildcards + existing = [int(n) for n in re.findall(r":(\d+)", rxn_smiles)] + next_map = ( + start_map if start_map is not None else (max(existing, default=0) + 1) + ) + + prod_frags = prod_blk.split(".") if prod_blk else [] + new_prod_frags = [] + + keep_ops = SanitizeFlags.SANITIZE_ALL & ~SanitizeFlags.SANITIZE_ADJUSTHS + + for smi in prod_frags: + if not smi: + continue + mol = Chem.MolFromSmiles(smi, sanitize=False) + if mol is None: + raise ValueError(f"Cannot parse product SMILES fragment: {smi}") + Chem.SanitizeMol(mol, sanitizeOps=keep_ops) + rw = Chem.RWMol(mol) + atoms = list(rw.GetAtoms()) + changed = False + for atom in atoms: + rad = atom.GetNumRadicalElectrons() + if rad > 0: + for _ in range(rad): + dummy = Chem.Atom(0) + dummy.SetAtomMapNum(next_map) + dummy.SetNoImplicit(True) + rw.AddAtom(dummy) + rw.AddBond( + atom.GetIdx(), rw.GetNumAtoms() - 1, Chem.BondType.SINGLE + ) + next_map += 1 + changed = True + if changed: + Chem.SanitizeMol(rw.GetMol(), sanitizeOps=keep_ops) + new_prod_frags.append( + Chem.MolToSmiles(rw.GetMol(), isomericSmiles=True, allHsExplicit=True) + ) + + prod_str = ".".join(new_prod_frags) + if agents_blk is None: + return f"{react_blk}>>{prod_str}" + return f"{react_blk}>{agents_blk}>{prod_str}" + + @staticmethod + def _split_reaction(rxn: str) -> Tuple[str, Optional[str], str]: + """ + Split a reaction SMILES into (reactant, agent or None, product). + + :param rxn: Reaction SMILES string. + :type rxn: str + :returns: (reactant, agent, product) tuple (agent may be None). + :rtype: Tuple[str, Optional[str], str] + :raises ValueError: If the SMILES does not contain 2 or 3 '>'s. + """ + parts = rxn.split(">") + if len(parts) == 2: + return parts[0], None, parts[1] + if len(parts) == 3: + return parts[0], parts[1], parts[2] + raise ValueError("Reaction SMILES must contain 2 or 3 '>' symbols") + + @staticmethod + def describe(): + """ + Print a description and minimal example. + """ + print(RadWC.__doc__) + + def __repr__(self): + return "" diff --git a/synkit/Graph/Wildcard/wildcard.py b/synkit/Graph/Wildcard/wildcard.py new file mode 100644 index 0000000..6766567 --- /dev/null +++ b/synkit/Graph/Wildcard/wildcard.py @@ -0,0 +1,230 @@ +import networkx as nx +from typing import Dict, Any, Tuple, Optional +from synkit.IO import rsmi_to_graph, graph_to_smi + + +class WildCard: + """ + Static utility class for generating reaction SMILES with wildcards by + augmenting the product graph with subgraphs unique to the reactant and + patching lost external connections with wildcard atoms ('*'). + + Optionally, can rebalance the reactant side to ensure both sides have + matching atom maps (by adding wildcard atoms if needed). + + All methods are static and do not store any internal state. + + Example + ------- + >>> WildCard.rsmi_with_wildcards('CCO>>CC') + 'CCO>>CC*' + + >>> WildCard.rsmi_with_wildcards('CCO>>CC', rebalance=True) + 'CCO*>>CC*' + """ + + @staticmethod + def rsmi_with_wildcards( + rsmi: str, + attributes_defaults: Optional[Dict[str, Any]] = None, + rebalance: bool = False, + ) -> str: + """ + Given a reaction SMILES string, returns a new reaction SMILES where the product + side contains any disconnected subgraphs unique to the reactant, with lost + external bonds patched with wildcard atoms. Optionally, also adds wildcards to + the reactant side to ensure matching atom maps (rebalance). + + :param rsmi: Reaction SMILES (e.g., 'CCO>>CC') + :type rsmi: str + :param attributes_defaults: Optional dictionary of default attributes for wildcards + :type attributes_defaults: dict, optional + :param rebalance: Whether to rebalance the reactant side by adding wildcards + :type rebalance: bool + :returns: Augmented reaction SMILES string + :rtype: str + :raises ValueError: If parsing or output generation fails. + + Example + ------- + >>> WildCard.rsmi_with_wildcards('CCO>>CC') + 'CCO>>CC*' + >>> WildCard.rsmi_with_wildcards('CCO>>CC', rebalance=True) + 'CCO*>>CC*' + """ + r, p = WildCard.from_rsmi(rsmi) + new_r, new_p = WildCard.add_unique_subgraph_with_wildcards( + r, p, attributes_defaults, rebalance=rebalance + ) + try: + return f"{WildCard.to_smi(new_r)}>>{WildCard.to_smi(new_p)}" + except Exception as e: + raise ValueError( + "Could not convert to RSMI after wildcard patching." + ) from e + + @staticmethod + def add_unique_subgraph_with_wildcards( + G: nx.Graph, + H: nx.Graph, + attributes_defaults: Optional[Dict[str, Any]] = None, + rebalance: bool = False, + ) -> Tuple[nx.Graph, nx.Graph]: + """ + Add the subgraph unique to G as a disconnected union to H, + and patch lost external connections with plain wildcard bonds. + Optionally, rebalance the reactant side to ensure both sides have + matching atom maps by adding wildcards. + + :param G: Reactant graph + :type G: nx.Graph + :param H: Product graph + :type H: nx.Graph + :param attributes_defaults: Optional attribute defaults for wildcard nodes + :type attributes_defaults: dict, optional + :param rebalance: Whether to rebalance the reactant side with wildcards + :type rebalance: bool + :returns: Tuple (new_G, new_H) with both graphs possibly augmented by wildcards + :rtype: Tuple[nx.Graph, nx.Graph] + :raises ValueError: If G or H are not valid graphs. + + Example + ------- + >>> r, p = WildCard.from_rsmi('CCO>>CC') + >>> r2, p2 = WildCard.add_unique_subgraph_with_wildcards(r, p, rebalance=True) + """ + if not isinstance(G, nx.Graph) or not isinstance(H, nx.Graph): + raise ValueError("G and H must be networkx.Graph instances") + if G.number_of_nodes() == 0 or H.number_of_nodes() == 0: + raise ValueError("Both G and H must have at least one node.") + if not all("atom_map" in d for _, d in G.nodes(data=True)): + raise ValueError( + "All reactant nodes must have 'atom_map' attributes for unique subgraph logic." + ) + if not all("atom_map" in d for _, d in H.nodes(data=True)): + raise ValueError( + "All product nodes must have 'atom_map' attributes for unique subgraph logic." + ) + + if attributes_defaults is None: + attributes_defaults = { + "element": "*", + "aromatic": False, + "hcount": 0, + "charge": 0, + "neighbors": [], + } + + # Make working copies + G_new = G.copy() + H_new = H.copy() + + # --------------------------- + # 1. PATCH PRODUCT SIDE (add unique reactant subgraphs and wildcards) + # --------------------------- + react_atom_maps = {d["atom_map"] for _, d in G.nodes(data=True)} + prod_atom_maps = {d["atom_map"] for _, d in H.nodes(data=True)} + + # Identify nodes and subgraphs unique to reactant + unique_atom_maps = react_atom_maps - prod_atom_maps + node_map = {d["atom_map"]: n for n, d in G.nodes(data=True)} + unique_nodes = [node_map[a] for a in unique_atom_maps if a in node_map] + + G_unique = G.subgraph(unique_nodes).copy() + # Add unique reactant fragments to product + for n, d in G_unique.nodes(data=True): + H_new.add_node(n, **d) + for u, v, d in G_unique.edges(data=True): + H_new.add_edge(u, v, **d) + + # Add wildcards to patch lost external bonds (reactant → outside) + existing_ids = set(H_new.nodes) + next_id = ( + max([n for n in existing_ids if isinstance(n, int)], default=0) + 1 + if existing_ids + else 1 + ) + + for n in unique_nodes: + for nbr in G.neighbors(n): + nbr_map = G.nodes[nbr]["atom_map"] + if nbr_map not in unique_atom_maps: + wc_id = next_id + next_id += 1 + H_new.add_node( + wc_id, + **attributes_defaults, + atom_map=wc_id, + typesGH=(("*", False, 0, 0, []), ("*", False, 0, 0, [])), + ) + H_new.add_edge(n, wc_id) + + # --------------------------- + # 2. REBALANCE REACTANT SIDE (add wildcards to reactant if required) + # --------------------------- + if rebalance: + prod_atom_maps = {d["atom_map"] for _, d in H_new.nodes(data=True)} + missing_in_react = prod_atom_maps - react_atom_maps + if missing_in_react: + react_existing_ids = set(G_new.nodes) + react_next_id = ( + max( + [n for n in react_existing_ids if isinstance(n, int)], default=0 + ) + + 1 + if react_existing_ids + else 1 + ) + for missing_map in missing_in_react: + wc_id = react_next_id + react_next_id += 1 + G_new.add_node( + wc_id, + **attributes_defaults, + atom_map=missing_map, + typesGH=(("*", False, 0, 0, []), ("*", False, 0, 0, [])), + ) + + return G_new, H_new + + @staticmethod + def from_rsmi(rsmi: str) -> Tuple[nx.Graph, nx.Graph]: + """ + Convert a reaction SMILES string into reactant and product graphs. + + :param rsmi: Reaction SMILES string + :type rsmi: str + :returns: Tuple (reactant_graph, product_graph) + :rtype: Tuple[nx.Graph, nx.Graph] + :raises ValueError: If input cannot be parsed. + """ + try: + return rsmi_to_graph(rsmi) + except Exception as e: + raise ValueError(f"Could not parse RSMI: {rsmi}") from e + + @staticmethod + def to_smi(G: nx.Graph) -> str: + """ + Convert a networkx molecular graph to a canonical SMILES string. + + :param G: Molecular graph + :type G: nx.Graph + :returns: SMILES string + :rtype: str + :raises ValueError: If conversion fails. + """ + try: + return graph_to_smi(G) + except Exception as e: + raise ValueError("Could not convert graph to SMILES") from e + + @staticmethod + def describe(): + """ + Print a description and usage example for this class. + """ + print(WildCard.__doc__) + + def __repr__(self): + return "" diff --git a/synkit/IO/combinatorial/__init__.py b/synkit/IO/combinatorial/__init__.py new file mode 100644 index 0000000..5b3e638 --- /dev/null +++ b/synkit/IO/combinatorial/__init__.py @@ -0,0 +1,8 @@ +import warnings + +warnings.warn( + "⚠️ This module is under active development and may be unstable. " + "APIs, behaviors, or outputs are subject to change without notice.", + category=UserWarning, + stacklevel=2, +) diff --git a/synkit/IO/combinatorial/gml_to_graph.py b/synkit/IO/combinatorial/gml_to_graph.py new file mode 100644 index 0000000..a01b23a --- /dev/null +++ b/synkit/IO/combinatorial/gml_to_graph.py @@ -0,0 +1,254 @@ +import re +import networkx as nx +from typing import Tuple, List, Dict + + +class GMLToGraph: + """ + Parses a GML-like reaction rule into three NetworkX graphs: reactant (left), + conserved context (context), and product (right). Preserves atom-map indices + and original SMARTS labels, and attaches placeholder constraints (both + table-style and multi-block Rest-style) to node and edge attributes. + + Parameters + ---------- + gml_text : str + The GML-like reaction rule text, containing 'left', 'context', 'right', + and 'constrainLabelAny' sections. + + Attributes + ---------- + graphs : Dict[str, nx.Graph] + A mapping of section names ('left', 'context', 'right') to the + corresponding parsed graphs. + placeholder_constraints : Dict[str, List[str]] + A mapping from placeholder labels (e.g. '_X') to their allowed values, + extracted from the 'constrainLabelAny' block. + """ + + def __init__(self, gml_text: str): + """ + Initialize the parser with the full GML text. + + Parameters + ---------- + gml_text : str + GML-like rule text to be parsed. + + Raises + ------ + ValueError + If the provided gml_text is empty. + """ + if not gml_text: + raise ValueError("gml_text must be a non-empty string") + self.gml_text = gml_text + self.graphs: Dict[str, nx.Graph] = { + sec: nx.Graph() for sec in ("left", "context", "right") + } + self.placeholder_constraints: Dict[str, List[str]] = {} + + def _parse_element(self, line: str, sec: str) -> None: + """ + Parse a single GML line describing a node or an edge and insert it + into the specified graph section. + + Parameters + ---------- + line : str + A line starting with 'node' or 'edge', e.g., 'node [ id 1 label "C" ]'. + sec : str + The target section in 'left', 'context', or 'right'. + + Raises + ------ + ValueError + If `sec` is not one of 'left', 'context', or 'right'. + """ + if sec not in self.graphs: + raise ValueError(f"Unknown section: {sec}") + tokens = line.split() + order_map = {"-": 1, ":": 1.5, "=": 2, "#": 3} + if tokens[0] == "node": + # Extract node attributes + nid = int(tokens[tokens.index("id") + 1]) + raw_label = tokens[tokens.index("label") + 1].strip('"') + m = re.fullmatch(r"([A-Za-z*_]+?)(\d+)?([+-])?", raw_label) + if m: + element = m.group(1) + charge = int(m.group(2)) if m.group(2) else 1 if m.group(3) else 0 + if m.group(3) == "-": + charge = -charge + else: + element = raw_label + charge = 0 + attrs = { + "element": element, + "charge": charge, + "atom_map": nid, + "hcount": 0, + "label": raw_label, + } + self.graphs[sec].add_node(nid, **attrs) + elif tokens[0] == "edge": + # Extract edge attributes + src = int(tokens[tokens.index("source") + 1]) + tgt = int(tokens[tokens.index("target") + 1]) + lbl = tokens[tokens.index("label") + 1].strip('"') + order = order_map.get(lbl, 0) + self.graphs[sec].add_edge(src, tgt, order=order, label=lbl) + + def _synchronize(self) -> None: + """ + Ensure that every node and edge in the 'context' graph also appears in + both 'left' and 'right' graphs, carrying over node and edge attributes. + """ + ctx = self.graphs["context"] + # Nodes + for n, data in ctx.nodes(data=True): + for sec in ("left", "right"): + if n not in self.graphs[sec]: + self.graphs[sec].add_node(n, **data) + else: + self.graphs[sec].nodes[n].update(data) + # Edges + for u, v, data in ctx.edges(data=True): + for sec in ("left", "right"): + if not self.graphs[sec].has_edge(u, v): + self.graphs[sec].add_edge(u, v, **data) + + def _parse_constraints(self, lines: List[str], idx: int) -> int: + """ + Parse placeholder constraints from the 'constrainLabelAny' block, + supporting both table and multi-block Rest styles. + + Parameters + ---------- + lines : List[str] + All lines of the GML text. + idx : int + The index of the first line inside the '[' of the block. + + Returns + ------- + int + The index of the closing ']' line of the block. + """ + while idx < len(lines): + line = lines[idx].strip() + if line.startswith("]"): + return idx + m = re.match(r'label\s+"[^\(]+\(([^)]+)\)"', line) + if m: + placeholders = [p.strip() for p in m.group(1).split(",")] + # Table style if multiple placeholders + if len(placeholders) > 1: + table: List[Tuple[str, ...]] = [] + idx += 1 + # Find rows + while idx < len(lines) and "labels [" not in lines[idx]: + idx += 1 + while idx < len(lines): + row_line = lines[idx].strip() + matches = re.findall(r'label\s+"[^\(]+\(([^)]+)\)"', row_line) + for row in matches: + table.append(tuple(val.strip() for val in row.split(","))) + if "]" in row_line: + break + idx += 1 + for j, ph in enumerate(placeholders): + self.placeholder_constraints[ph] = [row[j] for row in table] + idx += 1 + continue + # Multi-block Rest style + ph = placeholders[0] + self.placeholder_constraints[ph] = [] + idx += 1 + while idx < len(lines) and "labels [" not in lines[idx]: + idx += 1 + if idx < len(lines): + while idx < len(lines): + lbl_line = lines[idx].strip() + matches = re.findall(r'label\s+"[^\(]+\(([^)]+)\)"', lbl_line) + self.placeholder_constraints[ph].extend(matches) + if "]" in lbl_line: + break + idx += 1 + idx += 1 + continue + idx += 1 + return idx + + def _attach_constraints(self) -> None: + """ + Attach the parsed placeholder constraints to nodes (as 'constraint') + and edges (as 'bond_constraint') across all graphs, and store the + raw mapping on the context graph's .graph metadata. + """ + pc = self.placeholder_constraints + for graph in self.graphs.values(): + for n, data in graph.nodes(data=True): + lbl = data.get("label") + if lbl in pc: + data["constraint"] = pc[lbl] + for u, v, data in graph.edges(data=True): + lbl = data.get("label") + if lbl in pc: + data["bond_constraint"] = pc[lbl] + self.graphs["context"].graph["placeholder_constraints"] = pc + + def transform(self) -> Tuple[nx.Graph, nx.Graph, nx.Graph]: + """ + Parse the GML text, build the left, right, and context graphs, and return them. + + Returns + ------- + Tuple[nx.Graph, nx.Graph, nx.Graph] + A tuple (left_graph, right_graph, context_graph), each with attached constraints. + """ + lines = self.gml_text.splitlines() + section: str = "" + i = 0 + while i < len(lines): + line = lines[i].strip() + if line.startswith("constrainLabelAny"): + while i < len(lines) and not lines[i].strip().endswith("["): + i += 1 + i = self._parse_constraints(lines, i + 1) + elif any(line.startswith(x) for x in ("left", "right", "context")): + section = line.split("[")[0].strip() + elif line.startswith(("node", "edge")) and section: + self._parse_element(line, section) + i += 1 + self._synchronize() + self._attach_constraints() + return (self.graphs["left"], self.graphs["right"], self.graphs["context"]) + + def __repr__(self) -> str: + """ + Return a summary indicating the number of nodes in each graph. + + Returns + ------- + str + A brief representation with node counts. + """ + return ( + f"" + ) + + def help(self) -> str: + """ + Return a usage summary for the GMLToGraph parser. + + Returns + ------- + str + Multi-line help text explaining the API. + """ + return ( + "GMLToGraph(gml_text) -> (left, right, context) graphs\n" + "Supports Nuc-style and Rest-style constraint parsing." + ) diff --git a/synkit/IO/combinatorial/graph_to_gml.py b/synkit/IO/combinatorial/graph_to_gml.py new file mode 100644 index 0000000..b9c2651 --- /dev/null +++ b/synkit/IO/combinatorial/graph_to_gml.py @@ -0,0 +1,291 @@ +import networkx as nx +from typing import Dict, List, Tuple, Any, Set + + +class GraphToGML: + """ + Convert two NetworkX graphs into a minimal GML reaction rule string, + using canonical context detection and SynKit-style constraint annotations. + + This class identifies the conserved context (nodes/edges unchanged between + reactant and product graphs), extracts minimal changing portions, + and generates a GML block including left, context, right, and placeholder constraints. + + :param left: Graph representing the reactant state. + :type left: nx.Graph + :param right: Graph representing the product state. + :type right: nx.Graph + :param rule_id: Identifier for the reaction rule (default: "1"). + :type rule_id: str + :raises ValueError: If input graphs have mismatched mapping nodes. + + :Example: + >>> from networkx import Graph + >>> G1, G2 = Graph(), Graph() + >>> # populate G1, G2 with atom_map nodes, constraints, etc. + >>> g2g = GraphToGML(G1, G2, rule_id="rxn1") + >>> gml = g2g.to_gml() + >>> print(gml) + """ + + def __init__(self, left: nx.Graph, right: nx.Graph, rule_id: str = "1") -> None: + """ + Initialize the GraphToGML converter. + + :param left: Reactant graph with node and edge attributes. + :type left: nx.Graph + :param right: Product graph with node and edge attributes. + :type right: nx.Graph + :param rule_id: Unique identifier for the rule output. + :type rule_id: str + :returns: None + :rtype: None + :raises ValueError: If graphs have nodes without atom_map attribute. + """ + self.left: nx.Graph = left + self.right: nx.Graph = right + self.rule_id: str = rule_id + + self.context_nodes: Set[int] = set() + self.context_edges: Set[Tuple[int, int]] = set() + self.left_nodes: Set[int] = set() + self.right_nodes: Set[int] = set() + self.left_edges: List[Tuple[int, int, Dict[str, Any]]] = [] + self.right_edges: List[Tuple[int, int, Dict[str, Any]]] = [] + self.context: nx.Graph = nx.Graph() + self.constraint_dict: Dict[str, List[str]] = {} + + @staticmethod + def same_node(nl: Dict[str, Any], nr: Dict[str, Any]) -> bool: + """ + Compare two node attribute dictionaries, ignoring 'atom_map'. + + :param nl: Node attribute dict from left graph. + :type nl: Dict[str, Any] + :param nr: Node attribute dict from right graph. + :type nr: Dict[str, Any] + :returns: True if all attributes (except 'atom_map') match. + :rtype: bool + """ + keys = set(nl) | set(nr) + keys.discard("atom_map") + return all(nl.get(k) == nr.get(k) for k in keys) + + @staticmethod + def same_edge(el: Dict[str, Any], er: Dict[str, Any]) -> bool: + """ + Compare two edge attribute dictionaries for equality. + + :param el: Edge attribute dict from left graph. + :type el: Dict[str, Any] + :param er: Edge attribute dict from right graph. + :type er: Dict[str, Any] + :returns: True if all edge attributes match. + :rtype: bool + """ + keys = set(el) | set(er) + return all(el.get(k) == er.get(k) for k in keys) + + def compute(self) -> None: + """ + Compute conserved context and minimal changing nodes/edges, + then collect placeholder constraints from context graph. + + :returns: None + :rtype: None + """ + # Identify conserved context nodes + self.context_nodes = { + n + for n in set(self.left.nodes) & set(self.right.nodes) + if self.same_node(self.left.nodes[n], self.right.nodes[n]) + } + # Identify conserved context edges + self.context_edges = { + tuple(sorted((u, v))) + for u, v in set(self.left.edges) & set(self.right.edges) + if u in self.context_nodes + and v in self.context_nodes + and self.same_edge(self.left.edges[u, v], self.right.edges[u, v]) + } + # Compute minimal changing edges and nodes + self.left_edges, left_extra = self.get_changing_edges_and_nodes( + self.left, self.context_edges, self.context_nodes + ) + self.right_edges, right_extra = self.get_changing_edges_and_nodes( + self.right, self.context_edges, self.context_nodes + ) + self.left_nodes = { + n for n in self.left.nodes if n not in self.context_nodes + } | left_extra + self.right_nodes = { + n for n in self.right.nodes if n not in self.context_nodes + } | right_extra + # Build context subgraph + self.context = nx.Graph() + for n in sorted(self.context_nodes): + self.context.add_node(n, **self.left.nodes[n]) + for u, v in self.context_edges: + self.context.add_edge(u, v, **self.left.edges[u, v]) + # Gather constraints from context + self.constraint_dict.clear() + for _, d in self.context.nodes(data=True): + if "constraint" in d: + self.constraint_dict[d["label"]] = d["constraint"] + for _, _, d in self.context.edges(data=True): + if "bond_constraint" in d: + self.constraint_dict[d["label"]] = d["bond_constraint"] + + @staticmethod + def get_changing_edges_and_nodes( + G: nx.Graph, context_edges: Set[Tuple[int, int]], context_nodes: Set[int] + ) -> Tuple[List[Tuple[int, int, Dict[str, Any]]], Set[int]]: + """ + Identify edges and nodes in G that are not in the conserved context. + + :param G: Input graph. + :type G: nx.Graph + :param context_edges: Edges part of context. + :type context_edges: Set[Tuple[int, int]] + :param context_nodes: Nodes part of context. + :type context_nodes: Set[int] + :returns: A tuple of (changed_edges, changed_nodes). + :rtype: Tuple[List[Tuple[int, int, Dict[str, Any]]], Set[int]] + """ + changed_edges: List[Tuple[int, int, Dict[str, Any]]] = [] + changed_nodes: Set[int] = set() + for u, v in G.edges(): + key = tuple(sorted((u, v))) + if key not in context_edges: + changed_edges.append((u, v, G.edges[u, v])) + if u not in context_nodes: + changed_nodes.add(u) + if v not in context_nodes: + changed_nodes.add(v) + return changed_edges, changed_nodes + + @staticmethod + def graph_section( + name: str, + G: nx.Graph, + nodes: Set[int], + edges: List[Tuple[int, int, Dict[str, Any]]], + ) -> List[str]: + """ + Render a GML block for a subgraph section. + + :param name: Section name ('left', 'context', or 'right'). + :type name: str + :param G: Graph containing nodes/edges. + :type G: nx.Graph + :param nodes: Node IDs to include. + :type nodes: Set[int] + :param edges: Edges to include (u,v,attr). + :type edges: List[Tuple[int,int,Dict[str,Any]]] + :returns: Lines of GML representing the section. + :rtype: List[str] + """ + lines: List[str] = [f" {name} ["] + for n in sorted(nodes): + d = G.nodes[n] + lbl = d.get("label", d.get("element", str(n))) + lines.append(f' node [ id {n} label "{lbl}" ]') + order_map = {1: "-", 1.5: ":", 2: "=", 3: "#"} + for u, v, d in sorted(edges): + lbl = d.get("label", order_map.get(d.get("order", 1), "-")) + lines.append(f' edge [ source {u} target {v} label "{lbl}" ]') + lines.append(" ]") + return lines + + @staticmethod + def context_section(G: nx.Graph) -> List[str]: + """ + Render the conserved context GML block. + + :param G: Context graph. + :type G: nx.Graph + :returns: Lines of GML for context. + :rtype: List[str] + """ + lines: List[str] = [" context []"] # placeholder, updated in full block + lines = [" context ["] + for n, d in sorted(G.nodes(data=True)): + lbl = d.get("label", d.get("element", str(n))) + lines.append(f' node [ id {n} label "{lbl}" ]') + order_map = {1: "-", 1.5: ":", 2: "=", 3: "#"} + for u, v, d in sorted(G.edges(data=True)): + lbl = d.get("label", order_map.get(d.get("order", 1), "-")) + lines.append(f' edge [ source {u} target {v} label "{lbl}" ]') + lines.append(" ]") + return lines + + def constraints_section(self) -> List[str]: + """ + Render placeholder constraints as one constrainLabelAny block. + + :returns: Lines of GML for constraints. + :rtype: List[str] + """ + lines: List[str] = [] + if not self.constraint_dict: + return lines + lines.append(" constrainLabelAny [") + for ph, children in self.constraint_dict.items(): + lines.append(f' label "Rest({ph})"') + if children: + labels_inner = " ".join(f'label "Rest({c})"' for c in children) + lines.append(f" labels [{labels_inner}]") + else: + lines.append(" labels []") + lines.append(" ]") + return lines + + def to_gml(self) -> str: + """ + Generate the full GML reaction rule string. + + :returns: Complete GML string for the reaction rule. + :rtype: str + """ + self.compute() + out: List[str] = ["rule [", f' ruleID "{self.rule_id}"'] + out += self.graph_section("left", self.left, self.left_nodes, self.left_edges) + out += self.context_section(self.context) + out += self.graph_section( + "right", self.right, self.right_nodes, self.right_edges + ) + out += self.constraints_section() + out.append("]") + return "\n".join(out) + + def __repr__(self) -> str: + """ + Return a summary of the rule converter. + + :returns: Brief description with node counts. + :rtype: str + """ + return ( + f"" + ) + + def help(self) -> str: + """ + Show usage instructions for GraphToGML. + + :returns: Multi-line help text. + :rtype: str + """ + return ( + "GraphToGML(left, right, rule_id='374')\n" + " - left: nx.Graph for reactant state\n" + " - right: nx.Graph for product state\n" + " - rule_id: identifier for the GML rule\n" + "\n" + "Usage:\n" + " g2g = GraphToGML(G_left, G_right, rule_id='374')\n" + " print(g2g.to_gml())\n" + ) diff --git a/synkit/IO/combinatorial/graph_to_smarts.py b/synkit/IO/combinatorial/graph_to_smarts.py new file mode 100644 index 0000000..959ab28 --- /dev/null +++ b/synkit/IO/combinatorial/graph_to_smarts.py @@ -0,0 +1,189 @@ +import networkx as nx +from typing import List, Dict, Any, Set, Optional +from rdkit import Chem +from rdkit.Chem import AllChem + + +class GraphToSMARTS: + """ + Convert NetworkX graphs (with placeholder nodes/constraints) into SMARTS or reaction SMARTS strings. + + :param placeholder_labels: Set of labels recognized as placeholders (e.g., '_R', 'X', 'Y', 'Z'). + :type placeholder_labels: Optional[Set[str]] + :param validate: If True, validate generated SMARTS or reaction SMARTS using RDKit. + :type validate: bool + :raises: None + + :Example: + >>> G = nx.Graph() + >>> G.add_node(1, label='C', constraint=None) + >>> G.add_node(2, label='O', constraint=None) + >>> G.add_edge(1, 2, order=1) + >>> g2s = GraphToSMARTS() + >>> smarts = g2s.graph_to_smarts(G) + >>> isinstance(smarts, str) + True + """ + + def __init__( + self, placeholder_labels: Optional[Set[str]] = None, validate: bool = True + ) -> None: + """ + Initialize the GraphToSMARTS converter. + + :param placeholder_labels: Labels to treat as wildcard placeholders; defaults to {'_R','X','Y','Z'}. + :type placeholder_labels: Optional[Set[str]] + :param validate: Whether to validate SMARTS with RDKit if available. + :type validate: bool + :returns: None + :rtype: None + """ + if placeholder_labels is None: + placeholder_labels = {"_R", "X", "Y", "Z"} + self.placeholder_labels: Set[str] = placeholder_labels + self.validate: bool = validate + + def graph_to_smarts(self, G: nx.Graph) -> str: + """ + Convert a NetworkX graph into a SMARTS string representation. + + :param G: NetworkX Graph with node attributes: + - 'label': str atomic label or placeholder + - 'constraint': Optional[List[str]] allowed element list for placeholders + and edge attribute: + - 'order': float bond order (1,1.5,2,3) + :type G: nx.Graph + :returns: SMARTS string encoding the graph structure. + :rtype: str + :raises ValueError: If RDKit fails to parse the generated SMARTS when validate=True. + + :Example: + >>> G = nx.Graph() + >>> G.add_node(1, label='C', constraint=None) + >>> G.add_node(2, label='O', constraint=None) + >>> G.add_edge(1, 2, order=1) + >>> smarts = GraphToSMARTS().graph_to_smarts(G) + >>> smarts + '[C:1](-[O:2])' + """ + bond_sym: Dict[float, str] = {1: "-", 1.5: ":", 2: "=", 3: "#"} + + def bracket(node: Any, data: Dict[str, Any]) -> str: + map_num = node + if data.get("constraint"): + core = ",".join(data["constraint"]) + else: + core = data["label"] + return f"[{core}:{map_num}]" + + def choose_root(sub: nx.Graph) -> Any: + real_nodes = [ + n + for n, d in sub.nodes(data=True) + if d.get("label") not in self.placeholder_labels + ] + if real_nodes: + return max(real_nodes, key=sub.degree) + return min(sub.nodes) + + def rec(node: Any, parent: Any, sub: nx.Graph, visited: Set[Any]) -> str: + visited.add(node) + s = bracket(node, sub.nodes[node]) + for nbr in sub.neighbors(node): + if nbr == parent: + continue + order = sub[node][nbr].get("order", 1) + bond = bond_sym.get(order, "-") + s += f"({bond}{rec(nbr, node, sub, visited)})" + return s + + frags: List[str] = [] + for comp in nx.connected_components(G): + subgraph = G.subgraph(comp) + root = choose_root(subgraph) + frags.append(rec(root, None, subgraph, set())) + + smarts: str = ".".join(frags) + + if self.validate: + try: + if not Chem.MolFromSmarts(smarts): + raise ValueError("RDKit could not parse generated SMARTS.") + except ImportError: + pass + + return smarts + + def graphs_to_rxn_smarts(self, reactant: nx.Graph, product: nx.Graph) -> str: + """ + Construct a reaction SMARTS string from reactant and product graphs. + + :param reactant: Reactant NetworkX graph. + :type reactant: nx.Graph + :param product: Product NetworkX graph. + :type product: nx.Graph + :returns: Reaction SMARTS in the form 'reactants>>products'. + :rtype: str + :raises ValueError: If RDKit fails to parse the reaction SMARTS when validate=True. + + :Example: + >>> G1 = nx.Graph() + >>> G1.add_node(1, label='C', constraint=None) + >>> G1.add_node(2, label='O', constraint=None) + >>> G1.add_edge(1, 2, order=1) + >>> G2 = nx.Graph() + >>> G2.add_node(1, label='C', constraint=None) + >>> G2.add_node(2, label='O', constraint=None) + >>> G2.add_edge(1, 2, order=2) + >>> rxn = GraphToSMARTS().graphs_to_rxn_smarts(G1, G2) + >>> rxn + '[C:1](-[O:2])>>[C:1]([O:2])=' + """ + sm_r: str = self.graph_to_smarts(reactant) + sm_p: str = self.graph_to_smarts(product) + rxn: str = f"{sm_r}>>{sm_p}" + + if self.validate: + try: + if not AllChem.ReactionFromSmarts(rxn): + raise ValueError("RDKit could not parse generated reaction SMARTS.") + except ImportError: + pass + + return rxn + + def __repr__(self) -> str: + """ + Return an unambiguous representation of the converter instance. + + :returns: String showing placeholder labels and validation setting. + :rtype: str + """ + return ( + f"" + ) + + def help(self) -> str: + """ + Provide usage information for GraphToSMARTS. + + :returns: Multi-line help string describing available methods. + :rtype: str + + :Example: + >>> print(GraphToSMARTS().help()) # doctest:+NORMALIZE_WHITESPACE + GraphToSMARTS(placeholder_labels=None, validate=True) + - Use .graph_to_smarts(G) for a single graph + - Use .graphs_to_rxn_smarts(G_react, G_prod) for reaction SMARTS + """ + return ( + "GraphToSMARTS(placeholder_labels=None, validate=True)\n" + " - Use .graph_to_smarts(G) for a single graph\n" + " - Use .graphs_to_rxn_smarts(G_react, G_prod) for reaction SMARTS\n" + "Node attributes expected:\n" + " label: str (e.g., 'C', '_R', 'H+')\n" + " constraint: Optional[List[str]] for placeholders\n" + "Edge attribute:\n" + " order: float (1, 1.5, 2, 3) mapped to '-', ':', '=', '#'\n" + ) diff --git a/synkit/IO/combinatorial/smarts_expander.py b/synkit/IO/combinatorial/smarts_expander.py new file mode 100644 index 0000000..e66b97f --- /dev/null +++ b/synkit/IO/combinatorial/smarts_expander.py @@ -0,0 +1,152 @@ +import re +import itertools +from typing import List, Dict, Tuple, Iterator, Union + + +class SMARTSExpander: + """ + Efficiently enumerate all valid reaction SMARTS by expanding atom-list + placeholders like [C,N,O,P,S:9], ensuring that each atom-map uses the same + element everywhere it appears (on both sides of a reaction). + + :param smarts: SMARTS string, possibly containing one or more atom-list placeholders. + :type smarts: str + :returns: Expanded SMARTS strings without atom-list placeholders. + :rtype: List[str] + :raises ValueError: If no valid expansions exist due to incompatible element lists. + + Example usage:: + + >>> rxn = ( + ... '[H+:6].[C:7](-[O:8](-[H:12]))(-[C,N,O,P,S:9])' + ... '(-[C,N,O,P,S:10])(-[H:11]).' + ... '[C:2](-[S:4](-[C,N,O,P,S:5]))(-[C,N,O,P,S:1])(=[O:3])>>' + ... '[S:4](-[H:6])(-[C,N,O,P,S:5]).[H+:12].' + ... '[C:7](-[O:8](-[C:2](-[C,N,O,P,S:1])(=[O:3])))(-[C,N,O,P,S:9])' + ... '(-[C,N,O,P,S:10])(-[H:11])' + ... ) + >>> ex = SMARTSExpander.expand(rxn) + >>> len(ex) + 625 + >>> ex[:3] # first three expansions + ['[H+:6].[C:7](-[O:8](-[H:12]))(-[C:9])(-[C:10])(-[H:11]).[C:2]...' + '>>', + '...', + '...'] + """ + + _PAT = re.compile(r"\[([A-Z][a-z]?(?:,[A-Z][a-z]?)*)(:[0-9]+)\]") + + @staticmethod + def _extract_map_to_elements(matches: List[re.Match]) -> Dict[str, List[str]]: + """ + Build a mapping from atom-map to the intersection of allowed elements. + + :param matches: List of regex match objects for placeholders. + :type matches: List[re.Match] + :returns: Dictionary mapping ":map" to sorted list of shared elements. + :rtype: Dict[str, List[str]] + """ + amap2set: Dict[str, set] = {} + for m in matches: + elems = set(m.group(1).split(",")) + amap = m.group(2) + if amap not in amap2set: + amap2set[amap] = elems + else: + amap2set[amap] &= elems + for amap, s in amap2set.items(): + if not s: + raise ValueError(f"No overlapping elements for atom-map {amap}") + return {amap: sorted(s) for amap, s in amap2set.items()} + + @staticmethod + def _build_template( + smarts: str, matches: List[re.Match] + ) -> Tuple[List[Union[str, str]], List[str]]: + """ + Build a list of string segments and placeholders for reconstruction. + + :param smarts: Original SMARTS string. + :type smarts: str + :param matches: Regex matches for placeholders. + :type matches: List[re.Match] + :returns: Tuple of list of segments and placeholder order. + :rtype: Tuple[List[Union[str, str]], List[str]] + """ + segments: List[Union[str, str]] = [] + placeholder_order: List[str] = [] + last = 0 + for m in matches: + # fmt: off + segments.append(smarts[last: m.start()]) + # fmt: on + amap = m.group(2) + segments.append(amap) + placeholder_order.append(amap) + last = m.end() + segments.append(smarts[last:]) + return segments, placeholder_order + + @classmethod + def expand_iter(cls, smarts: str) -> Iterator[str]: + """ + Yield expanded SMARTS strings lazily. + + :param smarts: SMARTS string with placeholders. + :type smarts: str + :yields: One expanded SMARTS string at a time. + :rtype: Iterator[str] + + :raises ValueError: If no valid expansions due to incompatible lists. + """ + matches = list(cls._PAT.finditer(smarts)) + if not matches: + yield smarts + return + + amap2els = cls._extract_map_to_elements(matches) + segments, order = cls._build_template(smarts, matches) + unique_maps = list(dict.fromkeys(order)) + + pools = [amap2els[am] for am in unique_maps] + for combo in itertools.product(*pools): + mapping = dict(zip(unique_maps, combo)) + out = [] + for seg in segments: + if seg in mapping: + out.append(f"[{mapping[seg]}{seg}]") + else: + out.append(seg) + yield "".join(out) + + @classmethod + def expand(cls, smarts: str) -> List[str]: + """ + Return a list of all expanded SMARTS. + + :param smarts: SMARTS string with placeholders. + :type smarts: str + :returns: List of expanded SMARTS strings. + :rtype: List[str] + + :raises ValueError: If no valid expansions exist. + """ + return list(cls.expand_iter(smarts)) + + +# # --- Example usage --- + +# if __name__ == "__main__": +# rxn = ( +# '[H+:6].[C:7](-[O:8](-[H:12]))(-[C,N,O,P,S:9])(-[C,N,O,P,S:10])(-[H:11]).' +# '[C:2](-[S:4](-[C,N,O,P,S:5]))(-[C,N,O,P,S:1])(=[O:3])>>' +# '[S:4](-[H:6])(-[C,N,O,P,S:5]).[H+:12].' +# '[C:7](-[O:8](-[C:2](-[C,N,O,P,S:1])(=[O:3])))(-[C,N,O,P,S:9])(-[C,N,O,P,S:10])(-[H:11])' +# ) +# n = 0 +# for i, s in enumerate(SMARTSExpander.expand_iter(rxn)): +# if i < 3 or i > 621: +# print(f"{i+1}: {s}") +# n += 1 +# print(f"Total: {n} enumerated SMARTS.") diff --git a/synkit/IO/combinatorial/smarts_generalizer.py b/synkit/IO/combinatorial/smarts_generalizer.py new file mode 100644 index 0000000..42854df --- /dev/null +++ b/synkit/IO/combinatorial/smarts_generalizer.py @@ -0,0 +1,134 @@ +import re +from typing import List, Set +from rdkit import Chem +from rdkit.Chem import rdChemReactions + + +class SMARTSGeneralizer: + """ + Generalizes a list of atom-mapped (reaction) SMARTS into one combinatorial SMARTS + with element-list placeholders at mapped atom positions. + Optionally validates output using RDKit. + + :param sanity_check: If True, validate the output SMARTS with RDKit. + :type sanity_check: bool + + Example + ------- + >>> input_smarts = [ + ... '[C:1]-[N:2]>>[N:1]-[C:2]', + ... '[N:1]-[N:2]>>[N:1]-[N:2]', + ... '[O:1]-[N:2]>>[N:1]-[N:2]' + ... ] + >>> gen = SMARTSGeneralizer() + >>> print(gen.generalize(input_smarts)) + [C,N,O:1]-[N:2]>>[N:1]-[C,N,O:2] + """ + + _atom_pat = re.compile(r"\[([A-Z][a-z]?):(\d+)\]") + + def __init__(self, sanity_check: bool = True): + """ + Initialize SMARTSGeneralizer. + + :param sanity_check: If True, validate the output SMARTS with RDKit. + :type sanity_check: bool + """ + self.sanity_check = sanity_check + + def generalize(self, smarts_list: List[str]) -> str: + """ + Generalize a list of SMARTS/reaction SMARTS into one combinatorial SMARTS, + with element-list placeholders per atom-map index and position. + + :param smarts_list: List of atom-mapped SMARTS strings (same topology/order). + :type smarts_list: list[str] + :return: Generalized SMARTS with atom-list placeholders. + :rtype: str + :raises ValueError: If input list is empty, topology is inconsistent, or output is invalid. + """ + if not smarts_list: + raise ValueError("Input list is empty.") + + if len(smarts_list) == 1: + combined = smarts_list[0] + else: + pos_list: List[List[re.Match]] = [ + list(self._atom_pat.finditer(s)) for s in smarts_list + ] + n_atoms = len(pos_list[0]) + for pl in pos_list: + if len(pl) != n_atoms: + raise ValueError( + "All input SMARTS must have same atom-mapped topology and order." + ) + + pos2map: List[str] = [] + pos2elems: List[Set[str]] = [] + for i in range(n_atoms): + mapnum = pos_list[0][i].group(2) + elems = set(match[i].group(1) for match in pos_list) + pos2map.append(mapnum) + pos2elems.append(elems) + + # Template assembly: alternate static and atom-map segments + first = smarts_list[0] + atoms = list(self._atom_pat.finditer(first)) + segments = [] + last = 0 + for m in atoms: + # fmt: off + segments.append(first[last: m.start()]) + # fmt: on + segments.append(m.group(2)) # mapnum as marker + last = m.end() + segments.append(first[last:]) + + # Reconstruct combinatorial SMARTS + out = [] + idx = 0 + for seg in segments: + if idx < len(pos2map) and seg == pos2map[idx]: + els = sorted(pos2elems[idx]) + out.append(f"[{','.join(els)}:{seg}]") + idx += 1 + else: + out.append(seg) + combined = "".join(out) + + # RDKit validation + if self.sanity_check: + if ">>" in combined: + rxn = rdChemReactions.ReactionFromSmarts(combined) + if rxn is None or rxn.GetNumProductTemplates() == 0: + raise ValueError(f"Invalid reaction SMARTS generated: {combined}") + else: + mol = Chem.MolFromSmarts(combined) + if mol is None: + raise ValueError(f"Invalid molecule SMARTS generated: {combined}") + + return combined + + def describe(self) -> None: + """ + Print usage instructions and an example. + + :return: None + """ + print( + "SMARTSGeneralizer: Generalize a list of atom-mapped (reaction) SMARTS into a single " + "SMARTS with element-list placeholders at each mapped atom position.\n" + "Usage example:\n" + " >>> gen = SMARTSGeneralizer(sanity_check=True)\n" + " >>> smarts_list = ['[C:1]-[N:2]', '[O:1]-[N:2]', '[N:1]-[N:2]']\n" + " >>> print(gen.generalize(smarts_list)) # [C,N,O:1]-[N:2]" + ) + + def __repr__(self) -> str: + """ + String representation. + + :return: Description with current sanity_check setting. + :rtype: str + """ + return f"SMARTSGeneralizer(sanity_check={self.sanity_check})" diff --git a/synkit/IO/combinatorial/smarts_to_graph.py b/synkit/IO/combinatorial/smarts_to_graph.py new file mode 100644 index 0000000..bb8bad4 --- /dev/null +++ b/synkit/IO/combinatorial/smarts_to_graph.py @@ -0,0 +1,183 @@ +import re +import networkx as nx +from typing import Optional, Set, Tuple, List, Dict + +from rdkit import Chem + + +class SMARTSToGraph: + """ + Convert SMARTS or reaction SMARTS strings into NetworkX graphs with full atom and constraint data. + + :param placeholder_labels: Optional set of labels to treat as placeholders (e.g., wildcard atoms). + :type placeholder_labels: Optional[Set[str]] + :raises: None + """ + + def __init__(self, placeholder_labels: Optional[Set[str]] = None) -> None: + """ + Initialize a SMARTSToGraph converter. + + :param placeholder_labels: Set of placeholder labels used in SMARTS to identify wildcard positions. + Defaults to {'_R', 'X', 'Y', 'Z'} if None. + :type placeholder_labels: Optional[Set[str]] + :returns: None + :rtype: None + """ + self.placeholder_labels: Set[str] = placeholder_labels or {"_R", "X", "Y", "Z"} + + @staticmethod + def _safe_total_hs(atom: "Chem.Atom") -> int: + """ + Compute the total number of hydrogens (explicit + implicit) for an RDKit Atom safely. + + :param atom: RDKit Atom instance whose hydrogen count is desired. + :type atom: Chem.Atom + :returns: Total hydrogen count for the atom. + :rtype: int + :raises: Exception if RDKit property cache update fails. + + :Example: + >>> from rdkit import Chem + >>> atom = Chem.MolFromSmiles('C').GetAtomWithIdx(0) + >>> SMARTSToGraph._safe_total_hs(atom) + 3 + """ + try: + atom.UpdatePropertyCache(strict=False) + return int(atom.GetTotalNumHs(includeExplicit=True)) + except Exception: + return 0 + + def smarts_to_graph(self, smarts: str) -> nx.Graph: + """ + Parse a SMARTS string into a NetworkX graph representation, extracting wildcard constraints. + + :param smarts: SMARTS pattern to convert (e.g., '[C:1]C[O:2]'). + :type smarts: str + :returns: NetworkX Graph with: + - node attributes: element (str), charge (int), hcount (int), + label (str), constraint (Optional[List[str]]), atom_map (int) + - edge attributes: order (float), standard_order (float) + :rtype: nx.Graph + :raises ImportError: If RDKit is not available. + :raises ValueError: If the SMARTS string is invalid or atoms lack mapping numbers. + + :Example: + >>> stg = SMARTSToGraph() + >>> graph = stg.smarts_to_graph('[CH3:1]-[OH:2]') + >>> graph.nodes[1]['element'] + 'C' + >>> graph.nodes[2]['element'] + 'O' + """ + if Chem is None: + raise ImportError("RDKit is required for SMARTS parsing.") + + # Pre-scan SMARTS for wildcard constraint lists (e.g., [C,N:5]) + constraint_map: Dict[int, List[str]] = {} + for match in re.finditer(r"\[([^:\]]+?):(\d+)\]", smarts): + atom_expr, atom_idx = match.groups() + idx = int(atom_idx) + if "," in atom_expr: + constraint_map[idx] = [s.strip() for s in atom_expr.split(",")] + + mol = Chem.MolFromSmarts(smarts) + if mol is None: + raise ValueError(f"Invalid SMARTS string: {smarts!r}") + + G = nx.Graph() + idx_to_map: Dict[int, int] = {} + + # Add nodes with full atom data + for atom in mol.GetAtoms(): + amap = atom.GetAtomMapNum() + if amap == 0: + raise ValueError( + "All atoms in SMARTS must have a mapping number (atom map)." + ) + idx_to_map[atom.GetIdx()] = amap + + # Determine element, label, and constraints + raw_label = atom.GetSymbol() + if amap in constraint_map: + constraint = constraint_map[amap] + label = next(iter(self.placeholder_labels)) + element = "*" + else: + constraint = None + label = raw_label + element = "*" if raw_label in self.placeholder_labels else raw_label + + charge = atom.GetFormalCharge() + hcount = self._safe_total_hs(atom) + + G.add_node( + amap, + element=element, + charge=charge, + hcount=hcount, + label=label, + constraint=constraint, + atom_map=amap, + ) + + # Add edges with bond order data + for bond in mol.GetBonds(): + u = idx_to_map[bond.GetBeginAtomIdx()] + v = idx_to_map[bond.GetEndAtomIdx()] + order = bond.GetBondTypeAsDouble() + G.add_edge(u, v, order=order, standard_order=order) + + return G + + def rxn_smarts_to_graphs(self, rxn: str) -> Tuple[nx.Graph, nx.Graph]: + """ + Split a reaction SMARTS into separate reactant and product graphs. + + :param rxn: Reaction SMARTS in the format 'reactants>>products'. + :type rxn: str + :returns: Tuple of (reactant_graph, product_graph). + :rtype: Tuple[nx.Graph, nx.Graph] + :raises ValueError: If the reaction SMARTS string does not contain '>>'. + + :Example: + >>> stg = SMARTSToGraph() + >>> react, prod = stg.rxn_smarts_to_graphs('[CH3:1]-[OH:2]>>[CH2:1]=[O:2]') + >>> react.nodes + [1, 2] + >>> prod.nodes + [1, 2] + """ + if ">>" not in rxn: + raise ValueError("Reaction SMARTS must contain '>>' separator.") + lhs, rhs = rxn.split(">>", 1) + return self.smarts_to_graph(lhs), self.smarts_to_graph(rhs) + + def __repr__(self) -> str: + """ + Return an unambiguous string representation of this converter. + + :returns: Representation of the instance showing placeholder labels. + :rtype: str + """ + return f"" + + def describe(self) -> str: + """ + Provide a usage summary for SMARTSToGraph. + + :returns: Multi-line string explaining available methods and usage. + :rtype: str + + :Example: + >>> print(SMARTSToGraph().describe()) + SMARTSToGraph(placeholder_labels=None) + smarts_to_graph(smarts_str) -> Graph + rxn_smarts_to_graphs(rxn_smarts) -> (Graph_react, Graph_prod) + """ + return ( + "SMARTSToGraph(placeholder_labels=None)\n" + " smarts_to_graph(smarts_str) -> Graph\n" + " rxn_smarts_to_graphs(rxn_smarts) -> (Graph_react, Graph_prod)\n" + ) diff --git a/synkit/Rule/Apply/rule_matcher.py b/synkit/Rule/Apply/rule_matcher.py index 4e9b88f..b587fd9 100644 --- a/synkit/Rule/Apply/rule_matcher.py +++ b/synkit/Rule/Apply/rule_matcher.py @@ -17,10 +17,10 @@ >>> smarts, rule = matcher.get_result() """ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import networkx as nx -from synkit.IO import rsmi_to_graph +from synkit.IO import rsmi_to_graph, rsmi_to_its from synkit.Chem.Reaction.standardize import Standardize from synkit.Chem.Reaction.balance_check import BalanceReactionCheck from synkit.Synthesis.Reactor.syn_reactor import SynReactor @@ -54,13 +54,15 @@ class RuleMatcher: :vartype result: Tuple[str, nx.Graph] """ - def __init__(self, rsmi: str, rule: nx.Graph) -> None: + def __init__( + self, rsmi: str, rule: Union[str, nx.Graph], explicit_h: bool = True + ) -> None: """Initialize the matcher by standardizing the RSMI, building graphs, checking balance, and computing the match. :param rsmi: Reaction SMILES in 'reactant>>product' format. :type rsmi: str - :param rule: Transformation‑rule graph. + :param rule: Transformation-rule graph. :type rule: nx.Graph :raises ValueError: If no SMARTS reproduces the RSMI under the given rule. @@ -68,7 +70,10 @@ def __init__(self, rsmi: str, rule: nx.Graph) -> None: self.std = Standardize() self.rsmi = self.std.fit(rsmi) self.r_graph, self.p_graph = rsmi_to_graph(self.rsmi, drop_non_aam=False) + if isinstance(rule, str): + rule = rsmi_to_its(rule, core=True) self.rule = rule + self.explicit_h = explicit_h self.balanced = BalanceReactionCheck(n_jobs=1).rsmi_balance_check(self.rsmi) # Compute and store the match result @@ -124,7 +129,12 @@ def _match_reverse(self) -> Optional[Tuple[str, nx.Graph]]: return smarts, self.rule # Reactant‑side with inverted template - reactor = SynReactor(substrate=self.p_graph, template=self.rule, invert=True) + reactor = SynReactor( + substrate=self.p_graph, + template=self.rule, + invert=True, + explicit_h=self.explicit_h, + ) for smarts in reactor.smarts_list: std_r = self.std.fit(smarts) if self.all_in( diff --git a/synkit/Synthesis/Reactor/imba_engine.py b/synkit/Synthesis/Reactor/imba_engine.py new file mode 100644 index 0000000..310b1f6 --- /dev/null +++ b/synkit/Synthesis/Reactor/imba_engine.py @@ -0,0 +1,165 @@ +import networkx as nx +from typing import Union, Optional, List +from synkit.Graph.Canon.canon_graph import GraphCanonicaliser +from synkit.Synthesis.Reactor.syn_reactor import SynReactor, Strategy +from synkit.Graph.syn_graph import SynGraph +from synkit.Rule.syn_rule import SynRule +from synkit.Graph.Wildcard.radwc import RadWC +from synkit.Chem.Reaction.radical_wildcard import clean_wc + + +class ImbaEngine: + """ + Reactor for applying a SynKit reaction template to a substrate, with + options for inversion, canonicalisation, strategy, partial ITS, and + radical wildcard appending and fragment cleaning in products. + + :param substrate: Input substrate; SMILES string, networkx.Graph, or SynGraph. + :type substrate: Union[str, nx.Graph, SynGraph] + :param template: Reaction template; SMARTS (bracketed) string, networkx.Graph, or SynRule. + :type template: Union[str, nx.Graph, SynRule] + :param add_wildcard: If True, apply radical wildcard transform to each product SMARTS. + :type add_wildcard: bool + :param clean_fragments: If True, remove wildcard fragments and optionally keep max fragment. + :type clean_fragments: bool + :param max_frag: If True, force maximal fragment selection when cleaning. + :type max_frag: bool + :param invert: If True, apply the template in reverse (product → reactant). + :type invert: bool + :param canonicaliser: Optional GraphCanonicaliser for preprocessing or postprocessing. + :type canonicaliser: Optional[GraphCanonicaliser] + :param strategy: Enumeration strategy (Strategy enum or string). + :type strategy: Union[Strategy, str] + :param partial: If True, perform partial ITS graph construction on results. + :type partial: bool + """ + + def __init__( + self, + substrate: Union[str, nx.Graph, SynGraph], + template: Union[str, nx.Graph, SynRule], + add_wildcard: bool = True, + clean_fragments: bool = False, + max_frag: bool = False, + invert: bool = False, + canonicaliser: Optional[GraphCanonicaliser] = None, + strategy: Union[Strategy, str] = Strategy.ALL, + partial: bool = False, + ) -> None: + # Assign parameters + self.substrate = substrate + self.template = template + self.add_wildcard = add_wildcard + self.clean_fragments = clean_fragments + self.max_frag = max_frag + self.invert = invert + self.canonicaliser = canonicaliser + self.strategy = strategy + self.partial = partial + self._results: List[str] = [] + # Auto-run fit on init + self.fit() + + def __repr__(self) -> str: + return ( + f"" + ) + + @staticmethod + def describe() -> None: + """ + Print class documentation and usage examples. + """ + print(ImbaEngine.__doc__) + + def fit(self) -> "ImbaEngine": + """ + Apply the reaction template to the substrate, producing product SMARTS. + Optionally clean wildcard fragments and add radical wildcards. + Results are stored internally and self is returned. + + :returns: self + :rtype: ImbReactor + :raises ValueError: If substrate cannot be parsed or reaction fails. + """ + from synkit.IO import graph_to_smi + + # Determine reactant SMILES + if isinstance(self.substrate, (nx.Graph, SynGraph)): + react_smiles = graph_to_smi(self.substrate) + elif isinstance(self.substrate, str): + react_smiles = self.substrate + else: + raise ValueError(f"Unsupported substrate type: {type(self.substrate)}") + + reactor = SynReactor( + react_smiles, + template=self.template, + invert=self.invert, + strategy=self.strategy, + partial=self.partial, + implicit_temp=True, + explicit_h=False, + ) + raw_smarts: List[str] = reactor.smarts_list + + # Add radical wildcards if requested + if self.add_wildcard: + wc = [] + for s in raw_smarts: + try: + wc.append(RadWC.transform(s)) + except Exception as e: + print(e) + else: + wc = raw_smarts + # Clean fragments if requested + if self.clean_fragments: + self._results = [ + clean_wc(s, invert=False, max_frag=self.max_frag, wild_card=True) + for s in wc + ] + else: + self._results = wc + + return self + + @property + def smarts_list(self) -> List[str]: + """ + Product SMARTS results from the last fit() invocation. + + :returns: List of SMARTS strings. + :rtype: List[str] + """ + return self._results.copy() + + def __len__(self) -> int: + """ + Number of product SMARTS results. + """ + return len(self._results) + + def __getitem__(self, idx: int) -> str: + """ + Get the product SMARTS at index `idx`. + + :param idx: Index of desired SMARTS. + :type idx: int + :returns: SMARTS string at position `idx`. + :rtype: str + :raises IndexError: If idx is out of bounds. + """ + return self._results[idx] + + def to_list(self) -> List[str]: + """ + Return all product SMARTS as a list. + + :returns: List of SMARTS strings. + :rtype: List[str] + """ + return self._results.copy() diff --git a/synkit/Synthesis/Reactor/syn_reactor.py b/synkit/Synthesis/Reactor/syn_reactor.py index 7ba82f8..cd9e0f9 100644 --- a/synkit/Synthesis/Reactor/syn_reactor.py +++ b/synkit/Synthesis/Reactor/syn_reactor.py @@ -80,14 +80,14 @@ class SynReactor: :vartype _graph: Optional[SynGraph] :ivar _rule: Cached SynRule for the template. :vartype _rule: Optional[SynRule] - :ivar _mappings: Cached list of subgraph‐mapping dicts. + :ivar _mappings: Cached list of subgraph-mapping dicts. :vartype _mappings: Optional[List[MappingDict]] :ivar _its: Cached list of ITS graphs. :vartype _its: Optional[List[nx.Graph]] :ivar _smarts: Cached list of SMARTS strings. :vartype _smarts: Optional[List[str]] :ivar _flag_pattern_has_explicit_H: Internal flag indicating - explicit‑H constraints. + explicit-H constraints. :vartype _flag_pattern_has_explicit_H: bool """ diff --git a/synkit/Vis/graph_visualizer.py b/synkit/Vis/graph_visualizer.py index f197c49..4030548 100644 --- a/synkit/Vis/graph_visualizer.py +++ b/synkit/Vis/graph_visualizer.py @@ -122,6 +122,9 @@ def plot_its( font_size: int = 12, og: bool = False, rule: bool = False, + title_font_size: str = 20, + title_font_weight: str = "bold", + title_font_style: str = "italic", ) -> None: # --- original implementation preserved verbatim ------------------ ax.clear() @@ -130,7 +133,12 @@ def plot_its( ax.axis("equal") ax.axis("off") if title: - ax.set_title(title) + ax.set_title( + title, + fontsize=title_font_size, + fontweight=title_font_weight, + fontstyle=title_font_style, + ) if use_edge_color: edge_colors = [ ( diff --git a/synkit/Vis/rule_vis.py b/synkit/Vis/rule_vis.py index 8a630f5..940f75c 100644 --- a/synkit/Vis/rule_vis.py +++ b/synkit/Vis/rule_vis.py @@ -40,7 +40,7 @@ def vis(self, input: Union[str, Tuple[nx.Graph, nx.Graph, nx.Graph]], **kwargs): if input.strip().startswith("graph [") or "rule [" in input: return self.mod_vis(input, **kwargs) else: - r, p = rsmi_to_graph(input, light_weight=True) + r, p = rsmi_to_graph(input) its = ITSConstruction().ITSGraph(r, p) gml_str = smart_to_gml(input, core=False, sanitize=False) return self.mod_vis(gml_str, **kwargs) @@ -72,7 +72,7 @@ def nx_vis( try: # 1) Parse input if isinstance(input, str): - r, p = rsmi_to_graph(input, light_weight=True, sanitize=sanitize) + r, p = rsmi_to_graph(input, sanitize=sanitize) its = ITSConstruction().ITSGraph(r, p) elif isinstance(input, tuple) and len(input) == 3: r, p, its = input diff --git a/synkit/examples.py b/synkit/examples.py new file mode 100644 index 0000000..f94dd07 --- /dev/null +++ b/synkit/examples.py @@ -0,0 +1,50 @@ +# synkit/examples.py + +from __future__ import annotations + +from functools import lru_cache +from importlib.resources import files, as_file + +from synkit.IO import load_database # adjust if the import path differs + + +def _sanitize_slug(slug: str) -> str: + if ".." in slug or slug.startswith(("/", "\\")): + raise ValueError(f"Invalid example name: {slug!r}") + return slug + + +def list_examples() -> list[str]: + """ + Returns all available example slugs (without extensions), e.g. ['paracetamol', ...] + """ + data_dir = files("synkit").joinpath("Data") + slugs: set[str] = set() + for p in data_dir.iterdir(): + if not p.is_file(): + continue + if p.name.endswith(".json.gz"): + slugs.add(p.name[: -len(".json.gz")]) + elif p.name.endswith(".json"): + slugs.add(p.name[: -len(".json")]) + return sorted(slugs) + + +@lru_cache(maxsize=32) +def load_example(slug: str): + """ + Load an example by slug, preferring compressed (.json.gz) over plain (.json). + Delegates actual parsing to synkit.IO.load_database. + """ + slug = _sanitize_slug(slug) + data_dir = files("synkit").joinpath("Data") + candidates = [f"{slug}.json.gz", f"{slug}.json"] + for name in candidates: + resource = data_dir.joinpath(name) + if resource.is_file(): + # importlib.resources may keep resources inside archives; as_file gives a real path + with as_file(resource) as real_path: + return load_database(real_path) + raise FileNotFoundError( + f"Example '{slug}' not found (looked for {', '.join(candidates)})" + )