From cceb0ff5042b8803ed46a1fe2b55a541ea5140c6 Mon Sep 17 00:00:00 2001 From: Tieu Long Phan <125431507+TieuLongPhan@users.noreply.github.com> Date: Fri, 18 Jul 2025 12:37:42 +0200 Subject: [PATCH 1/5] Refractor (#26) * update refractor * fix reactor * fix lint * prepare bechmark * refractor MODReactor and SynReactor * update crn * refractor cluster, change to matcher, fix code synreactor, now resnow comparable to modreactor * test 3 os * test * test * fix workflow * fix lint * fix win * fix win * fix win again * update smart * add synreactor implicit hydrogen * fix mcsmatcher * refractor visualization * fix conflict rdkit, upgrade to 2025.3.1 * fix lint * move aam_validator to Chem submodule * fix lint * prepare benchmark matcher * change backend rule to mod * prepare doc * add doc * update fih * update graph module doc * update doc * prepare release * . * fix doc * clean doc * fix docstring * fix tutorial * update fig * update explicit_hydrogen for its * prepare release * build doc * fix lint * fix bug in explicit hydrogen for ITS * fix build * fix * fix * fix * fix doc * update nauty canon, rule filters, change benchmark * prepare release * fix bug in nauty alg * update doc * add features for expanding its * add rule_matcher.py * add testcase rule matcher * add wildcard for smiles * add partial engine * update new features * update document * add data * update Chem features * format docstring and refractor Chem module * add auto-test pypi * create dependabot * test run yml * test * test docker * add docker * add docker * . * add readme * rename * release docker * remove redundant file --- .github/dependabot.yml | 10 + .github/workflows/docker-publish.yml | 46 + .github/workflows/verify-pypi-install.yml | 59 ++ .gitignore | 1 + Dockerfile | 45 + Makefile | 20 - README.md | 34 +- Test/Chem/Fingerprint/test_fp_calculator.py | 98 ++- .../Fingerprint/test_transformation_fp.py | 8 +- Test/Chem/Reaction/test_aam_utils.py | 52 -- Test/Chem/Reaction/test_cleanning.py | 4 +- .../test_rsmi_utils.py => test_utils.py} | 49 +- doc/api.rst | 16 + doc/conf.py | 9 +- doc/getting_started.rst | 34 + make.bat | 35 - pyproject.toml | 3 +- synkit/Chem/Cluster/__init__.py | 0 synkit/Chem/Cluster/butina.py | 139 +++ synkit/Chem/Fingerprint/fp_calculator.py | 231 +++-- synkit/Chem/Fingerprint/smiles_featurizer.py | 295 ++++--- synkit/Chem/Fingerprint/transformation_fp.py | 158 ++-- synkit/Chem/Molecule/standardize.py | 197 ++--- synkit/Chem/Reaction/__init__.py | 7 +- synkit/Chem/Reaction/aam_utils.py | 97 --- synkit/Chem/Reaction/aam_validator.py | 53 +- synkit/Chem/Reaction/balance_check.py | 190 ++-- synkit/Chem/Reaction/canon_rsmi.py | 26 +- synkit/Chem/Reaction/cleaning.py | 66 ++ synkit/Chem/Reaction/cleanning.py | 67 -- synkit/Chem/Reaction/deionize.py | 342 +++----- synkit/Chem/Reaction/fix_aam.py | 105 +-- synkit/Chem/Reaction/neutralize.py | 325 +++---- synkit/Chem/Reaction/radical_wildcard.py | 45 +- synkit/Chem/Reaction/rsmi_utils.py | 126 --- synkit/Chem/Reaction/standardize.py | 90 +- synkit/Chem/Reaction/tautomerize.py | 164 ++-- synkit/Chem/utils.py | 363 +++++--- synkit/Data/gen_partial_aam.py | 13 +- synkit/Graph/Canon/canon_algs.py | 21 +- synkit/Graph/Canon/canon_graph.py | 34 +- synkit/Graph/Canon/nauty.py | 34 +- synkit/Graph/Context/hier_context.py | 42 +- synkit/Graph/Context/radius_expand.py | 50 +- synkit/Graph/Feature/graph_descriptors.py | 46 +- synkit/Graph/Feature/graph_fps.py | 14 +- synkit/Graph/Feature/graph_signature.py | 39 +- synkit/Graph/Feature/hash_fps.py | 24 +- synkit/Graph/Feature/morgan_fps.py | 18 +- synkit/Graph/Feature/path_fps.py | 15 +- synkit/Graph/Feature/wl_hash.py | 22 +- synkit/Graph/Hyrogen/_misc.py | 48 +- synkit/Graph/Hyrogen/hcomplete.py | 30 +- synkit/Graph/Hyrogen/hextend.py | 17 +- synkit/Graph/ITS/its_builder.py | 20 +- synkit/Graph/ITS/its_construction.py | 33 +- synkit/Graph/ITS/its_decompose.py | 29 +- synkit/Graph/ITS/its_expand.py | 24 +- synkit/Graph/ITS/its_relabel.py | 37 +- synkit/Graph/ITS/normalize_aam.py | 33 +- synkit/Graph/MTG/group_comp.py | 3 +- synkit/Graph/MTG/groupoid.py | 7 +- synkit/Graph/MTG/mcs_matcher.py | 3 +- synkit/Graph/MTG/mtg.py | 3 +- synkit/Graph/Matcher/batch_cluster.py | 24 +- synkit/Graph/Matcher/graph_cluster.py | 21 +- synkit/Graph/Matcher/graph_morphism.py | 808 ++++++++++-------- synkit/Graph/Matcher/multi_turbo_iso.py | 22 +- synkit/Graph/Matcher/sing.py | 22 +- synkit/Graph/Matcher/subgraph_matcher.py | 36 +- synkit/Graph/Matcher/turbo_iso.py | 4 +- synkit/Graph/Wildcard/fuse_graph.py | 12 +- synkit/Graph/canon_graph.py | 34 +- synkit/Graph/syn_graph.py | 31 +- synkit/Graph/utils.py | 16 +- synkit/IO/chem_converter.py | 139 +-- synkit/IO/data_io.py | 109 +-- synkit/IO/data_process.py | 7 +- synkit/IO/debug.py | 22 +- synkit/IO/dg_to_gml.py | 5 +- synkit/IO/gml_to_nx.py | 39 +- synkit/IO/graph_to_mol.py | 61 +- synkit/IO/mol_to_graph.py | 79 +- synkit/IO/nx_to_gml.py | 70 +- synkit/IO/smiles_to_id.py | 11 +- synkit/Rule/Apply/reactor_rule.py | 21 +- synkit/Rule/Apply/retro_reactor.py | 20 +- synkit/Rule/Apply/rule_matcher.py | 40 +- synkit/Rule/Apply/rule_rbl.py | 6 +- synkit/Rule/Compose/compose_rule.py | 34 +- synkit/Rule/Compose/rule_compose.py | 28 +- synkit/Rule/Compose/rule_mapping.py | 29 +- synkit/Rule/Compose/seq_comp.py | 19 +- synkit/Rule/Compose/valence_constrain.py | 14 +- synkit/Rule/Modify/implict_rule.py | 5 +- synkit/Rule/Modify/longest_path.py | 13 +- synkit/Rule/Modify/molecule_rule.py | 23 +- synkit/Rule/Modify/prune_templates.py | 12 +- synkit/Rule/Modify/rule_utils.py | 36 +- synkit/Rule/Modify/strip_rule.py | 17 +- synkit/Rule/syn_rule.py | 8 +- synkit/Synthesis/CRN/crn.py | 16 +- synkit/Synthesis/CRN/dcrn.py | 32 +- synkit/Synthesis/CRN/mod_crn.py | 17 +- synkit/Synthesis/MSR/multi_steps.py | 16 +- synkit/Synthesis/MSR/path_finder.py | 40 +- synkit/Synthesis/Metrics/_base.py | 3 +- synkit/Synthesis/Metrics/_plot.py | 11 +- synkit/Synthesis/Metrics/_ranking.py | 41 +- synkit/Synthesis/Reactor/batch_reactor.py | 98 +-- synkit/Synthesis/Reactor/core_engine.py | 212 ----- synkit/Synthesis/Reactor/mod_aam.py | 12 +- synkit/Synthesis/Reactor/mod_reactor.py | 56 +- synkit/Synthesis/Reactor/partial_engine.py | 19 +- synkit/Synthesis/Reactor/rbl_engine.py | 41 +- synkit/Synthesis/Reactor/rule_filter.py | 56 +- synkit/Synthesis/Reactor/single_predictor.py | 24 +- synkit/Synthesis/Reactor/strategy.py | 14 +- synkit/Synthesis/Reactor/syn_reactor.py | 62 +- synkit/Synthesis/reactor_utils.py | 39 +- synkit/Utils/utils.py | 21 +- synkit/Vis/embedding.py | 25 +- synkit/Vis/graph_visualizer.py | 12 +- synkit/Vis/pdf_writer.py | 18 +- synkit/Vis/rule_vis.py | 22 +- synkit/Vis/rxn_vis.py | 12 +- 126 files changed, 3453 insertions(+), 3631 deletions(-) create mode 100644 .github/dependabot.yml create mode 100644 .github/workflows/docker-publish.yml create mode 100644 .github/workflows/verify-pypi-install.yml create mode 100644 Dockerfile delete mode 100644 Makefile delete mode 100644 Test/Chem/Reaction/test_aam_utils.py rename Test/Chem/{Reaction/test_rsmi_utils.py => test_utils.py} (57%) delete mode 100644 make.bat create mode 100644 synkit/Chem/Cluster/__init__.py create mode 100644 synkit/Chem/Cluster/butina.py delete mode 100644 synkit/Chem/Reaction/aam_utils.py create mode 100644 synkit/Chem/Reaction/cleaning.py delete mode 100644 synkit/Chem/Reaction/cleanning.py delete mode 100644 synkit/Chem/Reaction/rsmi_utils.py delete mode 100644 synkit/Synthesis/Reactor/core_engine.py diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..c0c0eb1 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,10 @@ +# .github/dependabot.yml +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" # location of requirements.txt or pyproject.toml + target-branch: "staging" # open PRs against staging instead of main + schedule: + interval: "weekly" # check for updates once a week + open-pull-requests-limit: 5 # max concurrent Dependabot PRs + rebase-strategy: "auto" # auto-rebase PRs when they fall out of date diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml new file mode 100644 index 0000000..363e899 --- /dev/null +++ b/.github/workflows/docker-publish.yml @@ -0,0 +1,46 @@ +# .github/workflows/docker-publish.yml +name: Publish SynKit Docker Package + +on: + push: + # Fire on semver tags for real releases… + tags: + - 'v*.*.*' + # …and on any push to the refractor branch for testing + branches: + - 'staging' + +jobs: + build-and-push: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + id-token: write + + steps: + - name: Check out repository + uses: actions/checkout@v3 + + - name: Set up QEMU (optional, for multi‑arch) + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + + - name: Log in to Docker Hub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USER }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and push Docker image + uses: docker/build-push-action@v4 + with: + context: . + file: Dockerfile + platforms: linux/amd64,linux/arm64 + push: true + tags: | + tieulongphan/synkit:${{ github.ref_name }} + tieulongphan/synkit:latest diff --git a/.github/workflows/verify-pypi-install.yml b/.github/workflows/verify-pypi-install.yml new file mode 100644 index 0000000..55e3e2e --- /dev/null +++ b/.github/workflows/verify-pypi-install.yml @@ -0,0 +1,59 @@ +# .github/workflows/verify-synkit-pypi-install.yml +name: Verify SynKit PyPI install + +on: + workflow_dispatch: + inputs: + branches: + type: string + required: true + default: refractor + + # Scheduled test every Monday at 03:00 UTC + schedule: + - cron: '0 3 * * 1' + +jobs: + verify: + runs-on: ubuntu-latest + + steps: + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + + - name: Create & activate virtualenv, upgrade pip, install SynKit + run: | + python -m venv venv + source venv/bin/activate + python -m pip install --upgrade pip + pip install synkit[all] + + - name: Show installed SynKit version + run: | + source venv/bin/activate + python -c "import importlib.metadata as m; print('SynKit version:', m.version('synkit'))" + + - name: Write smoke-test script + run: | + cat << 'EOF' > test_synkit.py + from synkit.IO import rsmi_to_rsmarts + + template = ( + '[C:2]=[O:3].[C:4]([H:7])[H:8]' + '>>' + '[C:2]=[C:4].[O:3]([H:7])[H:8]' + ) + + smart = rsmi_to_rsmarts(template) + print("Reaction SMARTS:", smart) + EOF + + - name: Run smoke-test + run: | + source venv/bin/activate + python test_synkit.py + + - name: Success message + run: echo "✅ synkit[all] installed and smoke-test passed" diff --git a/.gitignore b/.gitignore index c85f0f8..07ad7b7 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ Data/Benchmark/* run.sh docs/* run_rdcanon.py +Data/Fragment/* diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..7bfb459 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,45 @@ +############################################ +# STAGE 1: Build your package wheel +############################################ +FROM python:3.11-slim AS builder + +# 1. Install system build tools (for any C extensions) +RUN apt-get update \ + && apt-get install -y --no-install-recommends build-essential \ + && rm -rf /var/lib/apt/lists/* + +# 2. Upgrade pip/setuptools/wheel and install PEP 517 tooling + Hatchling backend +RUN pip install --upgrade pip setuptools wheel \ + && pip install --no-cache-dir build hatchling + +# 3. Set working directory inside builder +WORKDIR /build + +# 4. Copy project metadata (including README so Hatchling can find it) +COPY pyproject.toml README.md ./ +# If you have a lockfile, uncomment: +# COPY poetry.lock ./ + +# 5. Copy your library source +COPY synkit/ ./synkit + +# 6. Build the wheel +RUN python -m build --wheel --no-isolation + +############################################ +# STAGE 2: Create the “release” image +############################################ +FROM python:3.11-slim + +# 7. Set a clean workdir +WORKDIR /opt/synkit + +# 8. Copy in the built wheel from the builder stage +COPY --from=builder /build/dist/*.whl ./ + +# 9. Install your package (and its dependencies), then remove the wheel +RUN pip install --no-cache-dir *.whl \ + && rm *.whl + +# 10. Sanity check: print the installed synkit version +CMD ["python", "-c", "import importlib.metadata as m; print(m.version('synkit'))"] diff --git a/Makefile b/Makefile deleted file mode 100644 index 269cadc..0000000 --- a/Makefile +++ /dev/null @@ -1,20 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build -SOURCEDIR = source -BUILDDIR = build - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/README.md b/README.md index d33bd0d..47d57fe 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,14 @@ Our tools are tailored to assist researchers and chemists in navigating complex For more details on each utility within the repository, please refer to the documentation provided in the respective folders. -## Step-by-Step Installation Guide +## Table of Contents +- [Installation](#installation) +- [Contribute to `SynKit`](#contribute) +- [Contributing](#contributing) +- [License](#license) +- [Acknowledgments](#acknowledgments) + +## Installation 1. **Python Installation:** Ensure that Python 3.11 or later is installed on your system. You can download it from [python.org](https://www.python.org/downloads/). @@ -41,7 +48,20 @@ For more details on each utility within the repository, please refer to the docu pip install synkit[all] ``` -## For contributors +4. **Install via Docker** + Pull the image: + + ```bash + docker pull tieulongphan/synkit:latest + # or a specific version: + docker pull tieulongphan/synkit:0.1.0 + ``` + Run a container (sanity check): + ``` + docker run --rm tieulongphan/synkit:latest + ``` + +## Contribute We're welcoming new contributors to build this project better. Please not hesitate to inquire me via [email][tieu@bioinf.uni-leipzig.de]. @@ -52,7 +72,7 @@ git checkout main git pull ``` -## Working on New Features +### Working on New Features 1. **Create a New Branch**: For every new feature or bug fix, create a new branch from the `main` branch. Name your branch meaningfully, related to the feature or fix you are working on. @@ -78,7 +98,7 @@ git pull Fix any issues or errors highlighted by these checks. -## Integrating Changes +### Integrating Changes 1. **Rebase onto Staging**: Once your feature is complete and tests pass, rebase your changes onto the `staging` branch to prepare for integration. @@ -98,16 +118,16 @@ git pull ``` 3. **Create a Pull Request**: - Open a pull request from your feature branch to the `stagging` branch. Ensure the pull request description clearly describes the changes and any additional context necessary for review. + Open a pull request from your feature branch to the `staging` branch. Ensure the pull request description clearly describes the changes and any additional context necessary for review. ## Contributing - [Tieu-Long Phan](https://tieulongphan.github.io/) - [Klaus Weinbauer](https://github.com/klausweinbauer) - [Phuoc-Chung Nguyen Van](https://github.com/phuocchung123) -## Deployment timeline +## Publication -We plan to update new version quarterly. +[**SynKit**: An Advanced Cheminformatics Python Library for Efficient Manipulation and Analysis of Chemical Reaction Data]() ## License diff --git a/Test/Chem/Fingerprint/test_fp_calculator.py b/Test/Chem/Fingerprint/test_fp_calculator.py index c98babc..47bcb3e 100644 --- a/Test/Chem/Fingerprint/test_fp_calculator.py +++ b/Test/Chem/Fingerprint/test_fp_calculator.py @@ -1,58 +1,68 @@ +import io import unittest +from contextlib import redirect_stdout + from synkit.Chem.Fingerprint.fp_calculator import FPCalculator class TestFPCalculator(unittest.TestCase): def setUp(self): - # Sample data setup - self.data = [ - { - "smiles": [ - ( - "C1CCCCC1.CCO.CS(=O)(=O)N1CCN(Cc2ccccc2)CC1.[OH-].[OH-].[Pd+2]" - + ">>CS(=O)(=O)N1CCNCC1" - ), - ( - "CCOC(C)=O.Cc1cc([N+](=O)[O-])ccc1NC(=O)c1ccccc1.Cl[Sn]Cl.O.O.O=C([O-])O.[Na+]" - + ">>Cc1cc(N)ccc1NC(=O)c1ccccc1" - ), - ( - "COc1ccc(-c2coc3ccc(-c4nnc(S)o4)cc23)cc1.COc1ccc(CCl)cc1F" - + ">>COc1ccc(-c2coc3ccc(-c4nnc(SCc5ccc(OC)c(F)c5)o4)cc23)cc1" - ), - ], - "ID": [1, 2, 3], - } + # Sample single reaction dict + self.single = {"rsmi": "CCO>>CC=O"} + # List of dicts for parallel + self.batch = [ + {"rsmi": "CCO>>CC=O"}, + {"rsmi": "CC(Cl)C>>CCCl"}, ] - self.smiles_key = "smiles" - self.fp_type = "drfp" - self.n_jobs = 2 - self.verbose = 0 - - # Instantiate the FPCalculator - self.fp_calculator = FPCalculator( - smiles_key=self.smiles_key, - fp_type=self.fp_type, - n_jobs=self.n_jobs, - verbose=self.verbose, - ) + self.rsmi_key = "rsmi" + self.fp_type = "ecfp4" + self.calc = FPCalculator(n_jobs=2, verbose=0) + + def test_constructor_assigns_attributes(self): + self.assertEqual(self.calc.n_jobs, 2) + self.assertEqual(self.calc.verbose, 0) - def test_init_invalid_fp_type(self): + def test_validate_fp_type_accepts_supported(self): + # Should not raise + for ft in FPCalculator.VALID_FP_TYPES: + self.calc._validate_fp_type(ft) + + def test_validate_fp_type_rejects_unsupported(self): with self.assertRaises(ValueError): - FPCalculator(smiles_key=self.smiles_key, fp_type="invalid_type") + self.calc._validate_fp_type("invalid_fp") - def test_fit_missing_column(self): + def test_dict_process_missing_key_raises(self): with self.assertRaises(ValueError): - fp_calculator = FPCalculator( - smiles_key=self.smiles_key, - fp_type=self.fp_type, - ) - fp_calculator.dict_process({"not_smiles": ["C"]}, "smiles") - - def test_constructor_and_attribute_assignment(self): - self.assertEqual(self.fp_calculator.smiles_key, "smiles") - self.assertEqual(self.fp_calculator.fp_type, "drfp") - self.assertEqual(self.fp_calculator.n_jobs, 2) + FPCalculator.dict_process({}, self.rsmi_key, fp_type=self.fp_type) + + def test_dict_process_adds_fingerprint(self): + data = {"rsmi": "CCO>>CC=O"} + out = FPCalculator.dict_process(data, "rsmi", fp_type="ecfp4") + self.assertIn("ecfp4", out) + # Check it's a list/vector (not None) + self.assertIsNotNone(out["ecfp4"]) + + def test_parallel_process_returns_list_of_dicts(self): + results = self.calc.parallel_process(self.batch, "rsmi", fp_type="ecfp4") + self.assertIsInstance(results, list) + self.assertEqual(len(results), 2) + for d in results: + self.assertIn("ecfp4", d) + + def test_str_and_help_output(self): + s = str(self.calc) + self.assertIn("FPCalculator", s) + buf = io.StringIO() + with redirect_stdout(buf): + self.calc.help() + help_out = buf.getvalue() + + # The help text starts with this exact line + self.assertIn( + "FPCalculator supports the following fingerprint types:", help_out + ) + # And lists our parallel jobs config + self.assertIn(f"Configured for {self.calc.n_jobs} parallel jobs", help_out) if __name__ == "__main__": diff --git a/Test/Chem/Fingerprint/test_transformation_fp.py b/Test/Chem/Fingerprint/test_transformation_fp.py index aad0552..8332c6e 100644 --- a/Test/Chem/Fingerprint/test_transformation_fp.py +++ b/Test/Chem/Fingerprint/test_transformation_fp.py @@ -22,13 +22,13 @@ def test_fit(self): abs_val = True # Test with return_array=True - reaction_fp_array = TransformationFP.fit( + reaction_fp_array = TransformationFP().fit( reaction_smiles, symbols, fp_type, abs_val ) self.assertIsInstance(reaction_fp_array, np.ndarray) # Test with return_array=False - reaction_fp_bitvect = TransformationFP.fit( + reaction_fp_bitvect = TransformationFP().fit( reaction_smiles, symbols, fp_type, abs_val, return_array=False ) self.assertIsInstance(reaction_fp_bitvect, cDataStructs.ExplicitBitVect) @@ -40,7 +40,7 @@ def test_fit_invalid_smiles(self): fp_type = "maccs" abs_val = True with self.assertRaises(Exception): - _ = TransformationFP.fit(reaction_smiles, symbols, fp_type, abs_val) + _ = TransformationFP().fit(reaction_smiles, symbols, fp_type, abs_val) def test_fit_reaction_split(self): """Test handling of SMILES split by symbols and impact on results""" @@ -48,7 +48,7 @@ def test_fit_reaction_split(self): symbols = ">>" fp_type = "maccs" abs_val = False # without taking absolute values - reaction_fp = TransformationFP.fit(reaction_smiles, symbols, fp_type, abs_val) + reaction_fp = TransformationFP().fit(reaction_smiles, symbols, fp_type, abs_val) self.assertIsInstance(reaction_fp, np.ndarray) diff --git a/Test/Chem/Reaction/test_aam_utils.py b/Test/Chem/Reaction/test_aam_utils.py deleted file mode 100644 index 1baf9f0..0000000 --- a/Test/Chem/Reaction/test_aam_utils.py +++ /dev/null @@ -1,52 +0,0 @@ -import unittest -from rdkit import Chem -from synkit.Chem.Reaction.aam_utils import enumerate_tautomers, mapping_success_rate - - -class TestChemUtils(unittest.TestCase): - def test_enumerate_tautomers_simple(self): - # A simple keto-enol tautomerism: acetylacetone (CC(=O)CC=O) -> same product - reaction = "CC(=O)CC=O>>O" - tautomers = enumerate_tautomers(reaction) - # Should return a list with at least the original reaction - self.assertIsInstance(tautomers, list) - self.assertIn(reaction, tautomers) - # Each entry should be a valid reaction SMILES - for rsmi in tautomers: - self.assertIsInstance(rsmi, str) - parts = rsmi.split(">>") - self.assertEqual(len(parts), 2) - # Reactant and product part parseable by RDKit - self.assertIsNotNone(Chem.MolFromSmiles(parts[0])) - self.assertIsNotNone(Chem.MolFromSmiles(parts[1])) - - def test_enumerate_tautomers_invalid(self): - # Invalid SMILES input - bad = "INVALID>>SMILES" - result = enumerate_tautomers(bad) - # Should return list with original - self.assertEqual(result, [bad]) - - def test_mapping_success_rate_normal(self): - data = ["C:1CC", "CCC", "O:3=O", ":5", "N"] - rate = mapping_success_rate(data) - # Entries with mapping: 'C:1CC', 'O:3=O', ':5' => 3/5 = 60.0% - self.assertEqual(rate, 60.0) - - def test_mapping_success_rate_empty(self): - with self.assertRaises(ValueError): - mapping_success_rate([]) - - def test_mapping_success_rate_all(self): - data = [":1C", ":2", "N:3"] - rate = mapping_success_rate(data) - self.assertEqual(rate, 100.0) - - def test_mapping_success_rate_none(self): - data = ["C", "O", "N"] - rate = mapping_success_rate(data) - self.assertEqual(rate, 0.0) - - -if __name__ == "__main__": - unittest.main() diff --git a/Test/Chem/Reaction/test_cleanning.py b/Test/Chem/Reaction/test_cleanning.py index 0f902ca..d065fe3 100644 --- a/Test/Chem/Reaction/test_cleanning.py +++ b/Test/Chem/Reaction/test_cleanning.py @@ -1,11 +1,11 @@ import unittest -from synkit.Chem.Reaction.cleanning import Cleanning +from synkit.Chem.Reaction.cleaning import Cleaning class TestCleaning(unittest.TestCase): def setUp(self): - self.cleaner = Cleanning() + self.cleaner = Cleaning() def test_remove_duplicates(self): input_smiles = ["CC>>CC", "CC>>CC"] diff --git a/Test/Chem/Reaction/test_rsmi_utils.py b/Test/Chem/test_utils.py similarity index 57% rename from Test/Chem/Reaction/test_rsmi_utils.py rename to Test/Chem/test_utils.py index 2ffd091..98fe9e8 100644 --- a/Test/Chem/Reaction/test_rsmi_utils.py +++ b/Test/Chem/test_utils.py @@ -1,5 +1,8 @@ import unittest -from synkit.Chem.Reaction.rsmi_utils import ( +from rdkit import Chem +from synkit.Chem.utils import ( + enumerate_tautomers, + mapping_success_rate, remove_common_reagents, reverse_reaction, remove_duplicates, @@ -8,7 +11,49 @@ ) -class TestChemicalReactions(unittest.TestCase): +class TestChemUtils(unittest.TestCase): + def test_enumerate_tautomers_simple(self): + # A simple keto-enol tautomerism: acetylacetone (CC(=O)CC=O) -> same product + reaction = "CC(=O)CC=O>>O" + tautomers = enumerate_tautomers(reaction) + # Should return a list with at least the original reaction + self.assertIsInstance(tautomers, list) + self.assertIn(reaction, tautomers) + # Each entry should be a valid reaction SMILES + for rsmi in tautomers: + self.assertIsInstance(rsmi, str) + parts = rsmi.split(">>") + self.assertEqual(len(parts), 2) + # Reactant and product part parseable by RDKit + self.assertIsNotNone(Chem.MolFromSmiles(parts[0])) + self.assertIsNotNone(Chem.MolFromSmiles(parts[1])) + + def test_enumerate_tautomers_invalid(self): + # Invalid SMILES input should raise ValueError + bad = "INVALID>>SMILES" + with self.assertRaises(ValueError) as cm: + enumerate_tautomers(bad) + self.assertIn("Invalid reactant or product SMILES", str(cm.exception)) + + def test_mapping_success_rate_normal(self): + data = ["C:1CC", "CCC", "O:3=O", ":5", "N"] + rate = mapping_success_rate(data) + # Entries with mapping: 'C:1CC', 'O:3=O', ':5' => 3/5 = 60.0% + self.assertEqual(rate, 60.0) + + def test_mapping_success_rate_empty(self): + with self.assertRaises(ValueError): + mapping_success_rate([]) + + def test_mapping_success_rate_all(self): + data = [":1C", ":2", "N:3"] + rate = mapping_success_rate(data) + self.assertEqual(rate, 100.0) + + def test_mapping_success_rate_none(self): + data = ["C", "O", "N"] + rate = mapping_success_rate(data) + self.assertEqual(rate, 0.0) def test_remove_common_reagents_no_common(self): reaction = "A.B.C>>D.E.F" diff --git a/doc/api.rst b/doc/api.rst index 7c0af53..1e38a3b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -21,6 +21,22 @@ The `Chem` module provides tools for handling input and output operations relate :undoc-members: :show-inheritance: +.. automodule:: synkit.Chem.Reaction.balance_check + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: synkit.Chem.Fingerprint.fp_calculator + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: synkit.Chem.Cluster.butina + :members: + :undoc-members: + :show-inheritance: + + Synthesis Module ================ diff --git a/doc/conf.py b/doc/conf.py index 915c62a..509d687 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,5 +1,6 @@ import os import sys +import importlib.metadata as m from importlib.metadata import version as _get_version, PackageNotFoundError # -- Path setup -------------------------------------------------------------- @@ -15,12 +16,10 @@ release = _get_version("synkit") except PackageNotFoundError: try: - import synkit - - release = synkit.__version__ + release = m.version("synkit") except (ImportError, AttributeError): # Fallback default - release = "0.0.10" + release = "0.0.11" # Use only major.minor for short version version = ".".join(release.split(".")[:2]) @@ -46,4 +45,4 @@ # -- Options for HTML output ------------------------------------------------- html_theme = "sphinx_rtd_theme" -html_static_path = ["_static"] +html_static_path = ["_static"] \ No newline at end of file diff --git a/doc/getting_started.rst b/doc/getting_started.rst index 2ebef8e..d179319 100644 --- a/doc/getting_started.rst +++ b/doc/getting_started.rst @@ -1,5 +1,9 @@ .. _getting-started-synkit: +.. image:: https://img.shields.io/pypi/v/synkit.svg + :alt: PyPI version + :align: right + Getting Started =============== @@ -69,10 +73,40 @@ After installation, verify that **synkit** is available and check its version: python -c "import importlib.metadata as m; print(m.version('synkit'))" # Should print the installed synkit version +Docker Installation +------------------- + +Install **SynKit** using Docker. + +Pull the image: + +.. code-block:: bash + + docker pull tieulongphan/synkit:latest + +Run a quick version check: + +.. code-block:: bash + + docker run --rm tieulongphan/synkit:latest \ + python -c "import importlib.metadata as m; print(m.version('synkit'))" + + +Use as a base image in your own Dockerfile: + +.. code-block:: dockerfile + + FROM tieulongphan/synkit:latest + WORKDIR /app + COPY . . + CMD ["python", "your_script.py"] + + Further Resources ----------------- - Official documentation: `SynKit Docs `_ - Tutorials and examples: :doc:`Tutorials and Examples ` +- Support ------- diff --git a/make.bat b/make.bat deleted file mode 100644 index dafd057..0000000 --- a/make.bat +++ /dev/null @@ -1,35 +0,0 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set SOURCEDIR=source -set BUILDDIR=build - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.https://www.sphinx-doc.org/ - exit /b 1 -) - -if "%1" == "" goto help - -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% -goto end - -:help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% - -:end -popd \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b74faf1..688cc3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "synkit" -version = "0.0.10" +version = "0.0.11" authors = [ {name="Tieu Long Phan", email="tieu@bioinf.uni-leipzig.de"} ] @@ -35,6 +35,7 @@ docs = [ "sphinx-rtd-theme", "sphinxcontrib-bibtex", ] + [project.urls] homepage = "https://github.com/TieuLongPhan/SynKit" source = "https://github.com/TieuLongPhan/SynKit" diff --git a/synkit/Chem/Cluster/__init__.py b/synkit/Chem/Cluster/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/synkit/Chem/Cluster/butina.py b/synkit/Chem/Cluster/butina.py new file mode 100644 index 0000000..ef13fb1 --- /dev/null +++ b/synkit/Chem/Cluster/butina.py @@ -0,0 +1,139 @@ +from __future__ import annotations +from typing import List, Optional + +import numpy as np +from rdkit.DataStructs import cDataStructs, CreateFromBitString, BulkTanimotoSimilarity +from rdkit.ML.Cluster import Butina +from sklearn.manifold import TSNE +import matplotlib.pyplot as plt + + +class ButinaCluster: + """Cluster chemical fingerprint vectors using the Butina algorithm from + RDKit, with integrated t-SNE visualization of clusters. + + Key features + ------------ + * **Butina clustering** – fast hierarchical clustering with a similarity cutoff. + * **t-SNE visualization** – 2D embedding of fingerprints, highlighting top‑k clusters. + * **NumPy support** – accepts 2D arrays of 0/1 fingerprint data. + * **Configurable** – user‑defined cutoff, perplexity, and top‑k highlight. + + Quick start + ----------- + >>> from synkit.Chem.Fingerprint.fingerprint_clusterer import ButinaCluster + >>> clusters = ButinaCluster.cluster(arr, cutoff=0.3) + >>> ButinaCluster.visualize(arr, clusters, k=5) + """ + + @staticmethod + def cluster(arr: np.ndarray, cutoff: float = 0.2) -> List[List[int]]: + """Perform Butina clustering on fingerprint bit-vectors. + + :param arr: 2D array of shape (n_samples, n_bits) with 0/1 + dtype. + :type arr: np.ndarray + :param cutoff: Distance cutoff (1 – similarity) to form + clusters. Defaults to 0.2. + :type cutoff: float + :returns: List of clusters, each a list of sample indices. + :rtype: list of list of int + """ + # Convert rows to RDKit ExplicitBitVect + fps: List[cDataStructs.ExplicitBitVect] = [] + for row in arr: + bitstr = "".join(str(int(b)) for b in row.tolist()) + fps.append(CreateFromBitString(bitstr)) + + n = len(fps) + # Build flattened upper‐triangular distance list + distances: List[float] = [] + for i in range(n): + # fmt: off + sims = BulkTanimotoSimilarity(fps[i], fps[i + 1:]) + # fmt: on + distances.extend((1.0 - np.array(sims, dtype=float)).tolist()) + + # Cluster: ClusterData(distanceList, nPts, cutoff, isDistData) + clusters = Butina.ClusterData(distances, n, cutoff, True) + return clusters + + @staticmethod + def visualize( + arr: np.ndarray, + clusters: List[List[int]], + k: Optional[int] = None, + perplexity: float = 30.0, + random_state: int = 42, + ) -> None: + """Visualize clusters in 2D via t-SNE embedding. + + :param arr: 2D array of shape (n_samples, n_features) with fingerprint data. + :type arr: np.ndarray + :param clusters: Clusters as returned by `cluster()`. + :type clusters: list of list of int + :param k: If provided, highlight only the top‑k largest clusters; others shown as 'Other'. + :type k: int or None + :param perplexity: t-SNE perplexity parameter. Defaults to 30.0. + :type perplexity: float + :param random_state: Random seed for reproducibility. Defaults to 42. + :type random_state: int + :returns: None + :rtype: NoneType + + :example: + >>> clusters = ButinaCluster.cluster(arr, cutoff=0.3) + >>> ButinaCluster.visualize(arr, clusters, k=5) + """ + n = arr.shape[0] + # assign labels: cluster idx or -1 for 'Other' + labels = np.full(n, -1, dtype=int) + # sort clusters by size + sorted_idx = sorted( + range(len(clusters)), key=lambda i: len(clusters[i]), reverse=True + ) + top = set(sorted_idx[:k]) if k is not None else set(sorted_idx) + for idx, cluster in enumerate(clusters): + for i in cluster: + labels[i] = idx if idx in top else -1 + + # compute t-SNE embedding + tsne = TSNE(n_components=2, perplexity=perplexity, random_state=random_state) + emb = tsne.fit_transform(arr) + + # plot + plt.figure(figsize=(8, 6)) + unique = sorted(set(labels)) + for lab in unique: + mask = labels == lab + if lab == -1: + plt.scatter( + emb[mask, 0], emb[mask, 1], color="gray", alpha=0.3, label="Other" + ) + else: + plt.scatter( + emb[mask, 0], emb[mask, 1], alpha=0.7, label=f"Cluster {lab}" + ) + plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") + plt.title("t-SNE visualization of Butina clusters") + plt.xlabel("t-SNE dim 1") + plt.ylabel("t-SNE dim 2") + plt.tight_layout() + plt.show() + + def __str__(self) -> str: + """Short description of the clusterer. + + :returns: Class name. + :rtype: str + """ + return "" + + def help(self) -> None: + """Print usage summary for clustering and visualization. + + :returns: None + :rtype: NoneType + """ + print("ButinaCluster.cluster(arr, cutoff=0.2)") + print("ButinaCluster.visualize(arr, clusters, k=None, perplexity=30.0)") diff --git a/synkit/Chem/Fingerprint/fp_calculator.py b/synkit/Chem/Fingerprint/fp_calculator.py index d164b7b..dd1d56b 100644 --- a/synkit/Chem/Fingerprint/fp_calculator.py +++ b/synkit/Chem/Fingerprint/fp_calculator.py @@ -1,158 +1,155 @@ +from __future__ import annotations +from typing import Any, Dict, List from joblib import Parallel, delayed -from typing import Dict, List + from synkit.IO.debug import configure_warnings_and_logs from synkit.Chem.Fingerprint.transformation_fp import TransformationFP -# Configure warnings and logging configure_warnings_and_logs(True, True) class FPCalculator: + """Calculate fingerprint vectors for chemical reactions represented by + SMILES strings. + + :cvar fps: Shared fingerprint engine instance. + :vartype fps: TransformationFP + :cvar VALID_FP_TYPES: Supported fingerprint type identifiers. + :vartype VALID_FP_TYPES: List[str] + :param n_jobs: Number of parallel jobs to use for batch processing. + :type n_jobs: int + :param verbose: Verbosity level for parallel execution. + :type verbose: int """ - Class to calculate fingerprint vectors for chemical compounds represented by - SMILES strings. This class provides methods to process SMILES strings - into various types of fingerprint vectors, - either individually or in batches, and supports parallel processing. - - Attributes: - - smiles_key (str): Key in the dictionary corresponding to the SMILES string. - - fp_type (str): Type of fingerprint to calculate; supports various cheminformatics fingerprint types. - - n_jobs (int): Number of parallel jobs to run for performance enhancement. - - verbose (int): Verbosity level of parallel computation. - """ - - # Class-level instance to be used in static methods. - fps = TransformationFP() - def __init__( - self, - smiles_key: str, - fp_type: str, - n_jobs: int = 1, - verbose: int = 0, - ): - """ - Initialize the FPCalculator with specific settings for SMILES string processing and fingerprint generation. - - Parameters: - - smiles_key (str): The key in a dictionary corresponding to the SMILES string. - - fp_type (str): The type of fingerprint to generate. - - n_jobs (int): Number of parallel jobs. - Default is 1. - - verbose (int): Verbosity level for parallel processing. - Default is 0. + fps: TransformationFP = TransformationFP() + VALID_FP_TYPES: List[str] = [ + "drfp", + "avalon", + "maccs", + "torsion", + "pharm2D", + "ecfp2", + "ecfp4", + "ecfp6", + "fcfp2", + "fcfp4", + "fcfp6", + "rdk5", + "rdk6", + "rdk7", + "ap", + ] + + def __init__(self, n_jobs: int = 1, verbose: int = 0) -> None: + """Initialize the FPCalculator. + + :param n_jobs: Number of parallel jobs to use for fingerprint + computation. + :type n_jobs: int + :param verbose: Verbosity level for the parallel processing. + :type verbose: int """ - self.smiles_key = smiles_key - self.fp_type = fp_type self.n_jobs = n_jobs self.verbose = verbose - self._validate_fp_type(fp_type) def _validate_fp_type(self, fp_type: str) -> None: - """ - Validate if the provided fingerprint type is supported. + """Ensure the requested fingerprint type is supported. - Parameters: - - fp_type (str): The type of fingerprint to be validated. - - Raises: - ValueError: If the fingerprint type is not supported. + :param fp_type: Fingerprint type identifier to validate. + :type fp_type: str + :raises ValueError: If `fp_type` is not in VALID_FP_TYPES. """ - valid_fps = [ - "drfp", - "avalon", - "maccs", - "torsion", - "pharm2D", - "ecfp2", - "ecfp4", - "ecfp6", - "fcfp2", - "fcfp4", - "fcfp6", - "rdk5", - "rdk6", - "rdk7", - ] - if fp_type not in valid_fps: + if fp_type not in self.VALID_FP_TYPES: + valid = ", ".join(self.VALID_FP_TYPES) raise ValueError( - f"Unsupported fingerprint type '{fp_type}'. Currently supported: {', '.join(valid_fps)}." + f"Unsupported fingerprint type '{fp_type}'. Supported types: {valid}." ) @staticmethod def dict_process( - data_dict: Dict, + data_dict: Dict[str, Any], rsmi_key: str, symbol: str = ">>", - fp_type: str = "ap", + fp_type: str = "ecfp4", absolute: bool = True, - ) -> Dict: - """ - Convert a reaction SMILES string to a fingerprint vector based on the - specified fingerprint type. - - Parameters: - - data_dict (Dict): A dictionary containing reaction SMILES. - - rsmi_key (str): The key in the dictionary for the reaction SMILES. - - symbol (str): The symbol used to separate reactants and products. - Default is '>>'. - - fp_type (str): The type of fingerprint to generate. - Default is 'ap'. - - absolute (bool): Whether to use absolute values. - Default is True. - - Returns: - - Dict: The updated dictionary with the fingerprint added - under the key `fp_type`. - - Raises: - - ValueError: If an unsupported fingerprint type is specified or - the reaction SMILES key does not exist. + ) -> Dict[str, Any]: + """Compute a fingerprint for a single reaction SMILES entry and add it + to the dict. + + :param data_dict: Dictionary containing reaction data. + :type data_dict: dict + :param rsmi_key: Key in `data_dict` for the reaction SMILES string. + :type rsmi_key: str + :param symbol: Delimiter between reactant and product in the SMILES. + :type symbol: str + :param fp_type: Fingerprint type to compute. + :type fp_type: str + :param absolute: Whether to take absolute values of the fingerprint difference. + :type absolute: bool + :returns: The input dictionary with a new key `fp_{fp_type}` holding the fingerprint vector. + :rtype: dict + :raises ValueError: If `rsmi_key` is missing in `data_dict`. """ if rsmi_key not in data_dict: - raise ValueError(f"Key '{rsmi_key}' does not exist in the dictionary.") - data_dict[fp_type] = FPCalculator.fps.fit( + raise ValueError(f"Key '{rsmi_key}' not found in data dictionary.") + # compute and insert fingerprint + vec = FPCalculator.fps.fit( data_dict[rsmi_key], symbols=symbol, fp_type=fp_type, abs=absolute ) + data_dict[f"{fp_type}"] = vec return data_dict def parallel_process( self, - data_dicts: List[Dict], + data_dicts: List[Dict[str, Any]], rsmi_key: str, symbol: str = ">>", - fp_type: str = "ap", + fp_type: str = "ecfp4", absolute: bool = True, - ) -> List[Dict]: - """ - Convert a list of SMILES strings to fingerprint vectors in parallel - based on the specified fingerprint type. This method processes - multiple dictionaries containing SMILES strings simultaneously - using multiple workers. - - Parameters: - - data_dicts (List[Dict]): A list of dictionaries, each containing reaction data. - - rsmi_key (str): The key to access the reaction SMILES in each dictionary. - - symbol (str): The symbol used to separate reactants and products. - Default is '>>'. - - fp_type (str): The type of fingerprint to generate. - Default is 'ap'. - - absolute (bool): Whether to use absolute values. - Default is True. - - Returns: - - List[Dict]: A list of dictionaries with updated fingerprint data, - where each dictionary includes a fingerprint vector. - - Raises: - - ValueError: If an unsupported fingerprint type is specified or the - reaction SMILES key does not exist in any dictionary. + ) -> List[Dict[str, Any]]: + """Compute fingerprints for a batch of reaction dictionaries in + parallel. + + :param data_dicts: List of dictionaries, each containing a reaction SMILES. + :type data_dicts: list of dict + :param rsmi_key: Key in each dict for the reaction SMILES string. + :type rsmi_key: str + :param symbol: Delimiter between reactant and product in the SMILES. + :type symbol: str + :param fp_type: Fingerprint type to compute. + :type fp_type: str + :param absolute: Whether to take absolute values of the fingerprint difference. + :type absolute: bool + :returns: A list of dictionaries augmented with `fp_{fp_type}` entries. + :rtype: list of dict + :raises ValueError: If `fp_type` is unsupported or any dict is missing `rsmi_key`. """ + # Validate fingerprint type once + self._validate_fp_type(fp_type) + # Process in parallel results = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)( - delayed(FPCalculator.safe_dict_process)( - data_dict, rsmi_key, symbol, fp_type, absolute - ) - for data_dict in data_dicts + delayed(self.dict_process)(dd, rsmi_key, symbol, fp_type, absolute) + for dd in data_dicts ) return results + + def __str__(self) -> str: + """Short string summarizing the calculator configuration. + + :returns: A summary of n_jobs and verbosity. + :rtype: str + """ + return f"" + + def help(self) -> None: + """Print details about supported fingerprint types and usage. + + :returns: None + :rtype: NoneType + """ + print("FPCalculator supports the following fingerprint types:") + for t in self.VALID_FP_TYPES: + print(" -", t) + print(f"Configured for {self.n_jobs} parallel jobs, verbose={self.verbose}") diff --git a/synkit/Chem/Fingerprint/smiles_featurizer.py b/synkit/Chem/Fingerprint/smiles_featurizer.py index 2c450a5..93fe024 100644 --- a/synkit/Chem/Fingerprint/smiles_featurizer.py +++ b/synkit/Chem/Fingerprint/smiles_featurizer.py @@ -1,3 +1,24 @@ +"""smiles_featurizer.py +======================= +Utility for converting SMILES strings into various cheminformatics fingerprints, +with optional NumPy‐array conversion. + +Key features +------------ +* **Multi‐fingerprint support** – MACCS, Avalon, ECFP/FCFP, RDKit, AtomPair, Torsion, Pharm2D +* **SMILES validation** – raises on invalid input +* **Array conversion** – output as NumPy arrays for ML pipelines +* **Extensible** – add new methods or override via subclassing + +Quick start +----------- +>>> from synkit.Chem.Fingerprint.smiles_featurizer import SmilesFeaturizer +>>> arr = SmilesFeaturizer.featurize_smiles("CCO", "ecfp4", convert_to_array=True) +""" + +from __future__ import annotations +from typing import Any + import numpy as np from rdkit import Chem, DataStructs from rdkit.Chem import AllChem, MACCSkeys @@ -7,71 +28,86 @@ class SmilesFeaturizer: - def __init__(self): - """ - Initializes the SmilesFeaturizer class without any specific parameters for fingerprint generation. + """Convert SMILES strings into chemical fingerprint vectors. + + :cvar None: This class only provides static/​class methods and holds no state. + + Supported fingerprint methods: + - MACCS keys + - Avalon + - ECFP/FCFP (Morgan) + - RDKit topological + - AtomPair + - Torsion + - 2D Pharmacophore + + Use `featurize_smiles` for one‑line access. + """ + + def __init__(self) -> None: + """Initialize SmilesFeaturizer. + + This class has no instance state; all methods are static or + class‑level. """ pass @staticmethod def smiles_to_mol(smiles: str) -> Chem.Mol: - """ - Converts a SMILES string to an RDKit Mol object. - - Parameters: - - smiles (str): The SMILES string to be converted. + """Convert a SMILES string to an RDKit Mol object. - Returns: - - Chem.Mol: The corresponding RDKit Mol object. + :param smiles: The SMILES string to convert. + :type smiles: str + :returns: RDKit Mol object corresponding to the SMILES. + :rtype: Chem.Mol + :raises ValueError: If the SMILES string is invalid. """ mol = Chem.MolFromSmiles(smiles) if mol is None: - raise ValueError("Invalid SMILES string provided.") + raise ValueError(f"Invalid SMILES string: {smiles!r}") return mol @staticmethod - def get_maccs_keys(mol: Chem.Mol): - """ - Generates MACCS keys fingerprint from an RDKit Mol object. + def get_maccs_keys(mol: Chem.Mol) -> Any: + """Generate the MACCS keys fingerprint for a molecule. - Parameters: - - mol (Chem.Mol): The Mol object to be featurized. - - Returns: - - RDKit ExplicitBitVect: The MACCS keys fingerprint of the Mol object. + :param mol: RDKit Mol object. + :type mol: Chem.Mol + :returns: MACCS keys fingerprint bit vector. + :rtype: ExplicitBitVect """ return MACCSkeys.GenMACCSKeys(mol) @staticmethod - def get_avalon_fp(mol: Chem.Mol, nBits: int = 1024): - """ - Generates Avalon fingerprint from an RDKit Mol object. + def get_avalon_fp(mol: Chem.Mol, nBits: int = 1024) -> Any: + """Generate the Avalon fingerprint for a molecule. - Parameters: - - mol (Chem.Mol): The Mol object to be featurized. - - nBits (int): The number of bits in the generated fingerprint. - - Returns: - - RDKit ExplicitBitVect: The Avalon fingerprint of the Mol object. + :param mol: RDKit Mol object. + :type mol: Chem.Mol + :param nBits: Length of the fingerprint vector. + :type nBits: int + :returns: Avalon fingerprint bit vector. + :rtype: ExplicitBitVect """ return fpAvalon.GetAvalonFP(mol, nBits) @staticmethod def get_ecfp( mol: Chem.Mol, radius: int, nBits: int = 2048, useFeatures: bool = False - ): - """ - Generates Extended-Connectivity Fingerprints (ECFP) or - Feature-Class Fingerprints (FCFP) from an RDKit Mol object. - - Parameters: - - mol (Chem.Mol): The Mol object to be featurized. - - radius (int): The radius of the fingerprint. - - nBits (int): The number of bits in the generated fingerprint. - - useFeatures (bool): Whether to use atom features instead of atom identities. + ) -> Any: + """Generate a Morgan fingerprint (ECFP or FCFP) for a molecule. - Returns: - - RDKit ExplicitBitVect: The ECFP or FCFP fingerprint of the Mol object. + :param mol: RDKit Mol object. + :type mol: Chem.Mol + :param radius: Radius for the Morgan algorithm. + :type radius: int + :param nBits: Length of the fingerprint vector. + :type nBits: int + :param useFeatures: If True, generate a Feature‑Class + fingerprint (FCFP). + :type useFeatures: bool + :returns: Morgan fingerprint bit vector. + :rtype: ExplicitBitVect """ return AllChem.GetMorganFingerprintAsBitVect( mol, radius, nBits=nBits, useFeatures=useFeatures @@ -80,106 +116,143 @@ def get_ecfp( @staticmethod def get_rdk_fp( mol: Chem.Mol, maxPath: int, fpSize: int = 2048, nBitsPerHash: int = 2 - ): - """ - Generates RDKit fingerprint from an RDKit Mol object. - - Parameters: - - mol (Chem.Mol): The Mol object to be featurized. - - maxPath (int): The maximum path length (in bonds) to be included. - - fpSize (int): The size of the fingerprint. - - nBitsPerHash (int): The number of bits per hash. + ) -> Any: + """Generate an RDKit topological fingerprint for a molecule. - Returns: - - RDKit ExplicitBitVect: The RDKit fingerprint of the Mol object. + :param mol: RDKit Mol object. + :type mol: Chem.Mol + :param maxPath: Maximum path length (bonds) to include. + :type maxPath: int + :param fpSize: Length of the fingerprint vector. + :type fpSize: int + :param nBitsPerHash: Bits per hash for path hashing. + :type nBitsPerHash: int + :returns: RDKit topological fingerprint bit vector. + :rtype: ExplicitBitVect """ return Chem.RDKFingerprint( mol, maxPath=maxPath, fpSize=fpSize, nBitsPerHash=nBitsPerHash ) @staticmethod - def mol_to_ap(mol: Chem.Mol) -> np.ndarray: - """ - Generates an Atom Pair fingerprint as a NumPy array from an RDKit Mol object. - - Parameters: - - mol (Chem.Mol): The Mol object to be featurized. + def mol_to_ap(mol: Chem.Mol) -> Any: + """Generate an Atom Pair fingerprint for a molecule. - Returns: - - RDKit ExplicitBitVect: The RDKit fingerprint of the Mol object. + :param mol: RDKit Mol object. + :type mol: Chem.Mol + :returns: Atom Pair fingerprint as an integer vector. + :rtype: ExplicitBitVect """ return Pairs.GetAtomPairFingerprint(mol) @staticmethod - def mol_to_torsion(mol: Chem.Mol) -> np.ndarray: - """ - Generates a Topological Torsion fingerprint as a NumPy array from an RDKit Mol object. + def mol_to_torsion(mol: Chem.Mol) -> Any: + """Generate a Topological Torsion fingerprint for a molecule. - Parameters: - - mol (Chem.Mol): The Mol object to be featurized. - - Returns: - - RDKit ExplicitBitVect: The RDKit fingerprint of the Mol object. + :param mol: RDKit Mol object. + :type mol: Chem.Mol + :returns: Torsion fingerprint as an integer vector. + :rtype: ExplicitBitVect """ return Torsions.GetTopologicalTorsionFingerprintAsIntVect(mol) @staticmethod - def mol_to_pharm2d(mol: Chem.Mol) -> np.ndarray: - """ - Generates a 2D Pharmacophore fingerprint as a NumPy array from an RDKit Mol object. + def mol_to_pharm2d(mol: Chem.Mol) -> Any: + """Generate a 2D Pharmacophore fingerprint for a molecule. - Parameters: - - mol (Chem.Mol): The Mol object to be featurized. - - Returns: - - RDKit ExplicitBitVect: The RDKit fingerprint of the Mol object. + :param mol: RDKit Mol object. + :type mol: Chem.Mol + :returns: 2D pharmacophore fingerprint bit vector. + :rtype: ExplicitBitVect """ return Generate.Gen2DFingerprint(mol, Gobbi_Pharm2D.factory) @classmethod def featurize_smiles( - cls, smiles: str, fingerprint_type: str, convert_to_array: bool = True, **kwargs - ) -> np.ndarray: - """ - Featurizes a SMILES string into the specified type of fingerprint, optionally converting it to a NumPy array. - - Parameters: - - smiles (str): The SMILES string to be featurized. - - fingerprint_type (str): The type of fingerprint to generate. - - convert_to_array (bool): Whether to convert the fingerprint to a NumPy array. Defaults to True. - - **kwargs: Additional keyword arguments for the fingerprint function. + cls, + smiles: str, + fingerprint_type: str, + convert_to_array: bool = True, + **kwargs: Any, + ) -> Any: + """Featurize a SMILES string into a chosen fingerprint, optionally + converting to a NumPy array. - Returns: - - np.ndarray or RDKit ExplicitBitVect: The requested type of fingerprint for the SMILES string, - either as a NumPy array or as an RDKit bit vector, depending on `convert_to_array`. + :param smiles: The SMILES string to featurize. + :type smiles: str + :param fingerprint_type: One of 'maccs', 'avalon', 'ecfp#', 'fcfp#', + 'rdk#', 'ap', 'torsion', 'pharm2d'. + :type fingerprint_type: str + :param convert_to_array: If True, convert the result to a NumPy array. + :type convert_to_array: bool + :param kwargs: Additional parameters passed to the chosen method: + - `nBits` for Avalon/ECFP/FCFP + - `radius` for ECFP/FCFP + - `maxPath`, `fpSize`, `nBitsPerHash` for RDKit FP + :type kwargs: dict + :returns: Fingerprint as a NumPy array (if `convert_to_array`) or RDKit bit vector. + :rtype: np.ndarray or ExplicitBitVect + :raises ValueError: If `fingerprint_type` is unsupported. """ mol = cls.smiles_to_mol(smiles) - if fingerprint_type == "maccs": + + ft = fingerprint_type.lower() + if ft == "maccs": fp = cls.get_maccs_keys(mol) - elif fingerprint_type == "avalon": - fp = cls.get_avalon_fp(mol, **kwargs) - elif fingerprint_type.startswith("ecfp") or fingerprint_type.startswith("fcfp"): - radius = int(fingerprint_type[4]) - useFeatures = fingerprint_type.startswith("fcfp") - nBits = kwargs.get("nBits", 2048) - fp = cls.get_ecfp(mol, radius, nBits=nBits, useFeatures=useFeatures) - elif fingerprint_type.startswith("rdk"): - maxPath = int(fingerprint_type[3]) - fp = cls.get_rdk_fp(mol, maxPath, **kwargs) - elif fingerprint_type == "ap": + elif ft == "avalon": + fp = cls.get_avalon_fp(mol, nBits=kwargs.get("nBits", 1024)) + elif ft.startswith("ecfp") or ft.startswith("fcfp"): + radius = int(ft[4]) + use_features = ft.startswith("fcfp") + fp = cls.get_ecfp( + mol, + radius, + nBits=kwargs.get("nBits", 2048), + useFeatures=use_features, + ) + elif ft.startswith("rdk"): + max_path = int(ft[3]) + fp = cls.get_rdk_fp( + mol, + maxPath=max_path, + fpSize=kwargs.get("fpSize", 2048), + nBitsPerHash=kwargs.get("nBitsPerHash", 2), + ) + elif ft == "ap": fp = cls.mol_to_ap(mol) - elif fingerprint_type == "torsion": + elif ft == "torsion": fp = cls.mol_to_torsion(mol) - elif fingerprint_type == "pharm2d": + elif ft == "pharm2d": fp = cls.mol_to_pharm2d(mol) else: - raise ValueError(f"Unsupported fingerprint type: {fingerprint_type}") + raise ValueError(f"Unsupported fingerprint type: {fingerprint_type!r}") + if convert_to_array: - if fingerprint_type == "pharm2d": - return np.frombuffer(fp.ToBitString().encode(), "u1") - ord("0") - else: - ar = np.zeros((1,), dtype=np.int8) - DataStructs.ConvertToNumpyArray(fp, ar) - return ar - else: - return fp + if ft == "pharm2d": + bitstr = fp.ToBitString() + return np.array([int(b) for b in bitstr], dtype=np.int8) + arr = np.zeros((fp.GetNumBits(),), dtype=np.int8) + DataStructs.ConvertToNumpyArray(fp, arr) + return arr + + return fp + + def __str__(self) -> str: + """Short description of the featurizer. + + :returns: Class name. + :rtype: str + """ + return "" + + def help(self) -> None: + """Print supported fingerprint types and usage summary. + + :returns: None + :rtype: NoneType + """ + print("SmilesFeaturizer supports the following fingerprint types:") + print(" - maccs, avalon, ecfp#, fcfp#, rdk#, ap, torsion, pharm2d") + print( + "Usage: SmilesFeaturizer.featurize_smiles(smiles, fingerprint_type, **kwargs)" + ) diff --git a/synkit/Chem/Fingerprint/transformation_fp.py b/synkit/Chem/Fingerprint/transformation_fp.py index b05eded..991b0f7 100644 --- a/synkit/Chem/Fingerprint/transformation_fp.py +++ b/synkit/Chem/Fingerprint/transformation_fp.py @@ -1,39 +1,53 @@ +"""transformation_fp.py +======================= +Compute reaction‐level fingerprints by combining molecular fingerprints +of reactants and products, with optional absolute mode and bit‐vector conversion. + +Quick start +----------- +>>> from synkit.Chem.Fingerprint.transformation_fp import TransformationFP +>>> arr = TransformationFP().fit('CCO>>CC=O', symbols='>>', fp_type='ecfp4', abs=True) +>>> bv = TransformationFP().fit('CCO>>CC=O', symbols='>>', fp_type='ecfp4', abs=True, return_array=False) +""" + +from __future__ import annotations +from typing import Any, Union + import numpy as np -from typing import Union, Any from rdkit.DataStructs import cDataStructs + from synkit.Chem.Fingerprint.smiles_featurizer import SmilesFeaturizer class TransformationFP: - """ - A class for handling the transformation of chemical reactions into reaction fingerprints - based on SMILES strings. + """Calculate reaction fingerprints by featurizing individual molecules and + combining them via vector subtraction. + + :cvar None: Stateless utility class. """ def __init__(self) -> None: - """ - Initializes the TransformationFP object. Currently, this constructor does not - perform any operations. + """Initialize TransformationFP. + + This class has no instance state; all methods are static or + class‐level. """ pass @staticmethod def convert_arr2vec(arr: np.ndarray) -> cDataStructs.ExplicitBitVect: - """ - Converts a numpy array to a RDKit ExplicitBitVect. + """Convert a NumPy array of bits into an RDKit ExplicitBitVect. - Parameters: - - arr (np.ndarray): The input array. - - Returns: - - cDataStructs.ExplicitBitVect: The converted bit vector. + :param arr: Array of 0/1 values representing a fingerprint. + :type arr: np.ndarray + :returns: RDKit bit vector constructed from the bit string. + :rtype: cDataStructs.ExplicitBitVect """ - arr_tostring = "".join(arr.astype(str)) - EBitVect = cDataStructs.CreateFromBitString(arr_tostring) - return EBitVect + bitstr = "".join(str(int(x)) for x in arr.flatten()) + return cDataStructs.CreateFromBitString(bitstr) - @staticmethod def fit( + self, reaction_smiles: str, symbols: str, fp_type: str, @@ -41,39 +55,81 @@ def fit( return_array: bool = True, **kwargs: Any, ) -> Union[np.ndarray, cDataStructs.ExplicitBitVect]: + """Generate a reaction fingerprint by subtracting reactant from product + fingerprints. + + :param reaction_smiles: Reaction SMILES, reactant and product separated by `symbols`. + :type reaction_smiles: str + :param symbols: Delimiter between reactants and products in the SMILES string. + :type symbols: str + :param fp_type: Fingerprint type to use for individual molecules (e.g., 'ecfp4'). + :type fp_type: str + :param abs: If True, take absolute value of the difference vector. + :type abs: bool + :param return_array: If True, return a NumPy array; otherwise convert to an RDKit bit vector. + :type return_array: bool + :param kwargs: Additional keyword arguments passed to `SmilesFeaturizer.featurize_smiles`. + :type kwargs: Any + :returns: Reaction fingerprint as a NumPy array or RDKit bit vector. + :rtype: Union[np.ndarray, cDataStructs.ExplicitBitVect] + :raises ValueError: If `reaction_smiles` is not correctly formatted. """ - Generates a reaction fingerprint for a given reaction represented by a SMILES string. - - Parameters: - - reaction_smiles (str): The SMILES string of the reaction, separated by `symbols`. - - symbols (str): The symbol used to separate reactants and products in the SMILES string. - - fp_type (str): The type of fingerprint to generate (e.g., 'maccs', 'ecfp'). - - abs (bool): Whether to take the absolute value of the reaction fingerprint difference. - - return_array (bool): Whether to return the reaction fingerprint as a numpy array or as a bit vector. - - Returns: - - Union[np.ndarray, cDataStructs.ExplicitBitVect]: The reaction fingerprint either as an array - or a bit vector, depending on the value of `return_array`. - """ - react, prod = reaction_smiles.split(symbols) - react_fps = None - for s in react.split("."): - if react_fps is None: - react_fps = SmilesFeaturizer.featurize_smiles(s, fp_type, **kwargs) - else: - react_fps += SmilesFeaturizer.featurize_smiles(s, fp_type, **kwargs) - - prod_fps = None - for s in prod.split("."): - if prod_fps is None: - prod_fps = SmilesFeaturizer.featurize_smiles(s, fp_type, **kwargs) - else: - prod_fps += SmilesFeaturizer.featurize_smiles(s, fp_type, **kwargs) - - reaction_fp = np.subtract(prod_fps, react_fps) + if symbols not in reaction_smiles: + raise ValueError(f"Reaction SMILES must contain separator '{symbols}'") + react_part, prod_part = reaction_smiles.split(symbols) + + def sum_fps(parts: list[str]) -> np.ndarray: + total = None + for smi in parts: + vec = SmilesFeaturizer.featurize_smiles(smi, fp_type, **kwargs) + if total is None: + total = vec.copy() if isinstance(vec, np.ndarray) else vec + else: + total = total + vec # type: ignore + return total # type: ignore + + react_vec = sum_fps(react_part.split(".")) + prod_vec = sum_fps(prod_part.split(".")) + + diff = prod_vec - react_vec # type: ignore if abs: - reaction_fp = np.abs(reaction_fp) + diff = np.abs(diff) + if return_array: - return reaction_fp - else: - return TransformationFP.convert_arr2vec(reaction_fp) + return diff # type: ignore + return TransformationFP.convert_arr2vec(diff) # type: ignore + + def help(self) -> None: + """Print usage summary for the TransformationFP class. + + :returns: None + :rtype: NoneType + """ + print("TransformationFP: compute reaction fingerprints via vector subtraction.") + print( + " fit(reaction_smiles, symbols, fp_type, abs, return_array=True, **kwargs)" + ) + print(" reaction_smiles: 'R1.R2>>P1.P2' SMILES string") + print(" symbols: separator between reactants and products (e.g. '>>')") + print( + " fp_type: one of 'maccs', 'avalon', 'ecfp#', 'fcfp#', 'rdk#', 'ap', 'torsion', 'pharm2d'" + ) + print(" abs: take absolute difference (True/False)") + print(" return_array: return NumPy array (True) or RDKit bit vector (False)") + print(" convert_arr2vec(arr: np.ndarray) -> ExplicitBitVect") + print("Example:") + print(" tfp = TransformationFP()") + print(" arr = tfp.fit('CCO>>CC=O', '>>', 'ecfp4', abs=True)") + print( + " bv = tfp.fit('CCO>>CC=O', '>>', 'ecfp4', abs=True, return_array=False)" + ) + + def __str__(self) -> str: + """Short description of the transformer. + + :returns: Class name. + :rtype: str + """ + return "" + + __repr__ = __str__ diff --git a/synkit/Chem/Molecule/standardize.py b/synkit/Chem/Molecule/standardize.py index b8d9f88..7a8b6ae 100644 --- a/synkit/Chem/Molecule/standardize.py +++ b/synkit/Chem/Molecule/standardize.py @@ -5,148 +5,125 @@ from typing import Optional -def sanitize_and_canonicalize_smiles(smiles: str) -> str | None: - """ - Sanitize and canonicalize a SMILES string using RDKit. - - Parameters - ---------- - smiles : str - Input SMILES string. +def sanitize_and_canonicalize_smiles(smiles: str) -> Optional[str]: + """Sanitize and canonicalize a SMILES string. - Returns - ------- - str or None - Canonical SMILES if valid and sanitizable, else None. + :param smiles: Input SMILES string. + :type smiles: str + :returns: Canonical SMILES if valid, otherwise None. + :rtype: Optional[str] """ try: mol = Chem.MolFromSmiles(smiles, sanitize=True) if mol is None: return None - Chem.SanitizeMol(mol) # additional safety + Chem.SanitizeMol(mol) return Chem.MolToSmiles(mol, canonical=True) except Exception: return None def normalize_molecule(mol: Chem.Mol) -> Chem.Mol: - """ - Normalize a molecule using RDKit's Normalizer. + """Normalize a molecule using RDKit's Normalizer. - Parameters: - - mol (Chem.Mol): RDKit Mol object to be normalized. - - Returns: - - Chem.Mol: Normalized RDKit Mol object. + :param mol: RDKit Mol object to normalize. + :type mol: Chem.Mol + :returns: Normalized RDKit Mol object. + :rtype: Chem.Mol """ normalizer = rdMolStandardize.Normalizer() return normalizer.normalize(mol) def canonicalize_tautomer(mol: Chem.Mol) -> Chem.Mol: - """ - Canonicalize the tautomer of a molecule using RDKit's TautomerCanonicalizer. - - Parameters: - - mol (Chem.Mol): RDKit Mol object. + """Canonicalize the tautomeric form of a molecule. - Returns: - - Chem.Mol: Mol object with canonicalized tautomer. + :param mol: RDKit Mol object to canonicalize. + :type mol: Chem.Mol + :returns: Mol object with a canonical tautomer. + :rtype: Chem.Mol """ - tautomer_canonicalizer = rdMolStandardize.TautomerEnumerator() - return tautomer_canonicalizer.Canonicalize(mol) + tautomer_enumerator = rdMolStandardize.TautomerEnumerator() + return tautomer_enumerator.Canonicalize(mol) def salts_remover(mol: Chem.Mol) -> Chem.Mol: - """ - Remove salt fragments from a molecule using RDKit's SaltRemover. - - Parameters: - - mol (Chem.Mol): RDKit Mol object. + """Remove salt fragments from a molecule. - Returns: - - Chem.Mol: Mol object with salts removed. + :param mol: RDKit Mol object to process. + :type mol: Chem.Mol + :returns: Mol object with salts removed. + :rtype: Chem.Mol """ remover = SaltRemover() return remover.StripMol(mol) def uncharge_molecule(mol: Chem.Mol) -> Chem.Mol: - """ - Neutralize a molecule by removing counter-ions using RDKit's Uncharger. + """Neutralize a molecule by removing charges. - Parameters: - - mol (Chem.Mol): RDKit Mol object. - - Returns: - - Chem.Mol: Neutralized Mol object. + :param mol: RDKit Mol object to neutralize. + :type mol: Chem.Mol + :returns: Neutralized Mol object. + :rtype: Chem.Mol """ uncharger = rdMolStandardize.Uncharger() return uncharger.uncharge(mol) -def fragments_remover(mol: Chem.Mol) -> Chem.Mol: - """ - Remove small fragments from a molecule, keeping only the largest one. - - Parameters: - - mol (Chem.Mol): RDKit Mol object. +def fragments_remover(mol: Chem.Mol) -> Optional[Chem.Mol]: + """Keep only the largest fragment of a molecule. - Returns: - - Chem.Mol: Mol object with small fragments removed. + :param mol: RDKit Mol object to fragment. + :type mol: Chem.Mol + :returns: Mol object of the largest fragment, or None if input is + empty. + :rtype: Optional[Chem.Mol] """ frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=True) return max(frags, default=None, key=lambda m: m.GetNumAtoms()) def remove_explicit_hydrogens(mol: Chem.Mol) -> Chem.Mol: - """ - Remove explicit hydrogens from a molecule to leave only the heavy atoms. - - Parameters: - - mol (Chem.Mol): RDKit Mol object. + """Remove all explicit hydrogens from a molecule. - Returns: - - Chem.Mol: Mol object with explicit hydrogens removed. + :param mol: RDKit Mol object to process. + :type mol: Chem.Mol + :returns: Mol object without explicit hydrogens. + :rtype: Chem.Mol """ return Chem.RemoveHs(mol) def remove_radicals_and_add_hydrogens( - mol: Chem.Mol, removeH=True + mol: Chem.Mol, removeH: bool = True ) -> Optional[Chem.Mol]: - """ - Remove radicals from a molecule by setting radical electrons to zero and adding hydrogens where needed. + """Replace radical electrons by hydrogens and optionally remove explicit H. - Parameters: - - mol (Chem.Mol): RDKit Mol object. - - Returns: - - Chem.Mol: Mol object with radicals removed and necessary hydrogens added. + :param mol: RDKit Mol object with possible radicals. + :type mol: Chem.Mol + :param removeH: If True, remove explicit hydrogens after addition. + :type removeH: bool + :returns: Mol object with radicals neutralized and hydrogens + adjusted. + :rtype: Optional[Chem.Mol] """ - # mol = Chem.RemoveHs(mol) # Remove explicit hydrogens first for atom in mol.GetAtoms(): - if atom.GetNumRadicalElectrons() > 0: - atom.SetNumExplicitHs( - atom.GetNumExplicitHs() + atom.GetNumRadicalElectrons() - ) - atom.SetNumRadicalElectrons(0) - mol = rdmolops.AddHs(mol) # Add hydrogens back - if removeH: - return remove_explicit_hydrogens(mol) - else: - return mol + rad = atom.GetNumRadicalElectrons() + if rad > 0: + atom.SetNumExplicitHs(atom.GetNumExplicitHs() + rad) + atom.SetNumRadicalElectrons(0) + mol = rdmolops.AddHs(mol) + return remove_explicit_hydrogens(mol) if removeH else mol def remove_isotopes(mol: Chem.Mol) -> Chem.Mol: - """ - Remove isotopic information from a molecule. - - Parameters: - - mol (Chem.Mol): RDKit Mol object. + """Remove all isotope labels from a molecule. - Returns: - - Chem.Mol: Mol object with isotopes removed. + :param mol: RDKit Mol object to process. + :type mol: Chem.Mol + :returns: Mol object with isotopes cleared. + :rtype: Chem.Mol """ for atom in mol.GetAtoms(): atom.SetIsotope(0) @@ -154,41 +131,37 @@ def remove_isotopes(mol: Chem.Mol) -> Chem.Mol: def clear_stereochemistry(mol: Chem.Mol) -> Chem.Mol: - """ - Clear all stereochemical information from a molecule. - - Parameters: - - mol (Chem.Mol): RDKit Mol object. + """Remove stereochemical annotations from a molecule. - Returns: - - Chem.Mol: Mol object with stereochemistry cleared. + :param mol: RDKit Mol object to process. + :type mol: Chem.Mol + :returns: Mol object with stereochemistry removed. + :rtype: Chem.Mol """ Chem.RemoveStereochemistry(mol) return mol -def fix_radical_rsmi(rsmi: str, removeH=True) -> str: - """ - Takes a reaction SMILES string with potential radicals and returns a new reaction SMILES string - where all radicals have been replaced by adding hydrogen atoms. +def fix_radical_rsmi(rsmi: str, removeH: bool = True) -> str: + """Fix radicals in a reaction SMILES by converting them to hydrogens. - Parameters: - - rsmi (str): A reaction SMILES string containing reactants and products. - - Returns: - - str: A reaction SMILES string with radicals replaced by hydrogen atoms. + :param rsmi: Reaction SMILES string, format 'reactant>>product'. + :type rsmi: str + :param removeH: If True, remove explicit hydrogens after addition. + :type removeH: bool + :returns: Corrected reaction SMILES with radicals replaced. + :rtype: str """ - r, p = rsmi.split(">>") - r_mol = Chem.MolFromSmiles(r, sanitize=False) - p_mol = Chem.MolFromSmiles(p, sanitize=False) + react_smiles, prod_smiles = rsmi.split(">>") + r_mol = Chem.MolFromSmiles(react_smiles, sanitize=False) + p_mol = Chem.MolFromSmiles(prod_smiles, sanitize=False) Chem.SanitizeMol(r_mol) Chem.SanitizeMol(p_mol) - if r_mol is not None and p_mol is not None: - r_mol = remove_radicals_and_add_hydrogens(r_mol, removeH) - p_mol = remove_radicals_and_add_hydrogens(p_mol, removeH) - - r_smiles = Chem.MolToSmiles(r_mol) if r_mol else r - p_smiles = Chem.MolToSmiles(p_mol) if p_mol else p - return f"{r_smiles}>>{p_smiles}" - else: - return f"{r}>>{p}" # + + if r_mol and p_mol: + r_fixed = remove_radicals_and_add_hydrogens(r_mol, removeH) + p_fixed = remove_radicals_and_add_hydrogens(p_mol, removeH) + r_out = Chem.MolToSmiles(r_fixed) if r_fixed else react_smiles + p_out = Chem.MolToSmiles(p_fixed) if p_fixed else prod_smiles + return f"{r_out}>>{p_out}" + return rsmi diff --git a/synkit/Chem/Reaction/__init__.py b/synkit/Chem/Reaction/__init__.py index e168770..db24749 100644 --- a/synkit/Chem/Reaction/__init__.py +++ b/synkit/Chem/Reaction/__init__.py @@ -1,4 +1,9 @@ from .aam_validator import AAMValidator from .standardize import Standardize from .canon_rsmi import CanonRSMI -from .rsmi_utils import * + +__all__ = [ + "AAMValidator", + "Standardize", + "CanonRSMI", +] diff --git a/synkit/Chem/Reaction/aam_utils.py b/synkit/Chem/Reaction/aam_utils.py deleted file mode 100644 index 2866f1c..0000000 --- a/synkit/Chem/Reaction/aam_utils.py +++ /dev/null @@ -1,97 +0,0 @@ -import re -from rdkit import Chem -from rdkit.Chem.MolStandardize import rdMolStandardize - -from typing import Optional, List - - -def enumerate_tautomers(reaction_smiles: str) -> Optional[List[str]]: - """ - Enumerates possible tautomers for reactants while canonicalizing the products in a - reaction SMILES string. This function first splits the reaction SMILES string into - reactants and products. It then generates all possible tautomers for the reactants and - canonicalizes the product molecule. The function returns a list of reaction SMILES - strings for each tautomer of the reactants combined with the canonical product. - - Parameters: - - reaction_smiles (str): A SMILES string of the reaction formatted as - 'reactants>>products'. - - Returns: - - List[str] | None: A list of SMILES strings for the reaction, with each string - representing a different - - tautomer of the reactants combined with the canonicalized products. Returns None if - an error occurs or if invalid SMILES strings are provided. - - Raises: - - ValueError: If the provided SMILES strings cannot be converted to molecule objects, - indicating invalid input. - """ - try: - # Split the input reaction SMILES string into reactants and products - reactants_smiles, products_smiles = reaction_smiles.split(">>") - - # Convert SMILES strings to molecule objects - reactants_mol = Chem.MolFromSmiles(reactants_smiles) - products_mol = Chem.MolFromSmiles(products_smiles) - - if reactants_mol is None or products_mol is None: - raise ValueError( - "Invalid SMILES string provided for reactants or products." - ) - - # Initialize tautomer enumerator - - enumerator = rdMolStandardize.TautomerEnumerator() - - # Enumerate tautomers for the reactants and canonicalize the products - try: - reactants_can = enumerator.Enumerate(reactants_mol) - except Exception as e: - print(f"An error occurred: {e}") - reactants_can = [reactants_mol] - products_can = products_mol - - # Convert molecule objects back to SMILES strings - reactants_can_smiles = [Chem.MolToSmiles(i) for i in reactants_can] - products_can_smiles = Chem.MolToSmiles(products_can) - - # Combine each reactant tautomer with the canonical product in SMILES format - rsmi_list = [i + ">>" + products_can_smiles for i in reactants_can_smiles] - if len(rsmi_list) == 0: - return [reaction_smiles] - else: - # rsmi_list.remove(reaction_smiles) - rsmi_list.insert(0, reaction_smiles) - return rsmi_list - - except Exception as e: - print(f"An error occurred: {e}") - return [reaction_smiles] - - -def mapping_success_rate(list_mapping_data): - """ - Calculate the success rate of entries containing atom mappings in a list of data - strings. - - Parameters: - - list_mapping_in_data (list of str): List containing strings to be searched for atom - mappings. - - Returns: - - float: The success rate of finding atom mappings in the list as a percentage. - - Raises: - - ValueError: If the input list is empty. - """ - atom_map_pattern = re.compile(r":\d+") - if not list_mapping_data: - raise ValueError("The input list is empty, cannot calculate success rate.") - - success = sum( - 1 for entry in list_mapping_data if re.search(atom_map_pattern, entry) - ) - rate = 100 * (success / len(list_mapping_data)) - - return round(rate, 2) diff --git a/synkit/Chem/Reaction/aam_validator.py b/synkit/Chem/Reaction/aam_validator.py index f48cd53..6662cee 100644 --- a/synkit/Chem/Reaction/aam_validator.py +++ b/synkit/Chem/Reaction/aam_validator.py @@ -9,30 +9,43 @@ from synkit.IO.chem_converter import rsmi_to_graph from synkit.Graph.ITS.its_decompose import get_rc from synkit.Graph.ITS.its_construction import ITSConstruction -from .aam_utils import enumerate_tautomers, mapping_success_rate +from synkit.Chem.utils import enumerate_tautomers, mapping_success_rate class AAMValidator: - """ - A utility class for validating atom‐atom mappings (AAM) in reaction SMILES. + """A utility class for validating atom‐atom mappings (AAM) in reaction + SMILES. Provides methods to compare mapped SMILES against ground truth by using reaction‐center (RC) or ITS‐graph isomorphism checks, including tautomer enumeration support and batch validation over tabular data. + + Quick start + ----------- + >>> from synkit.Chem.Reaction import AAMValidator + >>> validator = AAMValidator() + >>> rsmi_1 = ( + '[CH3:1][C:2](=[O:3])[OH:4].[CH3:5][OH:6]' + '>>' + '[CH3:1][C:2](=[O:3])[O:6][CH3:5].[OH2:4]') + >>> rsmi_2 = ( + '[CH3:5][C:1](=[O:2])[OH:3].[CH3:6][OH:4]' + '>>' + '[CH3:5][C:1](=[O:2])[O:4][CH3:6].[OH2:3]') + >>> is_eq = validator.smiles_check(rsmi_1, rsmi_2, check_method='ITS') + >>> print(is_eq) + >>> True """ def __init__(self) -> None: - """ - Initialize the AAMValidator. - """ + """Initialize the AAMValidator.""" pass @staticmethod def check_equivariant_graph( its_graphs: List[nx.Graph], ) -> Tuple[List[Tuple[int, int]], int]: - """ - Identify all pairs of isomorphic ITS graphs. + """Identify all pairs of isomorphic ITS graphs. :param its_graphs: A list of ITS graphs to compare. :type its_graphs: list of networkx.Graph @@ -66,20 +79,21 @@ def smiles_check( check_method: str = "RC", ignore_aromaticity: bool = False, ) -> bool: - """ - Validate a single mapped SMILES string against ground truth. + """Validate a single mapped SMILES string against ground truth. :param mapped_smile: The mapped SMILES to validate. :type mapped_smile: str :param ground_truth: The reference SMILES string. :type ground_truth: str - :param check_method: Which method to use: - `"RC"` for reaction‐center graph or - `"ITS"` for full ITS‐graph isomorphism. + :param check_method: Which method to use: `"RC"` for + reaction‐center graph or `"ITS"` for full ITS‐graph + isomorphism. :type check_method: str - :param ignore_aromaticity: If True, ignore aromaticity differences in ITS construction. + :param ignore_aromaticity: If True, ignore aromaticity + differences in ITS construction. :type ignore_aromaticity: bool - :returns: True if exactly one isomorphic match is found; False otherwise. + :returns: True if exactly one isomorphic match is found; False + otherwise. :rtype: bool """ its_graphs, rc_graphs = [], [] @@ -105,8 +119,7 @@ def smiles_check_tautomer( check_method: str = "RC", ignore_aromaticity: bool = False, ) -> Optional[bool]: - """ - Validate against all tautomers of a ground truth SMILES. + """Validate against all tautomers of a ground truth SMILES. :param mapped_smile: The mapped SMILES to test. :type mapped_smile: str @@ -142,8 +155,7 @@ def check_pair( ignore_aromaticity: bool = False, ignore_tautomers: bool = True, ) -> bool: - """ - Validate a single record (dict) entry for equivalence. + """Validate a single record (dict) entry for equivalence. :param mapping: A record containing both mapped and ground‐truth SMILES. :type mapping: dict of str→str @@ -186,8 +198,7 @@ def validate_smiles( verbose: int = 0, ignore_tautomers: bool = True, ) -> List[Dict[str, Union[str, float, List[bool]]]]: - """ - Batch-validate mapped SMILES in tabular or list-of-dicts form. + """Batch-validate mapped SMILES in tabular or list-of-dicts form. :param data: A pandas DataFrame or list of dicts, each row containing at least `ground_truth_col` and each entry in `mapped_cols`. diff --git a/synkit/Chem/Reaction/balance_check.py b/synkit/Chem/Reaction/balance_check.py index 76e0f1c..f21754b 100644 --- a/synkit/Chem/Reaction/balance_check.py +++ b/synkit/Chem/Reaction/balance_check.py @@ -1,48 +1,38 @@ from rdkit import Chem from rdkit.Chem.rdMolDescriptors import CalcMolFormula - from joblib import Parallel, delayed -from typing import List, Dict, Union, Tuple +from typing import List, Dict, Union, Tuple, Any class BalanceReactionCheck: - """ - A class to check the balance of chemical reactions given in SMILES format. - It supports parallel execution and maintains the input format in the output. + """Check elemental balance of chemical reactions in SMILES format. + + Supports checking single reactions, reaction dictionaries, or lists + in parallel. + + :ivar n_jobs: Number of parallel jobs for batch checking. + :ivar verbose: Verbosity level for joblib. """ - def __init__( - self, - n_jobs: int = 4, - verbose: int = 0, - ): + def __init__(self, n_jobs: int = 4, verbose: int = 0) -> None: """ - Initializes the class with given input data, the column name - for reactions in the input, number of jobs for - parallel processing, and verbosity level. - - Parameters: - - input_data (Union[str, List[Union[str, Dict[str, str]]]]): A single SMILES - string, a list of SMILES strings, or a list of dictionaries with 'reactions' keys. - - rsmi_column (str): The key/column name for reaction SMILES strings - in the input data. - - n_jobs (int): The number of parallel jobs to run for balance checking. - - verbose (int): The verbosity level of joblib parallel execution. + :param n_jobs: Number of parallel jobs for batch balance checks. Defaults to 4. + :type n_jobs: int + :param verbose: Verbosity level passed to joblib. Defaults to 0. + :type verbose: int """ - self.n_jobs = n_jobs self.verbose = verbose @staticmethod def get_combined_molecular_formula(smiles: str) -> str: - """ - Computes the molecular formula for a molecule represented by a SMILES string. + """Compute the molecular formula of a SMILES. - Parameters: - - smiles (str): The SMILES string of the molecule. - - Returns: - - str: The molecular formula, or an empty string if the molecule is invalid. + :param smiles: SMILES string of the molecule. + :type smiles: str + :returns: Elemental formula (e.g., "C6H6") or empty string if + invalid. + :rtype: str """ mol = Chem.MolFromSmiles(smiles) if not mol: @@ -54,109 +44,95 @@ def parse_input( input_data: Union[str, List[Union[str, Dict[str, str]]]], rsmi_column: str = "reactions", ) -> List[Dict[str, str]]: + """Normalize input into a list of reaction‐dicts. + + :param input_data: A single SMILES, list of SMILES, or list of dicts containing `rsmi_column`. + :type input_data: str or List[Union[str, Dict[str, str]]] + :param rsmi_column: Key in dicts for the reaction SMILES. Defaults to "reactions". + :type rsmi_column: str + :returns: List of dicts with a single key `rsmi_column` mapping to each reaction. + :rtype: List[Dict[str, str]] + :raises ValueError: If `input_data` is neither str nor list. """ - Parses the input data into a standardized list containing - dictionaries for each reaction. - - Parameters: - - input_data (Union[str, List[Union[str, Dict[str, str]]]]): - The input data to be processed. - - Returns: - - List[Dict[str, str]]: A list of dictionaries with reaction SMILES strings. - """ - standardized_input = [] + standardized: List[Dict[str, str]] = [] if isinstance(input_data, str): - standardized_input.append({rsmi_column: input_data}) + standardized.append({rsmi_column: input_data}) elif isinstance(input_data, list): for item in input_data: if isinstance(item, str): - standardized_input.append({rsmi_column: item}) + standardized.append({rsmi_column: item}) elif isinstance(item, dict) and rsmi_column in item: - standardized_input.append(item) + standardized.append(item) else: - raise ValueError("Unsupported input type") - return standardized_input + raise ValueError("Unsupported input type for balance checking") + return standardized @staticmethod - def parse_reaction(reaction_smiles: str) -> Tuple[List[str], List[str]]: - """ - Splits a reaction SMILES string into reactants and products. - - Parameters: - - reaction_smiles (str): A SMILES string representing a chemical reaction. - - Returns: - - Tuple[List[str], List[str]]: Lists of SMILES strings for reactants and products. + def parse_reaction(reaction_smiles: str) -> Tuple[str, str]: + """Split a reaction SMILES into reactant and product SMILES strings. + + :param reaction_smiles: Reaction SMILES in 'reactants>>products' + format. + :type reaction_smiles: str + :returns: Tuple of (reactants, products) SMILES. + :rtype: Tuple[str, str] """ - reactants_smiles, products_smiles = reaction_smiles.split(">>") - return reactants_smiles, products_smiles + return tuple(reaction_smiles.split(">>")) @staticmethod def rsmi_balance_check(reaction_smiles: str) -> bool: + """Determine if a reaction SMILES is elementally balanced. + + :param reaction_smiles: Reaction SMILES in 'reactants>>products' + format. + :type reaction_smiles: str + :returns: True if reactant and product formulas match, else + False. + :rtype: bool """ - Checks if a reaction SMILES string is balanced. - - Parameters: - - reaction_smiles (str): A SMILES string representing a chemical reaction. - - Returns: - - bool: True if the reaction is balanced, False otherwise. - """ - reactants_smiles, products_smiles = BalanceReactionCheck.parse_reaction( - reaction_smiles - ) - reactants_forumula = BalanceReactionCheck.get_combined_molecular_formula( - reactants_smiles - ) - products_forumula = BalanceReactionCheck.get_combined_molecular_formula( - products_smiles - ) - return reactants_forumula == products_forumula + react, prod = BalanceReactionCheck.parse_reaction(reaction_smiles) + react_formula = BalanceReactionCheck.get_combined_molecular_formula(react) + prod_formula = BalanceReactionCheck.get_combined_molecular_formula(prod) + return react_formula == prod_formula @staticmethod def dict_balance_check( reaction_dict: Dict[str, str], rsmi_column: str - ) -> Dict[str, Union[bool, str]]: - """ - Checks if a single reaction (in SMILES format) is balanced, maintaining - the input format. - - Parameters: - - reaction_dict (Dict[str, str]): A dictionary containing the - reaction SMILES string. - - Returns: - - Dict[str, Union[bool, str]]: A dictionary indicating if the reaction is - balanced, along with the original reaction data. + ) -> Dict[str, Any]: + """Check balance for a single reaction dict, preserving original keys. + + :param reaction_dict: Dict containing at least a `rsmi_column` key. + :type reaction_dict: Dict[str, str] + :param rsmi_column: Key for reaction SMILES in `reaction_dict`. + :type rsmi_column: str + :returns: Original dict augmented with `"balanced": bool`. + :rtype: Dict[str, Any] """ - reaction_smiles = reaction_dict[rsmi_column] - balance = BalanceReactionCheck.rsmi_balance_check(reaction_smiles) - return {"balanced": balance, **reaction_dict} + rsmi = reaction_dict[rsmi_column] + balanced = BalanceReactionCheck.rsmi_balance_check(rsmi) + return {"balanced": balanced, **reaction_dict} def dicts_balance_check( self, input_data: Union[str, List[Union[str, Dict[str, str]]]], rsmi_column: str = "reactions", - ) -> Tuple[List[Dict[str, Union[bool, str]]], List[Dict[str, Union[bool, str]]]]: - """ - Checks the balance of all reactions in the input data. - - Returns: - - Tuple[List[Dict[str, Union[bool, str]]], List[Dict[str, Union[bool, str]]]]: - Two lists containing dictionaries of balanced and unbalanced reactions, - respectively. + ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """Batch‐check balance for multiple reactions, in parallel. + + :param input_data: Single reaction SMILES, list of SMILES, or + list of dicts. + :type input_data: Union[str, List[Union[str, Dict[str, str]]]] + :param rsmi_column: Key for reaction SMILES in each dict. + Defaults to "reactions". + :type rsmi_column: str + :returns: Tuple (balanced_list, unbalanced_list) of dicts each + including `"balanced"`. + :rtype: Tuple[List[Dict[str, Any]], List[Dict[str, Any]]] """ - reactions = self.parse_input(input_data, rsmi_column) results = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)( - delayed(self.dict_balance_check)(reaction, rsmi_column) - for reaction in reactions + delayed(self.dict_balance_check)(rd, rsmi_column) for rd in reactions ) - - balanced_reactions = [reaction for reaction in results if reaction["balanced"]] - unbalanced_reactions = [ - reaction for reaction in results if not reaction["balanced"] - ] - - return balanced_reactions, unbalanced_reactions + balanced = [r for r in results if r["balanced"]] + unbalanced = [r for r in results if not r["balanced"]] + return balanced, unbalanced diff --git a/synkit/Chem/Reaction/canon_rsmi.py b/synkit/Chem/Reaction/canon_rsmi.py index 27a9417..0abb0c8 100644 --- a/synkit/Chem/Reaction/canon_rsmi.py +++ b/synkit/Chem/Reaction/canon_rsmi.py @@ -9,9 +9,9 @@ class CanonRSMI: - """ - A **pure-Python / pure-NetworkX** utility for canonicalizing reaction SMILES - by expanding atom-maps and deterministically reindexing reaction graphs. + """A **pure-Python / pure-NetworkX** utility for canonicalizing reaction + SMILES by expanding atom-maps and deterministically reindexing reaction + graphs. Workflow -------- @@ -61,16 +61,15 @@ def __init__( @staticmethod def _mol_from_smiles(smi: str) -> Chem.Mol: - """ - RDKit MolFromSmiles with explicit sanitize step. - """ + """RDKit MolFromSmiles with explicit sanitize step.""" mol = _rdkit_MolFromSmiles(smi, sanitize=False) Chem.SanitizeMol(mol) return mol def expand_aam(self, rsmi: str) -> str: - """ - Assign new atom-map IDs to unmapped reactant atoms in 'reactants>>products' SMILES. + """Assign new atom-map IDs to unmapped reactant atoms in + 'reactants>>products' SMILES. + New IDs start at max(existing maps)+1. """ try: @@ -110,9 +109,7 @@ def sync_atom_map_with_index(G: nx.Graph) -> None: def get_aam_pairwise_indices( G: nx.Graph, H: nx.Graph, aam_key: str = "atom_map" ) -> List[Tuple[int, int]]: - """ - Return sorted list of (G_node, H_node) for shared atom-map IDs. - """ + """Return sorted list of (G_node, H_node) for shared atom-map IDs.""" gmap = { data[aam_key]: n for n, data in G.nodes(data=True) @@ -223,7 +220,8 @@ def canonical_product_graph(self) -> Optional[nx.Graph]: @property def canonical_hash(self) -> Optional[str]: - """Reaction-level hash combining reactant and product canonical hashes.""" + """Reaction-level hash combining reactant and product canonical + hashes.""" if not self._canon_reactant_graph or not self._canon_product_graph: return None h_reac = self._canon.canonical_signature(self._canon_reactant_graph) @@ -236,9 +234,7 @@ def mapping_pairs(self) -> Optional[List[Tuple[int, int]]]: return self._mapping_pairs def help(self) -> None: # pragma: no cover - """ - Pretty-print the class doc and public methods with signatures. - """ + """Pretty-print the class doc and public methods with signatures.""" print(inspect.getdoc(self.__class__)) for meth in ( "expand_aam", diff --git a/synkit/Chem/Reaction/cleaning.py b/synkit/Chem/Reaction/cleaning.py new file mode 100644 index 0000000..6bdd370 --- /dev/null +++ b/synkit/Chem/Reaction/cleaning.py @@ -0,0 +1,66 @@ +from typing import List +from synkit.Chem.Reaction.standardize import Standardize +from synkit.Chem.Reaction.balance_check import BalanceReactionCheck + + +class Cleaning: + """Utilities for cleaning and filtering reaction SMILES lists. + + Methods + ------- + remove_duplicates(smiles_list) + Remove duplicate SMILES while preserving input order. + clean_smiles(smiles_list) + Standardize, balance‑check, and deduplicate a list of reaction SMILES. + """ + + def __init__(self) -> None: + """Initialize the Cleaning helper. + + No instance attributes are used. + """ + pass + + @staticmethod + def remove_duplicates(smiles_list: List[str]) -> List[str]: + """Remove duplicate SMILES strings, preserving first occurrences. + + :param smiles_list: List of reaction SMILES strings. + :type smiles_list: List[str] + :returns: List of unique SMILES in original order. + :rtype: List[str] + """ + seen = set() + return [smi for smi in smiles_list if not (smi in seen or seen.add(smi))] + + @staticmethod + def clean_smiles(smiles_list: List[str]) -> List[str]: + """Standardize, balance‑check, and deduplicate reaction SMILES. + + Steps: + 1. Standardize each SMILES via `Standardize.standardize_rsmi`. + 2. Keep only those that pass `BalanceReactionCheck.rsmi_balance_check`. + 3. Remove duplicates preserving order. + + :param smiles_list: List of reaction SMILES strings to clean. + :type smiles_list: List[str] + :returns: Cleaned list of standardized, balanced, unique SMILES. + :rtype: List[str] + """ + standardizer = Standardize() + balance_checker = BalanceReactionCheck() + + standardized: List[str] = [] + for smi in smiles_list: + try: + std = standardizer.standardize_rsmi(smi, stereo=True) + if std: + standardized.append(std) + except Exception: + continue + + balanced = [ + smi for smi in standardized if balance_checker.rsmi_balance_check(smi) + ] + + return Cleaning.remove_duplicates(balanced) diff --git a/synkit/Chem/Reaction/cleanning.py b/synkit/Chem/Reaction/cleanning.py deleted file mode 100644 index 1f4b6bb..0000000 --- a/synkit/Chem/Reaction/cleanning.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import List -from synkit.Chem.Reaction.standardize import Standardize -from synkit.Chem.Reaction.balance_check import BalanceReactionCheck - - -class Cleanning: - def __init__(self) -> None: - pass - - @staticmethod - def remove_duplicates(smiles_list: List[str]) -> List[str]: - """ - Removes duplicate SMILES strings from a list, maintaining the order of - first occurrences. Uses a set to track seen SMILES for efficiency. - - Parameters: - - smiles_list (List[str]): A list of SMILES strings representing - chemical reactions. - - Returns: - - List[str]: A list with unique SMILES strings, preserving the original order. - """ - seen = set() - unique_smiles = [ - smiles for smiles in smiles_list if not (smiles in seen or seen.add(smiles)) - ] - return unique_smiles - - @staticmethod - def clean_smiles(smiles_list: List[str]) -> List[str]: - """ - Cleans a list of SMILES strings by standardizing them, checking their chemical - balance, and removing duplicates. Each SMILES is first checked for validity and - then standardized. Only balanced reactions are kept. - - Parameters: - - smiles_list (List[str]): A list of SMILES strings representing chemical reactions. - - Returns: - - List[str]: A list of cleaned and standardized SMILES strings. - """ - # Standardize and check balance in separate list comprehensions - standardizer = Standardize() - balance_checker = BalanceReactionCheck() - - standardized_smiles = [] - for smiles in smiles_list: - try: - r = standardizer.standardize_rsmi(smiles, True) - standardized_smiles.append(r) - except Exception as e: - print(e) - pass - # standardized_smiles = [ - # standardizer.standardize_rsmi(smiles, True) - # for smiles in smiles_list - # if smiles - # ] - balanced_smiles = [ - smiles - for smiles in standardized_smiles - if balance_checker.rsmi_balance_check(smiles) - ] - - # Remove duplicates from the balanced SMILES list - clean_smiles = Cleanning.remove_duplicates(balanced_smiles) - return clean_smiles diff --git a/synkit/Chem/Reaction/deionize.py b/synkit/Chem/Reaction/deionize.py index 02cb3dd..98c02fc 100644 --- a/synkit/Chem/Reaction/deionize.py +++ b/synkit/Chem/Reaction/deionize.py @@ -1,8 +1,7 @@ import random -from itertools import permutations -from itertools import combinations +from itertools import combinations, permutations from joblib import Parallel, delayed -from typing import List, Tuple, Callable, Dict +from typing import List, Tuple, Callable, Dict, Any from rdkit import Chem from rdkit.Chem.MolStandardize import rdMolStandardize @@ -11,279 +10,190 @@ class Deionize: - """ - A class to deionize reactions. + """Neutralize ionic species and mixtures of ions in reactions. + + Provides methods to group ions into neutral combinations, uncharge + individual anions/cations, and apply these corrections to SMILES + strings or entire reaction dictionaries. """ @staticmethod def random_pair_ions( charges: List[int], smiles: List[str] ) -> Tuple[List[List[str]], List[List[int]]]: - """ - Generates non-overlapping groups of ions (2, 3, or 4) based on - their charges and corresponding SMILES representations, - aiming to maximize the total number of ions used by preferring - multiple smaller groups over fewer larger groups. - - Parameters: - - charges (List[int]): A list of integer charges of the ions. - - smiles (List[str]): A list of SMILES strings representing the ions. - - Returns: - - Tuple[List[List[str]], List[List[int]]]: A tuple containing two lists: - - The first list contains the groups of SMILES strings. - - The second list contains the groups of charges. + """Identify non‑overlapping groups of ions whose charges sum to zero. + + :param charges: List of integer formal charges for each ion. + :type charges: List[int] + :param smiles: Corresponding SMILES strings for each ion. + :type smiles: List[str] + :returns: A tuple of two lists: + - groups of SMILES strings forming neutral sets, + - groups of their corresponding charges. + :rtype: Tuple[List[List[str]], List[List[int]]] """ - def find_groups(indices, size): - """Finds and removes groups of a specific size that sum to zero charge.""" + def find_groups(indices: List[int], size: int) -> Tuple[int, ...]: for group in combinations(indices, size): if sum(charges[i] for i in group) == 0: return group - return [] + return () - # Prepare initial variables indices = list(range(len(charges))) - random.shuffle(indices) # Shuffle indices to ensure variety - used_indices = set() - grouped_smiles = [] - grouped_charges = [] + random.shuffle(indices) + used = set() + grouped_smiles: List[List[str]] = [] + grouped_charges: List[List[int]] = [] - for group_size in range( - 2, 5 - ): # Start with pairs, then triples, and finally quads + for group_size in (2, 3, 4): while True: - group = find_groups( - [i for i in indices if i not in used_indices], group_size - ) + available = [i for i in indices if i not in used] + group = find_groups(available, group_size) if not group: - break # No more groups of this size can be formed + break grouped_smiles.append([smiles[i] for i in group]) grouped_charges.append([charges[i] for i in group]) - used_indices.update(group) + used.update(group) return grouped_smiles, grouped_charges @staticmethod def uncharge_anion(smiles: str, charges: int = -1) -> str: - """ - Removes charge from an anionic species represented by a SMILES string. - - This function uses RDKit's standardization tools to neutralize - the charges in the molecule. It returns - the SMILES representation of the uncharged molecule. - - Parameters:: - - smiles (str): A SMILES string representing the anionic species. - - Returns: - - str: The SMILES string of the uncharged molecule. - - Note: - - The function assumes valid SMILES input. + """Neutralize an anionic SMILES string. + + :param smiles: SMILES of the anion to neutralize. + :type smiles: str + :param charges: Formal charge of the ion (negative integer). + Defaults to -1. + :type charges: int + :returns: SMILES of the uncharged molecule. + :rtype: str """ if smiles == "[N-]=[N+]=[N-]": return "[N-]=[N+]=[N]" if charges == -1: - # Convert the SMILES string to an RDKit molecule object mol = Chem.MolFromSmiles(smiles) - - # Initialize the uncharger uncharger = rdMolStandardize.Uncharger() - - # Apply the uncharger to the molecule - uncharged_mol = uncharger.uncharge(mol) - - # Convert the uncharged molecule back to a SMILES string - return Chem.MolToSmiles(uncharged_mol) - - elif charges < -1: - new_smiles = ( - smiles.replace(f"{charges}", "").replace("[", "").replace("]", "") - ) - return new_smiles + uncharged = uncharger.uncharge(mol) + return Chem.MolToSmiles(uncharged) + # for multi‐charged anions + return smiles.replace(f"{charges}", "").replace("[", "").replace("]", "") @staticmethod def uncharge_cation(smiles: str, charges: int = 1) -> str: + """Neutralize a cationic SMILES string. + + :param smiles: SMILES of the cation to neutralize. + :type smiles: str + :param charges: Formal charge of the ion (positive integer). + Defaults to 1. + :type charges: int + :returns: SMILES of the uncharged molecule. + :rtype: str """ - Removes charge from a cationic species represented by a SMILES string. - - This function uses RDKit's standardization tools to neutralize - the charges in the molecule. It returns the - SMILES representation of the uncharged molecule. - - Parameters:: - - smiles (str): A SMILES string representing the cationic species. - - Returns: - - str: The SMILES string of the uncharged molecule. - - Note: - - The function assumes valid SMILES input. - """ - if charges == 1: - new_smiles = smiles.replace("+", "") - elif charges > 1: - # For multiple positive charges, directly modify the SMILES string - new_smiles = smiles.replace(f"+{charges}", "") - return new_smiles + return smiles.replace("+", "") + return smiles.replace(f"+{charges}", "") @staticmethod def uncharge_smiles(charge_smiles: str) -> str: - """ - Processes a SMILES string containing ionic and non-ionic parts, - neutralizes the charges, and returns a modified SMILES string. - - The function splits the input SMILES string into individual components, - identifies ionic and non-ionic parts, - and attempts to neutralize charged ions. - It then creates permutations of the modified ions and combines them into - a single SMILES string, ensuring the molecular structure is valid. + """Neutralize all ionic components in a dot‑separated SMILES string. - Parameters:: - - charge_smiles (str): A SMILES string that may contain ionic and non-ionic parts. + Splits into components, identifies ionic species, groups + them into neutral sets via `random_pair_ions`, then + applies `uncharge_anion` or `uncharge_cation` and recombines. - Returns: - - str: A modified SMILES string with neutralized charges. - - Note: - - This function depends on RDKit for molecular operations. - - The function assumes a valid SMILES input. - - The 'uncharge_anion' and 'random_pair_ions' functions - must be defined and accessible. + :param charge_smiles: SMILES string with ionic and non‑ionic parts. + :type charge_smiles: str + :returns: SMILES string with charges neutralized. + :rtype: str """ - - smiles = charge_smiles.split(".") - charges = [Chem.rdmolops.GetFormalCharge(Chem.MolFromSmiles(i)) for i in smiles] - - if all(charge == 0 for charge in charges): + parts = charge_smiles.split(".") + charges = [Chem.rdmolops.GetFormalCharge(Chem.MolFromSmiles(p)) for p in parts] + if all(c == 0 for c in charges): return charge_smiles - valid_smiles, non_ionic_smiles = [], [] - original_ionic_parts, original_ion_charges = [], [] - - # Splitting the SMILES into ionic and non-ionic parts - for smile, charge in zip(smiles, charges): - if charge == 0: - non_ionic_smiles.append(smile) + non_ionic, ionic_parts, ionic_charges = [], [], [] + for p, c in zip(parts, charges): + if c == 0: + non_ionic.append(p) else: - original_ionic_parts.append(smile) - original_ion_charges.append(charge) - - valid_smiles.extend(non_ionic_smiles) - paired_smiles, paired_charges = Deionize.random_pair_ions( - original_ion_charges, original_ionic_parts - ) - # Processing each pair of ionic parts - for i_smile, i_charge in zip(paired_smiles, paired_charges): - modified_ions = [] - for ion, charge in zip(i_smile, i_charge): - if int(charge) > 0: - new_ion = Deionize.uncharge_cation(ion, charge) - modified_ions.append(new_ion) - elif int(charge) < 0: - new_ion = Deionize.uncharge_anion(ion, charge) - modified_ions.append(new_ion) - # Creating permutations of the modified ions - check_merge = False - for perm in permutations(modified_ions): - combined_ionic = "".join(perm) - if Chem.MolFromSmiles(combined_ionic): - coordinate_pattern = ["->", "<-"] - if all( - pattern not in Chem.CanonSmiles(combined_ionic) - for pattern in coordinate_pattern - ): - valid_smiles.append(Chem.CanonSmiles(combined_ionic)) - check_merge = True - break - if check_merge is False: - valid_smiles.extend(i_smile) - return ".".join(valid_smiles) + ionic_parts.append(p) + ionic_charges.append(c) + + valid = non_ionic.copy() + groups, group_chs = Deionize.random_pair_ions(ionic_charges, ionic_parts) + for smiles_group, charge_group in zip(groups, group_chs): + candidates = [] + for smi, ch in zip(smiles_group, charge_group): + if ch > 0: + candidates.append(Deionize.uncharge_cation(smi, ch)) + else: + candidates.append(Deionize.uncharge_anion(smi, ch)) + # try permutations for valid SMILES + for perm in permutations(candidates): + combo = "".join(perm) + if Chem.MolFromSmiles(combo): + valid.append(Chem.CanonSmiles(combo)) + break + else: + valid.extend(smiles_group) + return ".".join(valid) @staticmethod def ammonia_hydroxide_standardize(reaction_smiles: str) -> str: - """ - Replaces occurrences of ammonium hydroxide (NH4+ and OH-) in a - reaction SMILES string with a simplified representation (N.O or O.N). - - Parameters:: - reaction_smiles (str): The reaction SMILES string to be standardized. + """Simplify ammonium hydroxide pairs in a reaction SMILES. - Returns: - str: The standardized reaction SMILES string with - ammonium hydroxide represented as 'N.O' or 'O.N'. + :param reaction_smiles: Reaction SMILES string. + :type reaction_smiles: str + :returns: Reaction SMILES with '[NH4+].[OH-]' replaced by 'N.O' + or 'O.N'. + :rtype: str """ - # Simplify the representation of ammonium hydroxide in the reaction SMILES - new_smiles = reaction_smiles.replace("[NH4+].[OH-]", "N.O").replace( + return reaction_smiles.replace("[NH4+].[OH-]", "N.O").replace( "[OH-].[NH4+]", "O.N" ) - return new_smiles @classmethod def apply_uncharge_smiles_to_reactions( cls, - reactions: List[Dict[str, str]], + reactions: List[Dict[str, Any]], uncharge_smiles_func: Callable[[str], str], n_jobs: int = 4, - ) -> List[Dict[str, str]]: - """ - Applies a given uncharge SMILES function to the reactants - and products of a list of chemical reactions, - parallelizing the process for improved performance. - Each reaction is expected to be a dictionary - with at least 'reactants' and 'products' keys. - The function adds three new keys to each reaction - dictionary: 'uncharged_reactants', 'uncharged_products', - and 'uncharged_reactions', containing - the uncharged SMILES strings of reactants, products, - and the overall reaction, respectively. - - Parameters:: - - reactions (List[Dict[str, str]]): A list of dictionaries, where each dictionary - represents a chemical reaction with 'reactants' and 'products' keys. - - uncharge_smiles_func (Callable[[str], str]): A function that takes a SMILES - string as input and returns a modified SMILES string with neutralized charges. - - Returns: - - List[Dict[str, str]]: The input list of reaction dictionaries, modified in-place - to include 'uncharged_reactants', 'uncharged_products', and 'uncharged_reactions' - keys. + ) -> List[Dict[str, Any]]: + """Apply a neutralization function to each reaction’s + reactants/products in parallel. + + Adds keys 'new_reactants', 'new_products', and 'standardized_reactions' + based on uncharged SMILES and verifies formula balance. + + :param reactions: List of reaction dicts with 'reactants' and 'products' keys. + :type reactions: List[Dict[str, Any]] + :param uncharge_smiles_func: Function to neutralize a SMILES string. + :type uncharge_smiles_func: Callable[[str], str] + :param n_jobs: Number of parallel jobs to run. Defaults to 4. + :type n_jobs: int + :returns: List of updated reaction dicts with: + - 'success': bool indicating formula match + - 'new_reactants' / 'new_products' + - 'standardized_reactions' + :rtype: List[Dict[str, Any]] """ - # Define a helper function for processing a single reaction - def process_reaction(reaction): - fix_reactants = cls.ammonia_hydroxide_standardize(reaction["reactants"]) - fix_products = cls.ammonia_hydroxide_standardize(reaction["products"]) - - uncharged_reactants = uncharge_smiles_func(fix_reactants) - uncharged_products = uncharge_smiles_func(fix_products) - uncharged_reactants_formula = ( - BalanceReactionCheck().get_combined_molecular_formula( - uncharged_reactants - ) - ) - uncharged_products_formula = ( - BalanceReactionCheck().get_combined_molecular_formula( - uncharged_products - ) - ) - if uncharged_reactants_formula != uncharged_products_formula: - reaction["success"] = False - reaction["new_reactants"] = fix_reactants - reaction["new_products"] = fix_products - else: - reaction["success"] = True - reaction["new_reactants"] = uncharged_reactants - reaction["new_products"] = uncharged_products + def process(reaction: Dict[str, Any]) -> Dict[str, Any]: + # pre‐standardize ammonia hydroxide + r_fix = cls.ammonia_hydroxide_standardize(reaction["reactants"]) + p_fix = cls.ammonia_hydroxide_standardize(reaction["products"]) + ur = uncharge_smiles_func(r_fix) + up = uncharge_smiles_func(p_fix) + r_formula = BalanceReactionCheck().get_combined_molecular_formula(ur) + p_formula = BalanceReactionCheck().get_combined_molecular_formula(up) + reaction["success"] = r_formula == p_formula + reaction["new_reactants"] = ur if reaction["success"] else r_fix + reaction["new_products"] = up if reaction["success"] else p_fix reaction["standardized_reactions"] = ( f"{reaction['new_reactants']}>>{reaction['new_products']}" ) return reaction - # Use joblib to parallelize the processing of reactions - reactions = Parallel(n_jobs=n_jobs)( - delayed(process_reaction)(reaction) for reaction in reactions - ) - return reactions + return Parallel(n_jobs=n_jobs)(delayed(process)(rxn) for rxn in reactions) diff --git a/synkit/Chem/Reaction/fix_aam.py b/synkit/Chem/Reaction/fix_aam.py index 995eab7..8f2b2eb 100644 --- a/synkit/Chem/Reaction/fix_aam.py +++ b/synkit/Chem/Reaction/fix_aam.py @@ -3,89 +3,62 @@ class FixAAM: - """ - A class containing methods for manipulating atom mapping numbers (AAM) in molecules. - It includes functionality for incrementing atom map numbers in a molecule, adjusting - atom mappings in SMILES strings, and fixing atom mappings in reaction SMILES (RSMI) strings. - - Methods: - increment_atom_mapping(mol: Chem.Mol) -> Chem.Mol: - Increments the atom map number for each atom in the molecule by 1. - - fix_aam_smiles(smiles: str) -> str: - Takes a SMILES string, increments all atom mapping numbers by 1, and returns the updated SMILES. + """Utilities for incrementing and correcting atom‐atom mapping (AAM) + numbers in molecules and reaction SMILES. - fix_aam_rsmi(rsmi: str) -> str: - Adjusts atom mapping numbers in both reactant and product parts of a reaction SMILES (RSMI). + Provides methods to: + - Increment AAM on all atoms of an RDKit Mol. + - Adjust AAM numbers in a standalone SMILES string. + - Apply the same adjustment to both sides of a reaction SMILES (RSMI). """ @staticmethod def increment_atom_mapping(mol: Chem.Mol) -> Chem.Mol: - """ - Increments the atom mapping number of each atom in the molecule by 1. - - This method iterates through each atom in the molecule and increments its atom map - number (if it has one). + """Increment the atom‐map number of each atom in an RDKit Mol by 1. - Parameters: - mol (Chem.Mol): The RDKit molecule object that represents the molecule with atom mapping. - - Returns: - Chem.Mol: The updated RDKit molecule with incremented atom mapping numbers for all atoms. + :param mol: RDKit molecule with existing atom‐map annotations. + :type mol: Chem.Mol + :returns: The same Mol object with each atom’s map number + increased by one. + :rtype: Chem.Mol """ - # Iterate through all atoms in the molecule for atom in mol.GetAtoms(): - # Get the current atom map (if it exists) - atom_map = atom.GetAtomMapNum() - - atom.SetAtomMapNum(atom_map + 1) + atom.SetAtomMapNum(atom.GetAtomMapNum() + 1) return mol @staticmethod def fix_aam_smiles(smiles: str) -> str: + """Parse a SMILES string, increment all atom map numbers, and return + updated SMILES. + + :param smiles: SMILES string containing atom‐map annotations. + :type smiles: str + :returns: SMILES string with every atom‐map number increased by + one. + :rtype: str + :raises ValueError: If the input SMILES cannot be parsed into an + RDKit Mol. """ - Takes a SMILES string, increments all atom mapping numbers by 1, and returns the updated SMILES. - - This method converts the SMILES string into an RDKit molecule, increments the atom - mapping numbers, and returns the updated SMILES string. - - Parameters: - smiles (str): A SMILES string containing atom mapping numbers. - - Returns: - str: A new SMILES string with incremented atom mapping numbers for all atoms. - - Raises: - ValueError: If the input SMILES string is invalid and cannot be parsed into a molecule. - """ - # Create the molecule from the SMILES string mol: Optional[Chem.Mol] = Chem.MolFromSmiles(smiles, sanitize=False) if mol is None: - raise ValueError("Invalid SMILES string.") + raise ValueError(f"Invalid SMILES string: {smiles!r}") Chem.SanitizeMol(mol) - # Increment atom mapping numbers - updated_mol = FixAAM.increment_atom_mapping(mol) - - # Return the SMILES string with updated atom mappings - return Chem.MolToSmiles(updated_mol) + FixAAM.increment_atom_mapping(mol) + return Chem.MolToSmiles(mol) @staticmethod def fix_aam_rsmi(rsmi: str) -> str: + """Apply atom‐map increment to both reactant and product sides of a + reaction SMILES. + + :param rsmi: Reaction SMILES in 'reactants>>products' format + with atom‐map tags. + :type rsmi: str + :returns: New reaction SMILES string where each atom‐map number + in both halves is increased by one. + :rtype: str """ - Adjusts atom mapping numbers in both reactant and product parts of a reaction SMILES (RSMI). - - This method splits the reaction SMILES (RSMI) into its reactant and product components, - increments the atom mappings for both parts, and returns the updated reaction SMILES string. - - Parameters: - rsmi (str): A reaction SMILES string with atom mapping numbers. - - Returns: - str: A new reaction SMILES string with incremented atom mapping numbers in both reactant - and product parts. - """ - # Split the reaction SMILES into reactants and products - r, p = rsmi.split(">>") - - # Update both reactant and product SMILES strings - return f"{FixAAM.fix_aam_smiles(r)}>>{FixAAM.fix_aam_smiles(p)}" + react, prod = rsmi.split(">>") + new_react = FixAAM.fix_aam_smiles(react) + new_prod = FixAAM.fix_aam_smiles(prod) + return f"{new_react}>>{new_prod}" diff --git a/synkit/Chem/Reaction/neutralize.py b/synkit/Chem/Reaction/neutralize.py index eed307b..9745fdf 100644 --- a/synkit/Chem/Reaction/neutralize.py +++ b/synkit/Chem/Reaction/neutralize.py @@ -1,23 +1,25 @@ from rdkit import Chem from joblib import Parallel, delayed -from typing import Dict, Any, List, Union, Tuple +from typing import Dict, Any, List, Union, Tuple, Optional class Neutralize: - """ - A class for neutralizing unbalanced charges in a reaction. + """Neutralize unbalanced charges in chemical reactions by adding + counter‑ions. + + Provides utilities to calculate formal charges, parse reaction + SMILES, and adjust reactants/products with [Na+] or [Cl‑] to restore + neutrality. """ @staticmethod def calculate_charge(smiles: str) -> int: - """ - Calculates the formal charge of a given molecule represented by a SMILES string. + """Calculate the formal charge of a molecule. - Parameters: - - smiles (str): A SMILES string representing a molecule. - - Returns: - - int: The formal charge of the molecule. + :param smiles: SMILES string of the molecule. + :type smiles: str + :returns: Formal charge of the molecule (0 if invalid SMILES). + :rtype: int """ mol = Chem.MolFromSmiles(smiles) if mol is None: @@ -25,20 +27,15 @@ def calculate_charge(smiles: str) -> int: return Chem.rdmolops.GetFormalCharge(mol) @staticmethod - def parse_reaction(reaction_smiles: str) -> Tuple[str, str]: - """ - Parses a reaction SMILES string into reactants and products. - - Parameters: - - reaction_smiles (str): A reaction SMILES string of the form - "reactants>>products". - - Returns: - - Tuple[str, str]: A tuple containing the reactants and - products SMILES strings, respectively. - - This function uses a while loop and exception handling to - manage parsing errors and ensure the input is correctly formatted. + def parse_reaction(reaction_smiles: str) -> Tuple[Optional[str], Optional[str]]: + """Split a reaction SMILES into reactants and products. + + :param reaction_smiles: Reaction SMILES in 'reactants>>products' + format. + :type reaction_smiles: str + :returns: Tuple of (reactants, products) SMILES, or (None, None) + if parse fails. + :rtype: Tuple[Optional[str], Optional[str]] """ try: reactants, products = reaction_smiles.split(">>") @@ -48,20 +45,24 @@ def parse_reaction(reaction_smiles: str) -> Tuple[str, str]: @staticmethod def calculate_charge_dict( - reaction: Dict[str, str], reaction_column: str + reaction: Dict[str, Any], reaction_column: str ) -> Dict[str, Union[str, int]]: + """Compute and store the total formal charge of the products in a + reaction dict. + + :param reaction: Dictionary containing at least `reaction_column` with a reaction SMILES. + :type reaction: Dict[str, Any] + :param reaction_column: Key under which the reaction SMILES is stored. + :type reaction_column: str + :returns: The same dictionary updated with: + - 'reactants': reactant SMILES or None + - 'products': product SMILES or None + - 'total_charge_in_products': integer sum of product charges or None + :rtype: Dict[str, Union[str, int]] """ - Calculates and adds the total charge of products in a single reaction. - - Parameters: - - reaction (Dict[str, str]): A dictionary representing a reaction with keys - 'R-id' and 'new_reaction'. - - Returns: - - Dict[str, Union[str, int]]: The same reaction dictionary, with an added key - 'total_charge_in_products' indicating the sum of formal charges in its products. - """ - reactants, products = Neutralize.parse_reaction(reaction[reaction_column]) + reactants, products = Neutralize.parse_reaction( + reaction.get(reaction_column, "") + ) if reactants is None or products is None: reaction.update( {"reactants": None, "products": None, "total_charge_in_products": None} @@ -69,188 +70,126 @@ def calculate_charge_dict( else: reaction["reactants"] = reactants reaction["products"] = products - products = products.split(".") - total_charge = sum( - Neutralize.calculate_charge(product) for product in products - ) - reaction["total_charge_in_products"] = total_charge + total = sum(Neutralize.calculate_charge(p) for p in products.split(".")) + reaction["total_charge_in_products"] = total return reaction @staticmethod def fix_negative_charge( - reaction_dict: Dict[str, any], + reaction_dict: Dict[str, Any], charges_column: str = "total_charge_in_products", id_column: str = "R-id", reaction_column: str = "reactions", - ) -> Dict[str, any]: - """ - Adjusts a reaction dictionary to compensate for a negative charge - in the products by adding [Na+] ions. - - This function calculates the number of sodium ions ([Na+]) needed to neutralize - negative charges in the reaction products. It then adds the appropriate number of - sodium ions to both the reactants and products. - - Parameters:: - - reaction_dict (Dict[str, any]): A dictionary representing a chemical reaction. - Must include keys for 'total_charge_in_products', 'reactants', 'products', 'R-id', - and 'label'. - - Returns: - - Dict[str, any]: A new reaction dictionary with adjusted reactants and products - to neutralize the negative charge. The 'total_charge_in_products' is set to 0, - assuming the charge has been neutralized. + ) -> Dict[str, Any]: + """Add [Na+] ions to neutralize negative product charge. + + :param reaction_dict: Dictionary with 'reactants', 'products', and charge info. + :type reaction_dict: Dict[str, Any] + :param charges_column: Key for product total charge. Defaults to 'total_charge_in_products'. + :type charges_column: str + :param id_column: Key for reaction identifier. Defaults to 'R-id'. + :type id_column: str + :param reaction_column: Key for reaction SMILES to update. Defaults to 'reactions'. + :type reaction_column: str + :returns: New dictionary with: + - updated `reaction_column` including added [Na+] ions + - 'reactants' and 'products' with ions appended + - charge column set to 0 + :rtype: Dict[str, Any] """ - - num_na_to_add = abs(reaction_dict[charges_column]) - sodium_ion = "[Na+]" - - # Generate the string to add, with the correct number of sodium ions - sodium_addition = ( - "." + ".".join([sodium_ion] * num_na_to_add) if num_na_to_add > 0 else "" - ) - - # Add the sodium ions to reactants and products - new_reactants = reaction_dict["reactants"] + sodium_addition - new_products = reaction_dict["products"] + sodium_addition - - # Generate the new reaction string - new_reactions = new_reactants + ">>" + new_products - - # Create the new reaction dictionary - new_reaction_dict = { - id_column: reaction_dict["R-id"], - reaction_column: new_reactions, - "reactants": new_reactants, - "products": new_products, - charges_column: 0, # Assuming the charge is neutralized + num_to_add = abs(reaction_dict.get(charges_column, 0)) + sodium = "[Na+]" + addition = ("." + ".".join([sodium] * num_to_add)) if num_to_add else "" + new_react = reaction_dict["reactants"] + addition + new_prod = reaction_dict["products"] + addition + new_reaction = f"{new_react}>>{new_prod}" + + return { + id_column: reaction_dict.get("R-id"), + reaction_column: new_reaction, + "reactants": new_react, + "products": new_prod, + charges_column: 0, } - return new_reaction_dict - @staticmethod def fix_positive_charge( - reaction_dict: Dict[str, any], + reaction_dict: Dict[str, Any], charges_column: str = "total_charge_in_products", id_column: str = "R-id", reaction_column: str = "reactions", - ) -> Dict[str, any]: - """ - Adjusts a reaction dictionary to compensate for a positive charge - in the products by adding [Cl-] ions. The function - takes into account the total positive charge indicated - in the reaction dictionary and adds an equivalent number of - chloride ions ([Cl-]) to both reactants and products to neutralize the charge. - - Parameters:: - - reaction_dict (Dict[str, any]): A dictionary representing a chemical reaction. - This dictionary must include keys for reactants, products, and a specified charge - column (default is 'total_charge_in_products') which contains the total charge of - the products. - - charges_column (str, optional): The key in `reaction_dict` that contains the - total charge of the products. Defaults to 'total_charge_in_products'. - - Returns: - - Dict[str, any]: A modified reaction dictionary with added [Cl-] ions to - neutralize the positive charge. The 'total_charge_in_products' is updated to 0, - indicating that the reaction's charge has been neutralized. The dictionary - includes updated 'reactants', 'products', and a new reaction string. + ) -> Dict[str, Any]: + """Add [Cl‑] ions to neutralize positive product charge. + + :param reaction_dict: Dictionary with 'reactants', 'products', and charge info. + :type reaction_dict: Dict[str, Any] + :param charges_column: Key for product total charge. Defaults to 'total_charge_in_products'. + :type charges_column: str + :param id_column: Key for reaction identifier. Defaults to 'R-id'. + :type id_column: str + :param reaction_column: Key for reaction SMILES to update. Defaults to 'reactions'. + :type reaction_column: str + :returns: New dictionary with: + - updated `reaction_column` including added [Cl‑] ions + - 'reactants' and 'products' with ions appended + - charge column set to 0 + :rtype: Dict[str, Any] """ - - num_cl_to_add = abs(reaction_dict[charges_column]) - chloride_ion = "[Cl-]" - - # Generate the string to add, with the correct number of chloride ions - chloride_addition = ( - "." + ".".join([chloride_ion] * num_cl_to_add) if num_cl_to_add > 0 else "" - ) - - # Add the chloride ions to reactants and products - new_reactants = reaction_dict["reactants"] + chloride_addition - new_products = reaction_dict["products"] + chloride_addition - - # Generate the new reaction string - new_reactions = new_reactants + ">>" + new_products - - # Create and return the new reaction dictionary with the neutralized charge - new_reaction_dict = { - "R-id": reaction_dict[id_column], - reaction_column: new_reactions, - "reactants": new_reactants, - "products": new_products, + num_to_add = abs(reaction_dict.get(charges_column, 0)) + chloride = "[Cl-]" + addition = ("." + ".".join([chloride] * num_to_add)) if num_to_add else "" + new_react = reaction_dict["reactants"] + addition + new_prod = reaction_dict["products"] + addition + new_reaction = f"{new_react}>>{new_prod}" + + return { + id_column: reaction_dict.get("R-id"), + reaction_column: new_reaction, + "reactants": new_react, + "products": new_prod, charges_column: 0, } - return new_reaction_dict - @staticmethod def fix_unbalanced_charged( - reaction_dict: Dict[str, any], - reaction_column: str, - ) -> Dict[str, any]: + reaction_dict: Dict[str, Any], reaction_column: str + ) -> Dict[str, Any]: + """Detect and neutralize unbalanced product charge by adding + counter‑ions. + + :param reaction_dict: Dictionary with raw reaction SMILES under `reaction_column`. + :type reaction_dict: Dict[str, Any] + :param reaction_column: Key for reaction SMILES in the input dict. + :type reaction_column: str + :returns: Dictionary with balanced charges and updated SMILES. + :rtype: Dict[str, Any] """ - Adjusts a reaction dictionary to compensate for an unbalanced charge in the - products by adding either [Cl-] ions for a positive charge or [Na+] ions for a - negative charge. The function determines the direction of the charge imbalance - using the specified charges column and applies the appropriate correction. - - Parameters:: - - reaction_dict (Dict[str, any]): A dictionary representing a chemical reaction. - This dictionary must include keys for reactants, products, and a specified charge - column which contains the total charge of the products. - - charges_column (str, optional): The key in `reaction_dict` that contains the - total charge of the products. Defaults to 'total_charge_in_products'. - - Returns: - - Dict[str, any]: A modified reaction dictionary with added ions to neutralize the - charge imbalance. The returned dictionary will have its charge neutralized and - include updated 'reactants', 'products', and a new reaction string. The specific - ions added ([Cl-] for positive charges or [Na+] for negative charges) depend on - the initial charge imbalance. - """ - reaction_dict = Neutralize.calculate_charge_dict(reaction_dict, reaction_column) - if reaction_dict["total_charge_in_products"] > 0: - return Neutralize.fix_positive_charge( - reaction_dict, "total_charge_in_products" - ) - elif reaction_dict["total_charge_in_products"] < 0: - return Neutralize.fix_negative_charge( - reaction_dict, "total_charge_in_products" - ) - else: - return reaction_dict + rd = Neutralize.calculate_charge_dict(reaction_dict, reaction_column) + total = rd.get("total_charge_in_products", 0) + if total > 0: + return Neutralize.fix_positive_charge(rd) + if total < 0: + return Neutralize.fix_negative_charge(rd) + return rd @classmethod def parallel_fix_unbalanced_charge( - cls, - reaction_dicts: List[Dict[str, Any]], - reaction_column: str, - n_jobs: int = 4, + cls, reaction_dicts: List[Dict[str, Any]], reaction_column: str, n_jobs: int = 4 ) -> List[Dict[str, Any]]: + """Neutralize charges in multiple reaction dictionaries in parallel. + + :param reaction_dicts: List of reaction dictionaries to process. + :type reaction_dicts: List[Dict[str, Any]] + :param reaction_column: Key for reaction SMILES in each dict. + :type reaction_column: str + :param n_jobs: Number of parallel jobs (use -1 for all cores). + Defaults to 4. + :type n_jobs: int + :returns: List of dictionaries with balanced charges and updated + SMILES. + :rtype: List[Dict[str, Any]] """ - Processes a list of reaction dictionaries in parallel to compensate - for unbalanced charges in the products, adding either [Cl-] ions - for positive charges or [Na+] ions for negative charges. - - Parameters:: - - reaction_dicts (List[Dict[str, Any]]): A list of dictionaries, each representing - a chemical reaction that may have an unbalanced charge. - - charges_column (str): The key in each reaction dictionary that contains the - total charge of the products. Defaults to 'total_charge_in_products'. - - n_jobs (int): The number of CPU cores to use for parallel processing. - -1 means using all available cores. - - Returns: - - List[Dict[str, Any]]: A list of modified reaction dictionaries - with charges neutralized, reflecting the addition of necessary ions. - - Note: - - This function requires the joblib library for parallel execution. - Ensure joblib is installed and available for import. - """ - # Use joblib.Parallel and joblib.delayed to parallelize the charge fixing - fixed_reactions = Parallel(n_jobs=n_jobs)( - delayed(cls.fix_unbalanced_charged)(reaction_dict, reaction_column) - for reaction_dict in reaction_dicts + return Parallel(n_jobs=n_jobs)( + delayed(cls.fix_unbalanced_charged)(d, reaction_column) + for d in reaction_dicts ) - return fixed_reactions diff --git a/synkit/Chem/Reaction/radical_wildcard.py b/synkit/Chem/Reaction/radical_wildcard.py index f431b44..166d981 100644 --- a/synkit/Chem/Reaction/radical_wildcard.py +++ b/synkit/Chem/Reaction/radical_wildcard.py @@ -5,9 +5,9 @@ class RadicalWildcardAdder: - """ - A utility for adding wildcard dummy atoms ([*]) to radical centers in reaction SMILES, - with unique incremental atom-map indices and correct propagation into products. + """A utility for adding wildcard dummy atoms ([*]) to radical centers in + reaction SMILES, with unique incremental atom-map indices and correct + propagation into products. Each reactive radical atom in the reactant block is identified by its unpaired electron count, assigned one or more wildcard map indices, and recorded. The same wildcard(s) are then appended @@ -27,37 +27,36 @@ class RadicalWildcardAdder: """ def __init__(self, start_map: Optional[int] = None) -> None: - """ - Initialize the adder with an optional starting map index. + """Initialize the adder with an optional starting map index. - :param start_map: Starting atom-map index for wildcards or None to auto-pick. + :param start_map: Starting atom-map index for wildcards or None + to auto-pick. :type start_map: Optional[int] """ self.start_map = start_map def __repr__(self) -> str: - """ - Official representation. - """ + """Official representation.""" return f"" def __str__(self) -> str: - """ - User-friendly description. - """ + """User-friendly description.""" m = self.start_map if self.start_map is not None else "auto" return f"RadicalWildcardAdder(start_map={m})" def transform(self, rxn_smiles: str) -> str: - """ - Append wildcard dummy atoms to each radical center in the reactant block - and propagate the same wildcards to the matching atoms in the product block. + """Append wildcard dummy atoms to each radical center in the reactant + block and propagate the same wildcards to the matching atoms in the + product block. - :param rxn_smiles: Reaction SMILES string, two-component or three-component. + :param rxn_smiles: Reaction SMILES string, two-component or + three-component. :type rxn_smiles: str - :returns: Modified reaction SMILES with consistent wildcard attachments. + :returns: Modified reaction SMILES with consistent wildcard + attachments. :rtype: str - :raises ValueError: If the SMILES is not valid or fragments fail to parse. + :raises ValueError: If the SMILES is not valid or fragments fail + to parse. """ # Split into reactants > agents? > products react_blk, agents_blk, prod_blk = self._split_reaction(rxn_smiles) @@ -148,14 +147,16 @@ def _process(frags: List[str], propagate: bool) -> List[str]: @staticmethod def _split_reaction(rxn: str) -> Tuple[str, Optional[str], str]: - """ - Split a reaction SMILES into reactants, agents (optional), and products. + """Split a reaction SMILES into reactants, agents (optional), and + products. :param rxn: The reaction SMILES string. :type rxn: str - :returns: Tuple of (reactants_block, agents_block or None, products_block). + :returns: Tuple of (reactants_block, agents_block or None, + products_block). :rtype: Tuple[str, Optional[str], str] - :raises ValueError: If the SMILES does not contain 2 or 3 '>' symbols. + :raises ValueError: If the SMILES does not contain 2 or 3 '>' + symbols. """ parts = rxn.split(">") if len(parts) == 2: diff --git a/synkit/Chem/Reaction/rsmi_utils.py b/synkit/Chem/Reaction/rsmi_utils.py deleted file mode 100644 index e567195..0000000 --- a/synkit/Chem/Reaction/rsmi_utils.py +++ /dev/null @@ -1,126 +0,0 @@ -from rdkit import Chem -from rdkit.Chem import rdChemReactions -from typing import List, Tuple, Optional - - -def remove_explicit_H_from_rsmi(rsmi: str) -> str: - """ - Remove explicit [H:...] atoms from a reaction SMILES with atom-atom mapping. - Keeps atom mapping intact for non-hydrogen atoms and returns a simplified reaction SMILES. - - Args: - rsmi (str): Atom-mapped reaction SMILES with explicit hydrogens. - - Returns: - str: Reaction SMILES with implicit hydrogens and AAM preserved. - """ - rxn = rdChemReactions.ReactionFromSmarts(rsmi, useSmiles=True) - - def cleaned_smiles(mols): - return ".".join( - Chem.MolToSmiles(Chem.RemoveHs(mol), isomericSmiles=True) for mol in mols - ) - - reactant_smiles = cleaned_smiles(rxn.GetReactants()) - product_smiles = cleaned_smiles(rxn.GetProducts()) - - return f"{reactant_smiles}>>{product_smiles}" - - -def remove_common_reagents(reaction_smiles: str) -> Tuple[Optional[str], Optional[str]]: - """ - Removes reagents that appear on both sides of a reaction SMILES string. - - Parameters: - - reaction_smiles (str): The reaction in SMILES format. - - Returns: - - Tuple[Optional[str], Optional[str]]: A tuple containing the cleaned reaction SMILES - and a list of common reagents removed. If no common reagents are found, the reaction - is returned unchanged and the second element of the tuple is `None`. - """ - reactants, products = reaction_smiles.split(">>") - reactant_list = reactants.split(".") - product_list = products.split(".") - common_reagents = set(reactant_list) & set(product_list) - - filtered_reactants = [r for r in reactant_list if r not in common_reagents] - filtered_products = [p for p in product_list if p not in common_reagents] - cleaned_reaction_smiles = ( - ".".join(filtered_reactants) + ">>" + ".".join(filtered_products) - ) - - return cleaned_reaction_smiles - - -def remove_duplicates(input_list: List[str]) -> List[str]: - """ - Removes duplicate strings from a list, retaining only the first occurrence of each string. - - Parameters: - - input_list (List[str]): A list of strings potentially containing duplicates. - - Returns: - - List[str]: A list of strings with duplicates removed. - """ - seen = set() - result = [] - for item in input_list: - if item not in seen: - result.append(item) - seen.add(item) - return result - - -def reverse_reaction(rsmi: str) -> str: - """ - Reverses the direction of a reaction SMILES string. - - Parameters: - - rsmi (str): The reaction SMILES string to reverse. - - Returns: - - str: The reversed reaction SMILES string. - """ - reactants, products = rsmi.split(">>") - return f"{products}>>{reactants}" - - -def merge_reaction(rsmi_1: str, rsmi_2: str) -> str: - """ - Merges two reaction SMILES strings into a single reaction. - - Parameters: - - rsmi_1 (str): The first reaction SMILES string. - - rsmi_2 (str): The second reaction SMILES string. - - Returns: - - str: A new reaction SMILES string combining both input reactions. - """ - try: - r1, p1 = rsmi_1.split(">>") - r2, p2 = rsmi_2.split(">>") - except ValueError: - return None # Returns None if there's a problem with splitting (e.g., no '>>') - - # Check if any part of the reaction is empty, which could be problematic for a meaningful merge. - if any(len(part.strip()) == 0 for part in (r1, p1, r2, p2)): - return None - - return f"{r1}.{r2}>>{p1}.{p2}" - - -def find_longest_fragment(input_list: List[str]) -> str: - """ - Finds the longest string in a list of strings. - - Parameters: - - input_list (List[str]): A list of strings from which the longest string is to be found. - - Returns: - - str: The longest string found in the input list. - """ - if len(input_list) == 0: - return None - longest_fragment = max(input_list, key=len) - return longest_fragment diff --git a/synkit/Chem/Reaction/standardize.py b/synkit/Chem/Reaction/standardize.py index def39ba..917a47f 100644 --- a/synkit/Chem/Reaction/standardize.py +++ b/synkit/Chem/Reaction/standardize.py @@ -1,31 +1,37 @@ -from rdkit import Chem from typing import List, Optional, Tuple +from rdkit import Chem class Standardize: - """ - A collection of utilities to normalize and filter reaction and molecule SMILES. + """Utilities to normalize and filter reaction and molecule SMILES. + + This class provides methods to remove atom‑mapping, filter invalid molecules, + canonicalize reaction SMILES, and a full pipeline via `fit`. + + :ivar None: Stateless helper class. """ def __init__(self) -> None: - """ - Initialize the Standardize helper. + """Initialize the Standardize helper. + + No instance attributes are set. """ pass @staticmethod def remove_atom_mapping(reaction_smiles: str, symbol: str = ">>") -> str: - """ - Remove atom-map numbers from both sides of a reaction SMILES. + """Remove atom‑map numbers from a reaction SMILES string. - :param reaction_smiles: A reaction SMILES string with atom mappings. + :param reaction_smiles: Reaction SMILES with atom maps, e.g. + 'C[CH3:1]>>C'. :type reaction_smiles: str - :param symbol: The separator between reactants and products. Defaults to ">>". + :param symbol: Separator between reactants and products. + Defaults to '>>'. :type symbol: str - :returns: The reaction SMILES with all atom-map annotations stripped. + :returns: Reaction SMILES without atom‑mapping annotations. :rtype: str - :raises ValueError: If the input is not in "reactants>>products" format - or contains invalid SMILES. + :raises ValueError: If the input format is invalid or contains + invalid SMILES. """ parts = reaction_smiles.split(symbol) if len(parts) != 2: @@ -46,15 +52,15 @@ def clean_smiles(smi: str) -> str: @staticmethod def filter_valid_molecules(smiles_list: List[str]) -> List[Chem.Mol]: - """ - Convert a list of SMILES to RDKit Mol objects, keeping only valid molecules. + """Filter and sanitize a list of SMILES, returning only valid Mol + objects. - :param smiles_list: A list of SMILES strings. - :type smiles_list: list of str - :returns: A list of sanitized RDKit Mol objects. - :rtype: list of rdkit.Chem.Mol + :param smiles_list: List of SMILES strings to validate. + :type smiles_list: List[str] + :returns: List of sanitized RDKit Mol objects. + :rtype: List[rdkit.Chem.Mol] """ - valid = [] + valid: List[Chem.Mol] = [] for smi in smiles_list: mol = Chem.MolFromSmiles(smi, sanitize=False) if mol: @@ -62,21 +68,21 @@ def filter_valid_molecules(smiles_list: List[str]) -> List[Chem.Mol]: Chem.SanitizeMol(mol) valid.append(mol) except Exception: - pass + continue return valid @staticmethod def standardize_rsmi(rsmi: str, stereo: bool = False) -> Optional[str]: """ - Normalize a reaction SMILES by validating, sorting, and optional stereochemistry. + Normalize a reaction SMILES: validate molecules, sort fragments, optionally keep stereo. - :param rsmi: The reaction SMILES to standardize. + :param rsmi: Reaction SMILES in 'reactants>>products' format. :type rsmi: str - :param stereo: If True, include stereochemical information. Defaults to False. + :param stereo: If True, include stereochemistry in the output. Defaults to False. :type stereo: bool - :returns: The standardized reaction SMILES or None if no valid molecules remain. - :rtype: str or None - :raises ValueError: If the input is not in "reactants>>products" format. + :returns: Standardized reaction SMILES or None if no valid molecules remain. + :rtype: Optional[str] + :raises ValueError: If the input format is invalid. """ try: react_str, prod_str = rsmi.split(">>") @@ -104,16 +110,16 @@ def fit( self, rsmi: str, remove_aam: bool = True, ignore_stereo: bool = True ) -> Optional[str]: """ - Full standardization pipeline: remove atom-maps, normalize SMILES, fix H notation. + Full standardization pipeline: strip atom‑mapping, normalize SMILES, fix hydrogen notation. - :param rsmi: The reaction SMILES to process. + :param rsmi: Reaction SMILES to process. :type rsmi: str - :param remove_aam: If True, strip atom-mapping numbers. Defaults to True. + :param remove_aam: If True, remove atom‑mapping annotations. Defaults to True. :type remove_aam: bool :param ignore_stereo: If True, drop stereochemistry. Defaults to True. :type ignore_stereo: bool - :returns: The processed reaction SMILES or None if standardization fails. - :rtype: str or None + :returns: The standardized reaction SMILES, or None if standardization fails. + :rtype: Optional[str] """ if remove_aam: rsmi = self.remove_atom_mapping(rsmi) @@ -122,27 +128,27 @@ def fit( if std is None: return None - # Explicitly format double hydrogens + # Format any double‑hydrogen notation return std.replace("[HH]", "[H][H]") @staticmethod def categorize_reactions( reactions: List[str], target_reaction: str ) -> Tuple[List[str], List[str]]: - """ - Partition a list of reaction SMILES into those matching a target and those not. + """Partition reactions into those matching a target and those not. - :param reactions: List of reaction SMILES strings to categorize. - :type reactions: list of str - :param target_reaction: The benchmark reaction SMILES for matching. + :param reactions: List of reaction SMILES to categorize. + :type reactions: List[str] + :param target_reaction: Benchmark reaction SMILES for comparison. :type target_reaction: str - :returns: A pair `(matches, non_matches)`: - - `matches`: reactions equal to the standardized target. - - `non_matches`: all others. - :rtype: tuple (list of str, list of str) + :returns: Tuple of (matches, non_matches): + - matches: reactions equal to standardized target + - non_matches: all others + :rtype: Tuple[List[str], List[str]] """ tgt = Standardize.standardize_rsmi(target_reaction, stereo=False) - matches, non_matches = [], [] + matches: List[str] = [] + non_matches: List[str] = [] for rxn in reactions: if rxn == tgt: matches.append(rxn) diff --git a/synkit/Chem/Reaction/tautomerize.py b/synkit/Chem/Reaction/tautomerize.py index 2ac3e51..2da5b2d 100644 --- a/synkit/Chem/Reaction/tautomerize.py +++ b/synkit/Chem/Reaction/tautomerize.py @@ -5,24 +5,22 @@ class Tautomerize: - """ - A class to standardize molecules by converting specific functional groups to their - more common forms using RDKit for molecule manipulation. - """ + """Standardize molecules by converting enol and hemiketal tautomers into + their more stable carbonyl forms, and apply these corrections to individual + SMILES or collections of reaction data.""" @staticmethod def standardize_enol(smiles: str, atom_indices: Optional[List[int]] = None) -> str: - """ - Converts an enol form to a carbonyl form based on specified atom indices. - - Parameters: - - smiles (str): The SMILES string. - - atom_indices (List[int], optional): List containing indices of two carbons and - one oxygen involved in the enol formation. Defaults to [0, 1, 2]. - - Returns: - - str: The SMILES string of the molecule after conversion. - Returns an error message if indices are invalid. + """Convert an enol tautomer into its corresponding carbonyl form. + + :param smiles: SMILES string of the enol-containing molecule. + :type smiles: str + :param atom_indices: List of three atom indices [C1, C2, O] + defining the enol. If None, defaults to [0, 1, 2]. + :type atom_indices: List[int] or None + :returns: SMILES of the molecule after enol→carbonyl conversion, + or an error message if the input is invalid or indices fail. + :rtype: str """ if atom_indices is None: atom_indices = [0, 1, 2] @@ -33,39 +31,40 @@ def standardize_enol(smiles: str, atom_indices: Optional[List[int]] = None) -> s emol = Chem.EditableMol(mol) try: - c1_idx, c2_idx = ( + c_idxs = [ i for i in atom_indices if mol.GetAtomWithIdx(i).GetSymbol() == "C" - ) + ] + c1_idx, c2_idx = c_idxs[:2] o_idx = next( i for i in atom_indices if mol.GetAtomWithIdx(i).GetSymbol() == "O" ) except Exception as e: - return f"Error processing indices: {str(e)}" + return f"Error processing indices: {e}" try: emol.RemoveBond(c1_idx, c2_idx) emol.RemoveBond(c2_idx, o_idx) - emol.AddBond(c1_idx, c2_idx, order=Chem.rdchem.BondType.SINGLE) - emol.AddBond(c2_idx, o_idx, order=Chem.rdchem.BondType.DOUBLE) + emol.AddBond(c1_idx, c2_idx, Chem.rdchem.BondType.SINGLE) + emol.AddBond(c2_idx, o_idx, Chem.rdchem.BondType.DOUBLE) new_mol = emol.GetMol() Chem.SanitizeMol(new_mol) return Chem.MolToSmiles(new_mol) except Exception as e: - return f"Error in modifying molecule: {str(e)}" + return f"Error in modifying molecule: {e}" @staticmethod def standardize_hemiketal(smiles: str, atom_indices: List[int]) -> str: - """ - Converts a hemiketal form to a carbonyl form based on specified atom indices. - - Parameters: - - smiles (str): SMILES representation of the original molecule. - - atom_indices (List[int]): Indices of the carbon and two oxygen atoms - involved in the transformation. - - Returns: - - str: SMILES string of the modified molecule if successful, - otherwise returns an error message. + """Convert a hemiketal tautomer into its corresponding carbonyl form. + + :param smiles: SMILES string of the hemiketal-containing + molecule. + :type smiles: str + :param atom_indices: List of atom indices [C, O1, O2] defining + the hemiketal. + :type atom_indices: List[int] + :returns: SMILES of the molecule after hemiketal→carbonyl + conversion, or an error message if the input is invalid. + :rtype: str """ mol = Chem.MolFromSmiles(smiles) if mol is None: @@ -76,67 +75,65 @@ def standardize_hemiketal(smiles: str, atom_indices: List[int]) -> str: c_idx = next( i for i in atom_indices if mol.GetAtomWithIdx(i).GetSymbol() == "C" ) - o1_idx, o2_idx = ( + o_idxs = [ i for i in atom_indices if mol.GetAtomWithIdx(i).GetSymbol() == "O" - ) + ] + o1_idx = o_idxs[0] + except Exception as e: + return f"Error processing indices: {e}" + + try: emol.RemoveBond(c_idx, o1_idx) - emol.RemoveBond(c_idx, o2_idx) - emol.AddBond(c_idx, o1_idx, order=Chem.rdchem.BondType.DOUBLE) + if len(o_idxs) > 1: + emol.RemoveBond(c_idx, o_idxs[1]) + emol.AddBond(c_idx, o1_idx, Chem.rdchem.BondType.DOUBLE) new_mol = emol.GetMol() Chem.SanitizeMol(new_mol) return Chem.MolToSmiles(new_mol) except Exception as e: - return f"Error in modifying molecule: {str(e)}" + return f"Error in modifying molecule: {e}" @staticmethod def fix_smiles(smiles: str) -> str: - """ - Performs the standardization process by identifying and converting all relevant - functional groups to their target forms based on predefined rules and updates the - SMILES string accordingly. + """Iteratively apply enol and hemiketal standardizations until no + further changes, then return the canonical SMILES. - Parameters: - - smiles (str): SMILES string of the original molecule. - - Returns: - - str: Canonical SMILES string of the standardized molecule. + :param smiles: SMILES string to standardize. + :type smiles: str + :returns: Canonical SMILES of the standardized molecule. + :rtype: str """ query = FGQuery() fg = query.get(smiles) for item in fg: - if "hemiketal" in item: - atom_indices = item[1] - smiles = Tautomerize.standardize_hemiketal(smiles, atom_indices) + label, indices = item + if label == "hemiketal": + smiles = Tautomerize.standardize_hemiketal(smiles, indices) fg = query.get(smiles) - elif "enol" in item: - atom_indices = item[1] - smiles = Tautomerize.standardize_enol(smiles, atom_indices) + elif label == "enol": + smiles = Tautomerize.standardize_enol(smiles, indices) fg = query.get(smiles) return Chem.CanonSmiles(smiles) @staticmethod def fix_dict(data: Dict[str, str], reaction_column: str) -> Dict[str, str]: - """ - Updates a dictionary containing reaction data by - standardizing the SMILES strings of reactants and products. - - Parameters: - - data (Dict[str, str]): Dictionary containing the reaction data. - - reaction_column (str): The key in the dictionary where the reaction SMILES - string is stored. - - Returns: - - Dict[str, str]: The updated dictionary with standardized SMILES strings. + """Standardize the reactant and product SMILES in a reaction + dictionary. + + :param data: Dictionary containing a reaction SMILES under `reaction_column`. + :type data: Dict[str, str] + :param reaction_column: Key in `data` where the reaction SMILES is stored. + :type reaction_column: str + :returns: The same dictionary with standardized reaction SMILES. + :rtype: Dict[str, str] """ try: - reactants, products = data[reaction_column].split(">>") - reactants = Tautomerize.fix_smiles(reactants) - products = Tautomerize.fix_smiles(products) - data[reaction_column] = f"{reactants}>>{products}" + react, prod = data[reaction_column].split(">>") + data[reaction_column] = ( + f"{Tautomerize.fix_smiles(react)}>>{Tautomerize.fix_smiles(prod)}" + ) except ValueError: - smiles = data[reaction_column] - smiles = Tautomerize.fix_smiles(smiles) - data[reaction_column] = smiles + data[reaction_column] = Tautomerize.fix_smiles(data[reaction_column]) return data @staticmethod @@ -146,21 +143,18 @@ def fix_dicts( n_jobs: int = 4, verbose: int = 0, ) -> List[Dict[str, str]]: - """ - Standardizes multiple dictionaries containing - reaction data in parallel. - - Parameters: - - data (List[Dict[str, str]]): List of dictionaries, each containing reaction - data. - - reaction_column (str): The key where the reaction SMILES strings are - stored in each dictionary. - - n_jobs (int, optional): Number of jobs to run in parallel. Defaults to 4. - - verbose (int, optional): The verbosity level. Defaults to 0. - - Returns: - - List[Dict[str, str]]: A list of updated dictionaries - with standardized SMILES strings. + """Standardize multiple reaction dictionaries in parallel. + + :param data: List of dictionaries containing reaction SMILES under `reaction_column`. + :type data: List[Dict[str, str]] + :param reaction_column: Key in each dictionary for the reaction SMILES. + :type reaction_column: str + :param n_jobs: Number of parallel jobs to run. Defaults to 4. + :type n_jobs: int + :param verbose: Verbosity level for the joblib Parallel call. Defaults to 0. + :type verbose: int + :returns: List of dictionaries with standardized SMILES. + :rtype: List[Dict[str, str]] """ results = Parallel(n_jobs=n_jobs, verbose=verbose)( delayed(Tautomerize.fix_dict)(d, reaction_column) for d in data diff --git a/synkit/Chem/utils.py b/synkit/Chem/utils.py index afb05c6..763902e 100644 --- a/synkit/Chem/utils.py +++ b/synkit/Chem/utils.py @@ -1,105 +1,123 @@ from rdkit import Chem -from typing import List, Union -from synkit.IO.debug import setup_logging - -logger = setup_logging() +from rdkit.Chem.MolStandardize import rdMolStandardize +from rdkit.Chem import rdChemReactions +import re +from typing import List, Optional, Tuple, Union + + +def enumerate_tautomers(reaction_smiles: str) -> Optional[List[str]]: + """Enumerate possible tautomers of reactants while canonicalizing products. + + :param reaction_smiles: Reaction SMILES in 'reactants>>products' + format. + :type reaction_smiles: str + :returns: List of reaction SMILES for each reactant tautomer + (including the original), or None on error. + :rtype: Optional[List[str]] + :raises ValueError: If reactant or product SMILES are invalid. + """ + try: + reactants_smiles, products_smiles = reaction_smiles.split(">>") + reactants_mol = Chem.MolFromSmiles(reactants_smiles) + products_mol = Chem.MolFromSmiles(products_smiles) + if reactants_mol is None or products_mol is None: + raise ValueError("Invalid reactant or product SMILES.") + enumerator = rdMolStandardize.TautomerEnumerator() + reactants_tautos = enumerator.Enumerate(reactants_mol) or [reactants_mol] + prod_can = Chem.MolToSmiles(products_mol, canonical=True) + rsmi_list = [Chem.MolToSmiles(m) + ">>" + prod_can for m in reactants_tautos] + rsmi_list.insert(0, reaction_smiles) + return rsmi_list + except ValueError: + raise + except Exception: + return None + + +def mapping_success_rate(list_mapping_data: List[str]) -> float: + """Calculate percentage of entries containing atom‑mapping annotations. + + :param list_mapping_data: List of strings to search for mappings. + :type list_mapping_data: List[str] + :returns: Percentage of entries containing `:` patterns, + rounded to two decimals. + :rtype: float + :raises ValueError: If input list is empty. + """ + if not list_mapping_data: + raise ValueError("The input list is empty, cannot calculate success rate.") + pattern = re.compile(r":\d+") + success = sum(1 for entry in list_mapping_data if pattern.search(entry)) + return round(100 * success / len(list_mapping_data), 2) def count_carbons(smiles: str) -> int: - """ " - Counts the number of carbon atoms in a molecule given a SMILES string. - - Parameters: - - smiles (str): SMILES representation of the molecule. - - Returns: - - int: Number of carbon atoms in the molecule if the SMILES string is valid. - - str: Error message indicating an invalid SMILES string. + """Count the number of carbon atoms in a molecule. + + :param smiles: SMILES string of the molecule. + :type smiles: str + :returns: Number of carbon atoms, or raises ValueError if SMILES + invalid. + :rtype: int + :raises ValueError: If the SMILES string is invalid. """ mol = Chem.MolFromSmiles(smiles) - - if mol: - carbon_count = sum(1 for atom in mol.GetAtoms() if atom.GetSymbol() == "C") - return carbon_count - else: - return "Invalid SMILES string" + if mol is None: + raise ValueError(f"Invalid SMILES string: {smiles}") + return sum(1 for atom in mol.GetAtoms() if atom.GetSymbol() == "C") def get_max_fragment(smiles: Union[str, List[str]]) -> str: - """ - Extracts and returns the SMILES string of the largest fragment from a SMILES string or a list of SMILES strings - of a compound that may contain multiple fragments. This function determines the largest fragment based on the - number of atoms. - - Parameters: - - smiles (Union[str, List[str]]): A single SMILES string or a list of SMILES strings containing potentially - multiple fragments. + """Return the largest fragment by atom count from SMILES. - Returns: - - str: SMILES string of the largest fragment based on the number of atoms. Returns an empty string if no valid - fragments can be processed. - - Examples: - - get_max_fragment("C.CC.CCC") returns "CCC" - - get_max_fragment(["C.CC", "CCC.C"]) returns "CCC" + :param smiles: SMILES string(s), possibly with '.' separators. + :type smiles: str or List[str] + :returns: SMILES of the fragment with the most atoms, or empty + string if none valid. + :rtype: str """ if isinstance(smiles, str): fragments = smiles.split(".") - elif isinstance(smiles, list): - fragments = [frag for s in smiles for frag in s.split(".")] else: + fragments = [frag for s in smiles for frag in s.split(".")] + mols = [Chem.MolFromSmiles(f) for f in fragments if f] + mols = [m for m in mols if m] + if not mols: return "" - - molecules = [Chem.MolFromSmiles(fragment) for fragment in fragments if fragment] - if not molecules: - return "" # Return empty string if no valid molecules are found - - max_mol = max( - molecules, key=lambda mol: mol.GetNumAtoms() if mol else 0, default=None - ) - return Chem.MolToSmiles(max_mol) if max_mol else "" + max_mol = max(mols, key=lambda m: m.GetNumAtoms()) + return Chem.MolToSmiles(max_mol) def filter_smiles(smiles_list: List[str], target_smiles: str) -> List[str]: + """Filter SMILES list to those containing carbon and not equal to a target. + + :param smiles_list: List of SMILES strings to filter. + :type smiles_list: List[str] + :param target_smiles: SMILES string to exclude. + :type target_smiles: str + :returns: Filtered list containing SMILES with at least one carbon atom + and not matching `target_smiles`. + :rtype: List[str] """ - Filters a list of SMILES strings to include only those that contain carbon atoms and are not identical - to a given target SMILES string. - - Parameters: - - smiles_list (List[str]): A list of SMILES strings to be filtered. - - target_smiles (str): The target SMILES string to exclude from the output. - - Returns: - - List[str]: A list of SMILES strings that contain carbon and are not the same as the target SMILES. - """ - filtered_smiles = [] - # Convert target SMILES to a molecule and standardize it for comparison target_mol = Chem.MolFromSmiles(target_smiles) - target_canonical = Chem.MolToSmiles(target_mol) if target_mol else None - - for smiles in smiles_list: - mol = Chem.MolFromSmiles(smiles) - if mol: - # Check if the molecule contains carbon - if any(atom.GetSymbol() == "C" for atom in mol.GetAtoms()): - # Standardize the SMILES for comparison - canonical_smiles = Chem.MolToSmiles(mol) - # Check that the SMILES is not the same as the target SMILES - if canonical_smiles != target_canonical: - filtered_smiles.append(smiles) - - return filtered_smiles + target_can = Chem.MolToSmiles(target_mol) if target_mol else "" + result: List[str] = [] + for smi in smiles_list: + mol = Chem.MolFromSmiles(smi) + if mol and any(atom.GetSymbol() == "C" for atom in mol.GetAtoms()): + can = Chem.MolToSmiles(mol) + if can != target_can: + result.append(smi) + return result def remove_atom_mappings(mol: Chem.Mol) -> Chem.Mol: - """ - Removes atom mapping numbers from a molecule by setting each atom's mapping number to zero. - - Parameters: - - mol (Chem.Mol): The RDKit molecule object from which to remove atom mappings. + """Strip atom‑mapping numbers from a molecule. - Returns: - - Chem.Mol: The same RDKit molecule object with atom mappings removed. + :param mol: RDKit Mol object. + :type mol: Chem.Mol + :returns: The same Mol with all atom‑map numbers set to zero. + :rtype: Chem.Mol """ for atom in mol.GetAtoms(): atom.SetAtomMapNum(0) @@ -107,81 +125,152 @@ def remove_atom_mappings(mol: Chem.Mol) -> Chem.Mol: def get_sanitized_smiles(smiles_list: List[str]) -> List[str]: - """ - Filters and returns a list of sanitizable SMILES strings from the provided list, with atom mappings removed - and excluding any SMILES containing reaction indicators ('->'). - - Parameters: - - smiles_list (List[str]): A list of SMILES strings to be sanitized. + """Sanitize SMILES list by removing mappings and invalid entries. - Returns: - - List[str]: A list of SMILES strings that can be successfully sanitized. + :param smiles_list: List of SMILES strings to sanitize. + :type smiles_list: List[str] + :returns: List of sanitized, isomeric SMILES of the largest + fragments only. + :rtype: List[str] """ - sanitized_smiles = [] + sanitized: List[str] = [] for smiles in smiles_list: - if "->" in smiles: # Skip SMILES with reaction indicators + if "->" in smiles: + continue + mol = Chem.MolFromSmiles(smiles) + if not mol: continue + mol = remove_atom_mappings(mol) try: - # Attempt to create a molecule from the SMILES string - mol = Chem.MolFromSmiles(smiles) - if mol: - # Remove atom mappings before sanitization - mol = remove_atom_mappings(mol) - - # Attempt to sanitize the molecule - Chem.SanitizeMol(mol) - - # If sanitization is successful, append the sanitized SMILES to the result list - sanitized_smiles.append(Chem.MolToSmiles(mol, isomericSmiles=True)) - sanitized_smiles = [get_max_fragment(sanitized_smiles)] - except (Chem.rdchem.ChemicalReactionException, ValueError) as e: - logger.error(e) + Chem.SanitizeMol(mol) + sanitized.append(Chem.MolToSmiles(mol, isomericSmiles=True)) + except Exception: continue - - return sanitized_smiles + # keep only the largest fragment across all + if sanitized: + sanitized = [get_max_fragment(sanitized)] + return sanitized def remove_duplicates(smiles_list: List[str]) -> List[str]: - """ - Removes duplicate SMILES strings from a list, maintaining the order of - first occurrences. Uses a set to track seen SMILES for efficiency. - - Parameters: - - smiles_list (List[str]): A list of SMILES strings representing - chemical reactions. + """Remove duplicate strings from a list, preserving first occurrence. - Returns: - - List[str]: A list with unique SMILES strings, preserving the original order. + :param smiles_list: List of strings (e.g., SMILES) possibly with + duplicates. + :type smiles_list: List[str] + :returns: List with duplicates removed in original order. + :rtype: List[str] """ seen = set() - unique_smiles = [ - smiles for smiles in smiles_list if not (smiles in seen or seen.add(smiles)) - ] - return unique_smiles + unique: List[str] = [] + for s in smiles_list: + if s not in seen: + unique.append(s) + seen.add(s) + return unique def process_smiles_list(smiles_list: List[str]) -> List[str]: + """Split dot‑connected SMILES into individual components. + + :param smiles_list: List of SMILES strings, some containing '.' + separators. + :type smiles_list: List[str] + :returns: Flattened list of component SMILES strings. + :rtype: List[str] """ - Processes a list of SMILES (Simplified Molecular Input Line Entry System) strings, - splitting any entries that contain disconnected molecular components - (indicated by a '.'), and returns a new list with each component as a separate entry. - - Parameters: - - smiles_list (List[str]): A list of SMILES strings, where some entries may contain - disconnected components separated by dots. - - Returns: - - List[str]: A new list of SMILES strings with all components separated. This list - does not include any original strings that contained dots; instead, it - includes their split components. - """ - new_smiles_list = [] # Create a new list to store processed SMILES strings + new_list: List[str] = [] for smiles in smiles_list: if "." in smiles: - # Split the SMILES string into components and extend the new list - components = smiles.split(".") - new_smiles_list.extend(components) + new_list.extend(smiles.split(".")) else: - # Add the unchanged SMILES string to the new list - new_smiles_list.append(smiles) - return new_smiles_list + new_list.append(smiles) + return new_list + + +def remove_explicit_H_from_rsmi(rsmi: str) -> str: + """Remove explicit H atoms from a reaction SMILES, preserving AAM. + + :param rsmi: Atom‑mapped reaction SMILES with explicit hydrogens. + :type rsmi: str + :returns: Simplified reaction SMILES with implicit hydrogens. + :rtype: str + """ + rxn = rdChemReactions.ReactionFromSmarts(rsmi, useSmiles=True) + + def cleaned(mols): + return ".".join( + Chem.MolToSmiles(Chem.RemoveHs(m), isomericSmiles=True) for m in mols + ) + + react = cleaned(rxn.GetReactants()) + prod = cleaned(rxn.GetProducts()) + return f"{react}>>{prod}" + + +def remove_common_reagents(reaction_smiles: str) -> Tuple[Optional[str], Optional[str]]: + """Remove reagents present on both sides of a reaction SMILES. + + :param reaction_smiles: Reaction SMILES 'reactants>>products'. + :type reaction_smiles: str + :returns: Tuple(cleaned_reaction, list_of_removed_reagents or None + if none found). + :rtype: Tuple[str, Optional[List[str]]] + """ + reactants, products = reaction_smiles.split(">>") + reactant_list = reactants.split(".") + product_list = products.split(".") + common_reagents = set(reactant_list) & set(product_list) + + filtered_reactants = [r for r in reactant_list if r not in common_reagents] + filtered_products = [p for p in product_list if p not in common_reagents] + cleaned_reaction_smiles = ( + ".".join(filtered_reactants) + ">>" + ".".join(filtered_products) + ) + + return cleaned_reaction_smiles + + +def reverse_reaction(rsmi: str) -> str: + """Reverse a reaction SMILES. + + :param rsmi: Reaction SMILES 'reactants>>products'. + :type rsmi: str + :returns: Reaction SMILES 'products>>reactants'. + :rtype: str + """ + parts = rsmi.split(">>") + return f"{parts[1]}>>{parts[0]}" if len(parts) == 2 else rsmi + + +def merge_reaction(rsmi_1: str, rsmi_2: str) -> Optional[str]: + """Merge two reaction SMILES into a single combined reaction. + + :param rsmi_1: First reaction SMILES. + :type rsmi_1: str + :param rsmi_2: Second reaction SMILES. + :type rsmi_2: str + :returns: Merged reaction SMILES or None if inputs invalid. + :rtype: Optional[str] + """ + try: + r1, p1 = rsmi_1.split(">>") + r2, p2 = rsmi_2.split(">>") + except ValueError: + return None + if not all([r1, p1, r2, p2]): + return None + return f"{r1}.{r2}>>{p1}.{p2}" + + +def find_longest_fragment(input_list: List[str]) -> Optional[str]: + """Find the longest string in a list. + + :param input_list: List of strings to search. + :type input_list: List[str] + :returns: Longest string or None if list empty. + :rtype: Optional[str] + """ + if not input_list: + return None + return max(input_list, key=len) diff --git a/synkit/Data/gen_partial_aam.py b/synkit/Data/gen_partial_aam.py index 8be300a..f9d5214 100644 --- a/synkit/Data/gen_partial_aam.py +++ b/synkit/Data/gen_partial_aam.py @@ -8,8 +8,8 @@ def _get_partial_aam(smart: str) -> str: - """ - Generate a partial atom‐atom mapping (AAM) SMILES string from a reactant SMARTS. + """Generate a partial atom‐atom mapping (AAM) SMILES string from a reactant + SMARTS. This function: 1. Parses the forward (“reactant”) and backward (“product”) halves of `smart`. @@ -63,8 +63,8 @@ def _get_partial_aam(smart: str) -> str: def _remove_small_smiles(smiles: str) -> str: - """ - Return the canonical SMILES of the largest fragment from an input SMILES. + """Return the canonical SMILES of the largest fragment from an input + SMILES. This function: 1. Parses `smiles` to an RDKit Mol without sanitization. @@ -105,9 +105,8 @@ def _remove_small_smiles(smiles: str) -> str: def _create_unbalanced_aam(rsmi: str, side: str = "right") -> str: - """ - Produce an unbalanced AAM reaction SMILES by keeping only the largest fragment - on the specified side(s) of the reaction. + """Produce an unbalanced AAM reaction SMILES by keeping only the largest + fragment on the specified side(s) of the reaction. :param rsmi: A reaction SMILES "reactant>>product". :type rsmi: str diff --git a/synkit/Graph/Canon/canon_algs.py b/synkit/Graph/Canon/canon_algs.py index b6dd9e8..e36959b 100644 --- a/synkit/Graph/Canon/canon_algs.py +++ b/synkit/Graph/Canon/canon_algs.py @@ -26,8 +26,7 @@ def _digest(text: str) -> Digest: - """ - Compute a 32-character hexadecimal SHA-256 digest of the input string. + """Compute a 32-character hexadecimal SHA-256 digest of the input string. Parameters ---------- @@ -43,8 +42,7 @@ def _digest(text: str) -> Digest: def ring_canonical_graph(g: nx.Graph) -> Tuple[nx.Graph, Digest]: - """ - Generate a relabelled graph based on SSSR membership hierarchy and + """Generate a relabelled graph based on SSSR membership hierarchy and compute its canonical signature. Nodes are ordered by: @@ -91,8 +89,8 @@ def ring_canonical_graph(g: nx.Graph) -> Tuple[nx.Graph, Digest]: def eigen_canonical_signature(g: nx.Graph) -> Digest: - """ - Compute a graph signature from sorted eigenvalues of its weighted adjacency matrix. + """Compute a graph signature from sorted eigenvalues of its weighted + adjacency matrix. Edge weights are taken from the 'order' attribute (default=1). The adjacency matrix is symmetric for undirected graphs. @@ -126,8 +124,7 @@ def eigen_canonical_signature(g: nx.Graph) -> Digest: def pgraph_signature(g: nx.Graph, p: int = 4) -> Digest: - """ - Generate a signature by hashing all simple paths up to length p. + """Generate a signature by hashing all simple paths up to length p. Each path is represented as a hyphen-separated sequence of node 'element' attributes (or '?' if missing), and the sorted list of these sequences @@ -160,8 +157,7 @@ def pgraph_signature(g: nx.Graph, p: int = 4) -> Digest: def canon_morgan( g: nx.Graph, morgan_radius: int = 2, node_attributes: List[str] = None ) -> Tuple[nx.Graph, Digest]: - """ - Prime-based neighbourhood refinement analogous to Morgan fingerprinting. + """Prime-based neighbourhood refinement analogous to Morgan fingerprinting. Each node is initially assigned a unique prime number; optionally, specified node attributes are incorporated into the seed label. @@ -236,8 +232,9 @@ def canon_morgan( # Utility to normalize and hash a node label with its neighbors and edge orders # Utility to normalize and hash a node label with its neighbors and edge orders def _hash_labels(node_label: int, neigh_info: List[Tuple[int, Any]]) -> int: - """ - Combine a node's label with sorted neighbor labels and edge orders into a new hash. + """Combine a node's label with sorted neighbor labels and edge orders into + a new hash. + neigh_info is a list of tuples (neighbor_label, edge_order). """ data = [str(node_label)] diff --git a/synkit/Graph/Canon/canon_graph.py b/synkit/Graph/Canon/canon_graph.py index bcd84dc..53e19a7 100644 --- a/synkit/Graph/Canon/canon_graph.py +++ b/synkit/Graph/Canon/canon_graph.py @@ -100,7 +100,8 @@ def _default_edge_key(u: NodeId, v: NodeId, data: EdgeData) -> Tuple[Any, ...]: def _digest(text: str) -> Digest: - """First 32 hex chars of SHA‑256 – short *but* collision‑safe for up to 2¹²⁸ graphs.""" + """First 32 hex chars of SHA‑256 – short *but* collision‑safe for up to + 2¹²⁸ graphs.""" return hashlib.sha256(text.encode()).hexdigest()[:32] @@ -110,8 +111,7 @@ def _digest(text: str) -> Digest: class GraphCanonicaliser: - """ - Factory that turns arbitrary ``networkx.Graph`` objects into their + """Factory that turns arbitrary ``networkx.Graph`` objects into their *canonical* twin plus a **stable 32‑hex digest**. Parameters @@ -169,8 +169,7 @@ def __init__( # High‑level helpers # # ------------------------------------------------------------------ # def canonicalise_graph(self, graph: nx.Graph) -> "CanonicalGraph": - """ - Return a :class:`CanonicalGraph` wrapper around *graph*. + """Return a :class:`CanonicalGraph` wrapper around *graph*. The wrapper exposes: @@ -183,8 +182,7 @@ def canonicalise_graphs( self, graphs: Iterable[nx.Graph], ) -> Tuple["CanonicalGraph", ...]: - """ - Bulk helper that returns *all* wrappers **sorted by hash**. + """Bulk helper that returns *all* wrappers **sorted by hash**. Useful when you want fast set comparison but need the canonical graphs as well: @@ -203,8 +201,7 @@ def canonicalise_graphs( # Digest / core methods # # ------------------------------------------------------------------ # def canonical_signature(self, graph: nx.Graph) -> Digest: - """ - Return the *hash of the canonical form* of *graph*. + """Return the *hash of the canonical form* of *graph*. Equal digests ⇒ graphs are guaranteed isomorphic **under the chosen back‑end and keys**. @@ -254,8 +251,7 @@ def _canon_generic(self, g: nx.Graph) -> nx.Graph: return G2 def _canon_wl(self, g: nx.Graph) -> nx.Graph: - """ - Weisfeiler–Lehman colour-refinement back-end (pure Python). + """Weisfeiler–Lehman colour-refinement back-end (pure Python). Seeds each node’s initial colour by the tuple of attributes in `self._wl_node_attrs` (e.g. ["element","charge","hcount"]), @@ -352,8 +348,7 @@ def __repr__(self) -> str: # pragma: no cover # Value wrapper (unchanged surface – richer docs) # ============================================================================= class CanonicalGraph: - """ - *Value object* tying together: + """*Value object* tying together: * the **original** NetworkX graph (mutable, user‑supplied); * its **canonical twin** (immutable copy, nodes relabelled 1…N); @@ -428,9 +423,9 @@ def help(self) -> None: # pragma: no cover class CanonicalRule: - """ - Value object that wraps a graph transformation rule in GML string form, - providing a canonicalised GML output and a stable 32-character SHA-256 hash. + """Value object that wraps a graph transformation rule in GML string form, + providing a canonicalised GML output and a stable 32-character SHA-256 + hash. Internally, the GML rule is parsed into a NetworkX graph via `gml_to_its`, canonicalised using a `GraphCanonicaliser`, and re-serialized back to GML @@ -458,8 +453,7 @@ def __init__( rule: str, canon: GraphCanonicaliser = GraphCanonicaliser(), ) -> None: - """ - Instantiate a CanonicalRule. + """Instantiate a CanonicalRule. Parameters ---------- @@ -523,9 +517,7 @@ def canonical_hash(self) -> Digest: return self._canonical_hash def help(self) -> None: - """ - Print original and canonical rule texts and underlying graphs. - """ + """Print original and canonical rule texts and underlying graphs.""" print("Original GML rule:") print(self._original_rule) print("\nCanonical GML rule:") diff --git a/synkit/Graph/Canon/nauty.py b/synkit/Graph/Canon/nauty.py index 2a89193..6a0333c 100644 --- a/synkit/Graph/Canon/nauty.py +++ b/synkit/Graph/Canon/nauty.py @@ -7,18 +7,17 @@ class NautyCanonicalizer: - """ - Perform Nauty‑style canonicalization of a NetworkX graph, optionally - refining and distinguishing nodes and edges by specified attributes, - and extracting automorphisms, orbits, and canonical permutations. + """Perform Nauty‑style canonicalization of a NetworkX graph, optionally + refining and distinguishing nodes and edges by specified attributes, and + extracting automorphisms, orbits, and canonical permutations. - :param node_attrs: List of node attribute keys to include in the initial - partition refinement. Nodes sharing the same tuple of - values under these keys will start in the same cell. + :param node_attrs: List of node attribute keys to include in the + initial partition refinement. Nodes sharing the same tuple of + values under these keys will start in the same cell. :type node_attrs: list[str] | None - :param edge_attrs: List of edge attribute keys to include when distinguishing - edges in the canonical label. If an edge has none of these - keys, its contribution will be empty. + :param edge_attrs: List of edge attribute keys to include when + distinguishing edges in the canonical label. If an edge has + none of these keys, its contribution will be empty. :type edge_attrs: list[str] | None """ @@ -29,12 +28,13 @@ def __init__( node_attrs: Optional[list[str]] = None, edge_attrs: Optional[list[str]] = None, ) -> None: - """ - Initialize the NautyCanonicalizer. + """Initialize the NautyCanonicalizer. - :param node_attrs: Node attribute names to use for initial partitioning. + :param node_attrs: Node attribute names to use for initial + partitioning. :type node_attrs: list[str] | None - :param edge_attrs: Edge attribute names to include in the canonical label. + :param edge_attrs: Edge attribute names to include in the + canonical label. :type edge_attrs: list[str] | None """ self.node_attrs = list(node_attrs) if node_attrs else [] @@ -59,8 +59,8 @@ def canonical_form( return_perm: bool = False, max_depth: Optional[int] = None, ): - """ - Compute canonical form of graph G with optional automorphisms, orbits, and early stopping. + """Compute canonical form of graph G with optional automorphisms, + orbits, and early stopping. :param G: NetworkX graph to canonicalize. :param return_aut: bool, whether to return list of automorphism permutations. @@ -317,4 +317,4 @@ def union_orbits(i, j): def graph_signature(self, G): G_canon = self.canonical_form(G) label = self._build_label(G_canon, sorted(G_canon.nodes())) - return hashlib.sha256(label.encode("utf-8")).hexdigest() + return hashlib.sha256(label.encode("utf-8")).hexdigest() \ No newline at end of file diff --git a/synkit/Graph/Context/hier_context.py b/synkit/Graph/Context/hier_context.py index 4714b34..e99849b 100644 --- a/synkit/Graph/Context/hier_context.py +++ b/synkit/Graph/Context/hier_context.py @@ -11,10 +11,11 @@ class HierContext(RadiusExpand): - """ - Hierarchical clustering class for reaction context graphs. Extends RadiusExpand to build - multi-level graph representations and clusters them based on structural features such as - Weisfeiler-Lehman hashing. + """Hierarchical clustering class for reaction context graphs. + + Extends RadiusExpand to build multi-level graph representations and + clusters them based on structural features such as Weisfeiler-Lehman + hashing. """ def __init__( @@ -24,8 +25,8 @@ def __init__( edge_attribute: str = "order", max_radius: int = 3, ) -> None: - """ - Initializes the HierContext class for hierarchical clustering of reaction context graphs. + """Initializes the HierContext class for hierarchical clustering of + reaction context graphs. Parameters: - node_label_names (List[str]): A list of node attribute names used for matching. @@ -46,8 +47,8 @@ def __init__( def _group_class( data: List[Dict[str, Any]], key: str ) -> Dict[Any, List[Dict[str, Any]]]: - """ - Groups a list of dictionaries into subgroups based on the specified key. + """Groups a list of dictionaries into subgroups based on the specified + key. Parameters: - data (List[Dict[str, Any]]): A list of dictionaries to be grouped. @@ -66,8 +67,8 @@ def _group_class( def _update_child_idx( data: List[List[Dict[str, Any]]], cls_id: str = "class" ) -> List[List[Dict[str, Any]]]: - """ - Updates hierarchical templates by assigning child IDs based on parent–cluster relationships. + """Updates hierarchical templates by assigning child IDs based on + parent–cluster relationships. Parameters: - data (List[List[Dict[str, Any]]]): A list of layers, where each layer is a list of dictionaries @@ -106,9 +107,9 @@ def _process( context_key: str, cls_func: Callable, ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: - """ - Processes a list of graph data entries by extracting context subgraphs and computing - their hashes, then classifies the data using the provided clustering function. + """Processes a list of graph data entries by extracting context + subgraphs and computing their hashes, then classifies the data using + the provided clustering function. Parameters: - data (List[Dict[str, Any]]): A list of dictionaries, each representing a graph or data entry. @@ -147,9 +148,9 @@ def _process_level( cls_func: Callable, radius: int = 1, ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: - """ - Processes a specific hierarchical level by grouping data based on parent cluster IDs, - extracting context for child levels, and clustering the data. + """Processes a specific hierarchical level by grouping data based on + parent cluster IDs, extracting context for child levels, and clustering + the data. Parameters: - data (List[Dict[str, Any]]): A list of dictionaries representing graph data entries. @@ -194,10 +195,11 @@ def fit( its_key: str = "ITS", context_key: str = "K", ) -> Tuple[List[Dict[str, Any]], List[List[Dict[str, Any]]]]: - """ - Processes a list of graph data entries, classifying each based on hierarchical clustering. - The method extracts context subgraphs, computes graph hashes, and clusters the data at multiple - hierarchical levels. Finally, child node indices are updated based on parent–cluster relationships. + """Processes a list of graph data entries, classifying each based on + hierarchical clustering. The method extracts context subgraphs, + computes graph hashes, and clusters the data at multiple hierarchical + levels. Finally, child node indices are updated based on parent–cluster + relationships. Parameters: - original_data (List[Dict[str, Any]]): A list of dictionaries, each representing a graph data entry diff --git a/synkit/Graph/Context/radius_expand.py b/synkit/Graph/Context/radius_expand.py index 22baf09..4006068 100644 --- a/synkit/Graph/Context/radius_expand.py +++ b/synkit/Graph/Context/radius_expand.py @@ -8,9 +8,8 @@ class RadiusExpand: - """ - A utility class for extracting and expanding reaction contexts - from chemical reaction graphs. + """A utility class for extracting and expanding reaction contexts from + chemical reaction graphs. This class provides methods to: - Identify reaction center nodes based on unequal edge orders. @@ -23,17 +22,17 @@ class RadiusExpand: """ def __init__(self) -> None: - """ - Initializes an instance of the RadiusExpand class. + """Initializes an instance of the RadiusExpand class. - This class does not maintain any instance-specific state and uses only static and class methods. + This class does not maintain any instance-specific state and + uses only static and class methods. """ pass @staticmethod def find_unequal_order_edges(G: nx.Graph) -> List[int]: - """ - Identifies reaction center nodes in a graph based on the presence of unequal order edges. + """Identifies reaction center nodes in a graph based on the presence of + unequal order edges. Parameters: - G (nx.Graph): Graph to analyze for reaction centers. @@ -56,8 +55,8 @@ def find_unequal_order_edges(G: nx.Graph) -> List[int]: def find_nearest_neighbors( G: nx.Graph, center_nodes: List[int], n_knn: int = 1 ) -> Set[int]: - """ - Finds the n-level nearest neighbors around the specified center nodes in a graph. + """Finds the n-level nearest neighbors around the specified center + nodes in a graph. Parameters: - G (nx.Graph): The graph in which to search for neighboring nodes. @@ -78,8 +77,8 @@ def find_nearest_neighbors( @staticmethod def extract_subgraph(G: nx.Graph, node_indices: List[int]) -> nx.Graph: - """ - Extracts a subgraph from the original graph containing the specified node indices. + """Extracts a subgraph from the original graph containing the specified + node indices. Parameters: - G (nx.Graph): The original graph. @@ -93,10 +92,9 @@ def extract_subgraph(G: nx.Graph, node_indices: List[int]) -> nx.Graph: @staticmethod def extract_k(its: nx.Graph, n_knn: int = 0) -> Tuple[nx.Graph, Any]: - """ - Constructs the context subgraph (K graph) from an ITS graph - based on reaction centers, and computes the longest extension path - from these centers constrained by 'standard_order' edges. + """Constructs the context subgraph (K graph) from an ITS graph based on + reaction centers, and computes the longest extension path from these + centers constrained by 'standard_order' edges. Parameters: - its (nx.Graph): The ITS graph representing the reaction network. @@ -127,9 +125,8 @@ def context_extraction( context_key: str = "K", n_knn: int = 0, ) -> Dict[str, Any]: - """ - Extracts the reaction context for a single reaction dictionary by computing - both the context subgraph and the longest extension path. + """Extracts the reaction context for a single reaction dictionary by + computing both the context subgraph and the longest extension path. Parameters: - data (Dict[str, Any]): Reaction data containing at least an ITS graph. @@ -143,7 +140,6 @@ def context_extraction( Returns: - Dict[str, Any]: The updated reaction data dictionary including the extracted context subgraph under the key specified by context_key. - """ context_data: Dict[str, Any] = copy.copy(data) its = context_data[its_key] @@ -161,8 +157,8 @@ def paralle_context_extraction( verbose: int = 0, n_knn: int = 0, ) -> List[Dict[str, Any]]: - """ - Performs parallel extraction of reaction contexts for multiple reaction dictionaries. + """Performs parallel extraction of reaction contexts for multiple + reaction dictionaries. Parameters: - data (List[Dict[str, Any]]): A list of reaction data dictionaries, each containing an ITS graph. @@ -186,8 +182,8 @@ def paralle_context_extraction( @staticmethod def remove_normal_edges(graph: nx.Graph, property_key: str) -> nx.Graph: - """ - Removes edges from a graph where the specified edge attribute has a value of 0. + """Removes edges from a graph where the specified edge attribute has a + value of 0. Parameters: - graph (nx.Graph): The input graph to modify. @@ -208,9 +204,9 @@ def remove_normal_edges(graph: nx.Graph, property_key: str) -> nx.Graph: @staticmethod def longest_radius_extension(G: nx.Graph, rc_nodes: List[int]) -> List[int]: - """ - Computes the longest unique extension path in the graph starting from the given reaction center nodes, - constrained by traversing only those edges where the 'standard_order' attribute equals 0. + """Computes the longest unique extension path in the graph starting + from the given reaction center nodes, constrained by traversing only + those edges where the 'standard_order' attribute equals 0. This method uses a depth-first search (DFS) strategy to explore all possible unique paths and returns the longest one. diff --git a/synkit/Graph/Feature/graph_descriptors.py b/synkit/Graph/Feature/graph_descriptors.py index c9b025e..04391e0 100644 --- a/synkit/Graph/Feature/graph_descriptors.py +++ b/synkit/Graph/Feature/graph_descriptors.py @@ -14,8 +14,7 @@ def __init__(self) -> None: @staticmethod def is_graph_empty(graph: Union[nx.Graph, dict, list, Any]) -> bool: - """ - Determine if a graph representation is empty. + """Determine if a graph representation is empty. Parameters: - graph (Union[nx.Graph, dict, list, Any]): A graph representation which can be @@ -40,8 +39,7 @@ def is_graph_empty(graph: Union[nx.Graph, dict, list, Any]) -> bool: @staticmethod def is_acyclic_graph(G: nx.Graph) -> bool: - """ - Determines if the given graph is acyclic. + """Determines if the given graph is acyclic. Parameters: - G (nx.Graph): The graph to be checked. @@ -54,8 +52,7 @@ def is_acyclic_graph(G: nx.Graph) -> bool: @staticmethod def is_single_cyclic_graph(G: nx.Graph) -> bool: - """ - Determines if the given graph has exactly one cycle. + """Determines if the given graph has exactly one cycle. Parameters: - G (nx.Graph): The graph to be checked. @@ -74,8 +71,7 @@ def is_single_cyclic_graph(G: nx.Graph) -> bool: @staticmethod def is_complex_cyclic_graph(G: nx.Graph) -> bool: - """ - Determines if the graph is complex cyclic with multiple cycles. + """Determines if the graph is complex cyclic with multiple cycles. Parameters: - G (nx.Graph): The graph to be checked. @@ -93,8 +89,7 @@ def is_complex_cyclic_graph(G: nx.Graph) -> bool: @staticmethod def check_graph_type(G: nx.Graph) -> str: - """ - Classifies the graph as acyclic, single cyclic, or complex cyclic. + """Classifies the graph as acyclic, single cyclic, or complex cyclic. Parameters: - G (nx.Graph): The graph to be checked. @@ -116,10 +111,9 @@ def check_graph_type(G: nx.Graph) -> str: @staticmethod def get_cycle_member_rings(G: nx.Graph, type="minimal") -> List[int]: - """ - Identifies all cycles in the given graph using cycle bases to ensure no overlap - and returns a list of the sizes of these cycles (member rings), - sorted in ascending order. + """Identifies all cycles in the given graph using cycle bases to ensure + no overlap and returns a list of the sizes of these cycles (member + rings), sorted in ascending order. Parameters: - G (nx.Graph): The NetworkX graph to be analyzed. @@ -142,8 +136,7 @@ def get_cycle_member_rings(G: nx.Graph, type="minimal") -> List[int]: @staticmethod def get_element_count(graph: nx.Graph) -> Dict[str, int]: - """ - Counts occurrences of each element in the graph nodes. + """Counts occurrences of each element in the graph nodes. Parameters: - graph (nx.Graph): A NetworkX graph with 'element' attribute in nodes. @@ -161,8 +154,8 @@ def get_descriptors( its: str = "ITS", condensed: bool = True, ) -> Dict: - """ - Enhance an entry dictionary with topology type and reaction type descriptors. + """Enhance an entry dictionary with topology type and reaction type + descriptors. Parameters: - entry (Dict): A dictionary with reaction data. @@ -208,8 +201,8 @@ def get_descriptors( @staticmethod def _extract_graph(entry: Dict, key: str) -> Union[nx.Graph, None]: - """ - Extracts a graph from an entry dictionary based on the specified key. + """Extracts a graph from an entry dictionary based on the specified + key. Parameters: - entry (Dict): The dictionary containing graph data. @@ -234,8 +227,7 @@ def _extract_graph(entry: Dict, key: str) -> Union[nx.Graph, None]: def _adjust_cycle_and_step( entry: Dict, cycle_key: str, topo_type: str, its_prefix: str = "" ) -> None: - """ - Adjusts cycle and step descriptors based on the graph topology type. + """Adjusts cycle and step descriptors based on the graph topology type. Parameters: - entry (Dict): The entry dictionary to update. @@ -258,8 +250,7 @@ def _adjust_cycle_and_step( @staticmethod def _validate_graph_input(G: nx.Graph) -> None: - """ - Validates that the input is a NetworkX graph. + """Validates that the input is a NetworkX graph. Parameters: - G (nx.Graph): The graph to validate. @@ -279,8 +270,8 @@ def process_entries_in_parallel( n_jobs: int = 4, verbose: int = 0, ) -> List[Dict]: - """ - Processes a list of entries in parallel to enhance each entry with descriptors. + """Processes a list of entries in parallel to enhance each entry with + descriptors. Parameters: - entries (List[Dict]): List of dictionaries containing reaction data to enhance. @@ -304,8 +295,7 @@ def process_entries_in_parallel( def check_graph_connectivity(graph: nx.Graph) -> str: - """ - Check the connectivity of a NetworkX graph. + """Check the connectivity of a NetworkX graph. This function assesses whether all nodes in the graph are connected by some path, applicable to undirected graphs. diff --git a/synkit/Graph/Feature/graph_fps.py b/synkit/Graph/Feature/graph_fps.py index 58b1214..b1ef798 100644 --- a/synkit/Graph/Feature/graph_fps.py +++ b/synkit/Graph/Feature/graph_fps.py @@ -7,9 +7,8 @@ class GraphFP: def __init__( self, graph: nx.Graph, nBits: int = 1024, hash_alg: str = "sha256" ) -> None: - """ - Initialize the GraphFP class to create binary fingerprints based on various graph - characteristics. + """Initialize the GraphFP class to create binary fingerprints based on + various graph characteristics. Parameters: - graph (nx.Graph): Graph on which to perform analysis. @@ -22,8 +21,8 @@ def __init__( self.hash_function = getattr(hashlib, self.hash_alg) def fingerprint(self, method: str) -> str: - """ - Generate a binary string fingerprint of the graph using the specified method. + """Generate a binary string fingerprint of the graph using the + specified method. Parameters: - method (str): The method to use for fingerprinting @@ -78,9 +77,8 @@ def _motif_count_fp(self) -> str: return triangle_str[: self.nBits] def iterative_deepening(self, remaining_bits: int) -> str: - """ - Extend the hash length using iterative hashing until the desired bit length is - achieved. + """Extend the hash length using iterative hashing until the desired bit + length is achieved. Parameters: - remaining_bits (int): Number of bits needed to complete the fingerprint diff --git a/synkit/Graph/Feature/graph_signature.py b/synkit/Graph/Feature/graph_signature.py index 0f558c2..99c9ce3 100644 --- a/synkit/Graph/Feature/graph_signature.py +++ b/synkit/Graph/Feature/graph_signature.py @@ -3,15 +3,17 @@ class GraphSignature: - """ - Provides methods to generate canonical signatures for graph edges (with flexible 'order' and 'state' attributes, - and node degrees/neighbor information), various spectral invariants, adjacency matrix, and complete graphs. - Aims for high uniqueness without relying solely on isomorphism checks. + """Provides methods to generate canonical signatures for graph edges (with + flexible 'order' and 'state' attributes, and node degrees/neighbor + information), various spectral invariants, adjacency matrix, and complete + graphs. + + Aims for high uniqueness without relying solely on isomorphism + checks. """ def __init__(self, graph: nx.Graph): - """ - Initializes the GraphSignature class with a specified graph. + """Initializes the GraphSignature class with a specified graph. Parameters: - graph (nx.Graph): A NetworkX graph instance. @@ -20,10 +22,9 @@ def __init__(self, graph: nx.Graph): self._validate_graph() def _validate_graph(self): - """ - Validates that all nodes have the required attributes ('element' and 'charge'), - and all edges have the required 'order' attribute as int, float, or tuple of two floats, - and optionally the 'state' attribute. + """Validates that all nodes have the required attributes ('element' and + 'charge'), and all edges have the required 'order' attribute as int, + float, or tuple of two floats, and optionally the 'state' attribute. Raises: - ValueError: If any node is missing the 'element' or 'charge' attribute, @@ -61,9 +62,10 @@ def _validate_graph(self): def create_edge_signature( self, include_neighbors: bool = False, max_hop: int = 2 ) -> str: - """ - Generates a canonical edge signature by formatting each edge with sorted node elements (including charge), - node degrees, bond order, bond state, and optionally including neighbor information and topological context. + """Generates a canonical edge signature by formatting each edge with + sorted node elements (including charge), node degrees, bond order, bond + state, and optionally including neighbor information and topological + context. Parameters: - include_neighbors (bool): Whether to include neighbors' details in the edge signature. @@ -139,8 +141,7 @@ def create_edge_signature( return "/".join(sorted(edge_signature_parts)) def _get_khop_neighbors(self, node, max_hop): - """ - Retrieves the k-hop neighborhood information for a given node. + """Retrieves the k-hop neighborhood information for a given node. Parameters: - node (int): The node for which to get neighborhood information. @@ -171,8 +172,8 @@ def _get_khop_neighbors(self, node, max_hop): ) def create_wl_hash(self, iterations: int = 3) -> str: - """ - Generates a Weisfeiler-Lehman (WL) hash for the graph to capture its structural features. + """Generates a Weisfeiler-Lehman (WL) hash for the graph to capture its + structural features. Parameters: - iterations (int): Number of WL iterations to perform. @@ -210,8 +211,8 @@ def create_graph_signature( include_neighbors: bool = True, max_hop: int = 1, ) -> str: - """ - Combines edge, various spectral invariants, and WL hash into a single comprehensive graph signature. + """Combines edge, various spectral invariants, and WL hash into a + single comprehensive graph signature. Parameters: - include_wl_hash (bool): Whether to include the Weisfeiler-Lehman hash. diff --git a/synkit/Graph/Feature/hash_fps.py b/synkit/Graph/Feature/hash_fps.py index 3febabb..8adbebd 100644 --- a/synkit/Graph/Feature/hash_fps.py +++ b/synkit/Graph/Feature/hash_fps.py @@ -7,8 +7,8 @@ class HashFPs: def __init__( self, graph: nx.Graph, numBits: int = 256, hash_alg: str = "sha256" ) -> None: - """ - Initialize the HashFPs class with a graph and configuration settings. + """Initialize the HashFPs class with a graph and configuration + settings. Parameters: - graph (nx.Graph): The graph to be fingerprinted. @@ -37,8 +37,8 @@ def hash_fps( end_node: Optional[int] = None, max_path_length: Optional[int] = None, ) -> str: - """ - Generate a binary hash fingerprint of the graph based on its paths and cycles. + """Generate a binary hash fingerprint of the graph based on its paths + and cycles. Parameters: - start_node (Optional[int]): The starting node index for path detection. @@ -55,7 +55,8 @@ def hash_fps( return full_hash_binary def initialize_hash(self) -> Any: - """Initialize and return the hash object based on the specified algorithm.""" + """Initialize and return the hash object based on the specified + algorithm.""" return getattr(hashlib, self.hash_alg)() def extract_features( @@ -64,8 +65,7 @@ def extract_features( end_node: Optional[int], max_path_length: Optional[int], ) -> str: - """ - Extract features from the graph based on paths and cycles. + """Extract features from the graph based on paths and cycles. Parameters: - start_node (Optional[int]): The starting node for path detection. @@ -90,9 +90,8 @@ def extract_features( return "".join(map(str, features)) def finalize_hash(self, hash_object: Any, features: str) -> str: - """ - Finalize the hash using the features extracted and return the hash as a binary - string. + """Finalize the hash using the features extracted and return the hash + as a binary string. Parameters: - hash_object (Any): The hash object. @@ -110,9 +109,8 @@ def finalize_hash(self, hash_object: Any, features: str) -> str: return full_hash_binary[: self.numBits] def iterative_deepening(self, hash_object: Any, remaining_bits: int) -> str: - """ - Extend hash length using iterative hashing until the desired bit length is - achieved. + """Extend hash length using iterative hashing until the desired bit + length is achieved. Parameters: - hash_object (hashlib._Hash): The hash object for iterative deepening. diff --git a/synkit/Graph/Feature/morgan_fps.py b/synkit/Graph/Feature/morgan_fps.py index 4bef1fa..b04f389 100644 --- a/synkit/Graph/Feature/morgan_fps.py +++ b/synkit/Graph/Feature/morgan_fps.py @@ -11,9 +11,9 @@ def __init__( nBits: int = 1024, hash_alg: str = "sha256", ): - """ - Initialize the MorganFPs class to generate fingerprints based on the Morgan - algorithm, approximating Extended Connectivity Fingerprints (ECFPs). + """Initialize the MorganFPs class to generate fingerprints based on the + Morgan algorithm, approximating Extended Connectivity Fingerprints + (ECFPs). Parameters: - graph (nx.Graph): The graph to analyze. @@ -29,10 +29,9 @@ def __init__( self.hash_function = getattr(hashlib, self.hash_alg) def generate_fingerprint(self) -> str: - """ - Generate a binary string fingerprint of the graph based on the local environments - of nodes. Ensures the output is exactly `nBits` in length using iterative - deepening if necessary. + """Generate a binary string fingerprint of the graph based on the local + environments of nodes. Ensures the output is exactly `nBits` in length + using iterative deepening if necessary. Returns: - str: A binary string of length `nBits` representing the fingerprint of the @@ -68,9 +67,8 @@ def generate_fingerprint(self) -> str: return fingerprint def iterative_deepening(self, hash_object: Any, remaining_bits: int) -> str: - """ - Extend the hash length using iterative hashing until the desired bit length is - achieved. + """Extend the hash length using iterative hashing until the desired bit + length is achieved. Parameters: - hash_object (hashlib._Hash): The hash object used for iterative deepening. diff --git a/synkit/Graph/Feature/path_fps.py b/synkit/Graph/Feature/path_fps.py index b5e4e92..6d9c488 100644 --- a/synkit/Graph/Feature/path_fps.py +++ b/synkit/Graph/Feature/path_fps.py @@ -11,9 +11,8 @@ def __init__( nBits: int = 1024, hash_alg: str = "sha256", ) -> None: - """ - Initialize the PathFPs class to create a binary fingerprint based on paths in a - graph. + """Initialize the PathFPs class to create a binary fingerprint based on + paths in a graph. Parameters: - graph (nx.Graph): Graph on which to perform analysis. @@ -29,9 +28,8 @@ def __init__( self.hash_function = getattr(hashlib, self.hash_alg) def generate_fingerprint(self) -> str: - """ - Generate a binary string fingerprint of the graph by hashing paths up to a certain - length and combining them. + """Generate a binary string fingerprint of the graph by hashing paths + up to a certain length and combining them. Returns: - str: A binary string of length `nBits` that represents the fingerprint of the @@ -63,9 +61,8 @@ def generate_fingerprint(self) -> str: return fingerprint def iterative_deepening(self, hash_object: Any, remaining_bits: int) -> str: - """ - Extend the hash length using iterative hashing until the desired bit length is - achieved. + """Extend the hash length using iterative hashing until the desired bit + length is achieved. Parameters: - hash_object (hashlib._Hash): The hash object used for iterative deepening. diff --git a/synkit/Graph/Feature/wl_hash.py b/synkit/Graph/Feature/wl_hash.py index 5f889dd..89317e3 100644 --- a/synkit/Graph/Feature/wl_hash.py +++ b/synkit/Graph/Feature/wl_hash.py @@ -4,8 +4,7 @@ class WLHash: - """ - A class that implements the Weisfeiler-Lehman graph hashing algorithm, + """A class that implements the Weisfeiler-Lehman graph hashing algorithm, supporting multiple node/edge attributes for hashing. Attributes: @@ -22,8 +21,7 @@ def __init__( iterations: int = 5, digest_size: int = 16, ): - """ - Initializes the WLHash class with configuration for hashing. + """Initializes the WLHash class with configuration for hashing. Parameters: - node: A node attribute name or list of node attribute names. @@ -39,8 +37,8 @@ def __init__( def _prepare_graph( self, graph: nx.Graph ) -> Tuple[nx.Graph, Union[str, None], Union[str, None]]: - """ - Prepare a deep copy of the graph with combined/missing node and edge attributes. + """Prepare a deep copy of the graph with combined/missing node and edge + attributes. Returns (H, node_attr_name, edge_attr_name). """ @@ -86,9 +84,7 @@ def _prepare_graph( return H, node_attr_name, edge_attr_name def weisfeiler_lehman_graph_hash(self, graph: nx.Graph) -> str: - """ - Computes the WL graph hash for the entire graph. - """ + """Computes the WL graph hash for the entire graph.""" G, node_attr, edge_attr = self._prepare_graph(graph) return nx.weisfeiler_lehman_graph_hash( G, @@ -101,9 +97,7 @@ def weisfeiler_lehman_graph_hash(self, graph: nx.Graph) -> str: def weisfeiler_lehman_subgraph_hashes( self, graph: nx.Graph ) -> Dict[Union[int, str], List[str]]: - """ - Computes the WL subgraph hashes for each node in the graph. - """ + """Computes the WL subgraph hashes for each node in the graph.""" G, node_attr, edge_attr = self._prepare_graph(graph) return nx.weisfeiler_lehman_subgraph_hashes( G, @@ -119,8 +113,8 @@ def process_data( graph_key: str = "ITS", subgraph: bool = False, ) -> List[Dict[str, Union[str, None]]]: - """ - Applies WL hashing (or subgraph hashing) to a list of data entries. + """Applies WL hashing (or subgraph hashing) to a list of data entries. + Each entry must contain a graph under 'graph_key'. """ for entry in data: diff --git a/synkit/Graph/Hyrogen/_misc.py b/synkit/Graph/Hyrogen/_misc.py index 98bfea8..866b4b9 100644 --- a/synkit/Graph/Hyrogen/_misc.py +++ b/synkit/Graph/Hyrogen/_misc.py @@ -8,8 +8,7 @@ def has_XH(G: nx.Graph) -> bool: - """ - Check whether the graph contains any heavy atom–hydrogen bond. + """Check whether the graph contains any heavy atom–hydrogen bond. A heavy atom is any atom whose 'element' attribute is not 'H'. This function searches for any edge that connects a heavy atom to a hydrogen atom. @@ -34,8 +33,7 @@ def has_XH(G: nx.Graph) -> bool: def has_HH(G: nx.Graph) -> bool: - """ - Check whether the graph contains any heavy atom–hydrogen bond. + """Check whether the graph contains any heavy atom–hydrogen bond. A heavy atom is any atom whose 'element' attribute is not 'H'. This function searches for any edge that connects a heavy atom to a hydrogen atom. @@ -60,8 +58,7 @@ def has_HH(G: nx.Graph) -> bool: def h_to_implicit(G: nx.Graph) -> nx.Graph: - """ - Convert explicit hydrogen atoms to implicit counts on heavy atoms. + """Convert explicit hydrogen atoms to implicit counts on heavy atoms. For each hydrogen atom ('element' == 'H'), its neighbor (assumed to be a heavy atom) will have its 'hcount' attribute incremented. The hydrogen nodes are then removed. @@ -107,8 +104,8 @@ def normalize_edge_orders(G: nx.Graph) -> None: def h_to_explicit(G: nx.Graph, nodes: List[int] = None, its: bool = False) -> nx.Graph: - """ - Convert implicit hydrogen counts on heavy atoms into explicit hydrogen nodes. + """Convert implicit hydrogen counts on heavy atoms into explicit hydrogen + nodes. For each node ID in `nodes`, this function reads the node's 'hcount', adds that many new hydrogen nodes, connects them to the node with a single bond (order=1.0), and @@ -170,11 +167,11 @@ def h_to_explicit(G: nx.Graph, nodes: List[int] = None, its: bool = False) -> nx def implicit_hydrogen( graph: nx.Graph, preserve_atom_maps: Set[int], reindex: bool = False ) -> nx.Graph: - """ - Adds implicit hydrogens to a molecular graph and removes non-preserved hydrogens. - This function operates on a deep copy of the input graph to avoid in-place modifications. - It counts hydrogen neighbors for each non-hydrogen node and adjusts based on - hydrogens that need to be preserved. Non-preserved hydrogen nodes are removed from the graph. + """Adds implicit hydrogens to a molecular graph and removes non-preserved + hydrogens. This function operates on a deep copy of the input graph to + avoid in-place modifications. It counts hydrogen neighbors for each non- + hydrogen node and adjusts based on hydrogens that need to be preserved. + Non-preserved hydrogen nodes are removed from the graph. Parameters: - graph (nx.Graph): A NetworkX graph representing the molecule, where each node has an 'element' @@ -330,8 +327,7 @@ def implicit_hydrogen( def check_equivariant_graph( its_graphs: List[nx.Graph], ) -> Tuple[List[Tuple[int, int]], int]: - """ - Checks for isomorphism among a list of ITS graphs. + """Checks for isomorphism among a list of ITS graphs. Parameters: - its_graphs (List[nx.Graph]): A list of ITS graphs. @@ -358,8 +354,8 @@ def check_equivariant_graph( def check_explicit_hydrogen(graph: nx.Graph) -> tuple: - """ - Counts the explicit hydrogen nodes in the given graph and collects their IDs. + """Counts the explicit hydrogen nodes in the given graph and collects their + IDs. Parameters: - graph (nx.Graph): The graph to inspect. @@ -376,9 +372,9 @@ def check_explicit_hydrogen(graph: nx.Graph) -> tuple: def check_hcount_change(react_graph: nx.Graph, prod_graph: nx.Graph) -> int: - """ - Computes the maximum change in hydrogen count ('hcount') between corresponding nodes - in the reactant and product graphs. It considers both hydrogen formation and breakage. + """Computes the maximum change in hydrogen count ('hcount') between + corresponding nodes in the reactant and product graphs. It considers both + hydrogen formation and breakage. Parameters: - react_graph (nx.Graph): The graph representing reactants. @@ -409,9 +405,8 @@ def check_hcount_change(react_graph: nx.Graph, prod_graph: nx.Graph) -> int: def get_cycle_member_rings(G: nx.Graph, type="minimal") -> List[int]: - """ - Identifies all cycles in the given graph using cycle bases to ensure no overlap - and returns a list of the sizes of these cycles (member rings), + """Identifies all cycles in the given graph using cycle bases to ensure no + overlap and returns a list of the sizes of these cycles (member rings), sorted in ascending order. Parameters: @@ -435,10 +430,9 @@ def get_cycle_member_rings(G: nx.Graph, type="minimal") -> List[int]: def get_priority(reaction_centers: List[Any]) -> List[int]: - """ - Evaluate reaction centers for specific graph characteristics, selecting indices based - on the shortest reaction paths and maximum ring sizes, and adjusting for certain - graph types by modifying the ring information. + """Evaluate reaction centers for specific graph characteristics, selecting + indices based on the shortest reaction paths and maximum ring sizes, and + adjusting for certain graph types by modifying the ring information. Parameters: - reaction_centers: List[Any], a list of reaction centers where each center should be diff --git a/synkit/Graph/Hyrogen/hcomplete.py b/synkit/Graph/Hyrogen/hcomplete.py index 9ab3839..32e6c1b 100644 --- a/synkit/Graph/Hyrogen/hcomplete.py +++ b/synkit/Graph/Hyrogen/hcomplete.py @@ -20,9 +20,8 @@ class HComplete: - """ - A class for infering hydrogen to complete reaction center or ITS graph. - """ + """A class for infering hydrogen to complete reaction center or ITS + graph.""" @staticmethod def process_single_graph_data( @@ -34,9 +33,8 @@ def process_single_graph_data( get_priority_graph: bool = False, max_hydrogen: int = 7, ) -> Dict[str, Optional[nx.Graph]]: - """ - Processes a single graph data dictionary by modifying hydrogen counts - and other features based on configuration settings. + """Processes a single graph data dictionary by modifying hydrogen + counts and other features based on configuration settings. Parameters: - graph_data (Dict[str, nx.Graph]): Dictionary containing the graph data. @@ -95,8 +93,7 @@ def process_graph_data_parallel( get_priority_graph: bool = False, max_hydrogen: int = 7, ) -> List[Dict[str, Optional[nx.Graph]]]: - """ - Processes a list of graph data dictionaries in parallel to optimize + """Processes a list of graph data dictionaries in parallel to optimize the hydrogen completion and other graph modifications. Parameters: @@ -140,9 +137,9 @@ def process_multiple_hydrogens( balance_its: bool, get_priority_graph: bool = False, ) -> Dict[str, Optional[nx.Graph]]: - """ - Handles significant hydrogen count changes between reactant and product graphs, - adjusting hydrogen nodes accordingly and assessing graph equivalence. + """Handles significant hydrogen count changes between reactant and + product graphs, adjusting hydrogen nodes accordingly and assessing + graph equivalence. Parameters: - graph_data (Dict[str, nx.Graph]): Dictionary containing the graph data. @@ -218,9 +215,9 @@ def add_hydrogen_nodes_multiple( balance_its: bool, get_priority_graph: bool = False, ) -> List[Tuple[nx.Graph, nx.Graph]]: - """ - Generates multiple permutations of reactant and product graphs by adjusting hydrogen counts, - exploring all possible configurations of hydrogen node additions or removals. + """Generates multiple permutations of reactant and product graphs by + adjusting hydrogen counts, exploring all possible configurations of + hydrogen node additions or removals. Parameters: - react_graph (nx.Graph): The reactant graph. @@ -310,9 +307,8 @@ def add_hydrogen_nodes_multiple_utils( node_id_pairs: Iterable[Tuple[int, int]], atom_map_update: bool = True, ) -> nx.Graph: - """ - Creates and returns a new graph with added hydrogen nodes based on the input graph - and node ID pairs. + """Creates and returns a new graph with added hydrogen nodes based on + the input graph and node ID pairs. Parameters: - graph (nx.Graph): The base graph to which the nodes will be added. diff --git a/synkit/Graph/Hyrogen/hextend.py b/synkit/Graph/Hyrogen/hextend.py index 43ddd55..c3e6ad8 100644 --- a/synkit/Graph/Hyrogen/hextend.py +++ b/synkit/Graph/Hyrogen/hextend.py @@ -19,8 +19,8 @@ class HExtend(HComplete): def get_unique_graphs_for_clusters( graphs: List[nx.Graph], cluster_indices: List[set] ) -> List[nx.Graph]: - """ - Retrieve a unique graph for each cluster from a list of graphs based on cluster indices. + """Retrieve a unique graph for each cluster from a list of graphs based + on cluster indices. This method selects one graph per cluster based on the first index found in each cluster set. Note: Clusters are expected to be represented @@ -60,8 +60,7 @@ def _extend( ignore_aromaticity: bool, balance_its: bool, ) -> Tuple[List[nx.Graph], List[nx.Graph], List[str]]: - """ - Process equivalent maps by adding hydrogen nodes and constructing + """Process equivalent maps by adding hydrogen nodes and constructing ITS graphs based on the balance and aromaticity settings. Parameters: @@ -107,9 +106,8 @@ def _process( ignore_aromaticity: bool, balance_its: bool, ) -> Dict: - """ - Processes a dictionary of graphs using specific graph processing functions - and updates the dictionary with new graph data. + """Processes a dictionary of graphs using specific graph processing + functions and updates the dictionary with new graph data. Parameters: - data_dict (Dict): Dictionary containing the graphs and their keys. @@ -143,9 +141,8 @@ def fit( n_jobs: int = 1, verbose: int = 0, ) -> List: - """ - Fit the model to the data in parallel, processing each entry to generate - new graph data based on the ITS and reaction graph keys. + """Fit the model to the data in parallel, processing each entry to + generate new graph data based on the ITS and reaction graph keys. Parameters: - data (iterable): Data to be processed. diff --git a/synkit/Graph/ITS/its_builder.py b/synkit/Graph/ITS/its_builder.py index 8869886..2ab5591 100644 --- a/synkit/Graph/ITS/its_builder.py +++ b/synkit/Graph/ITS/its_builder.py @@ -3,17 +3,17 @@ class ITSBuilder: - """ - Build and annotate an Imaginary Transition State (ITS) graph from a base graph - and a reaction-center (RC) graph. + """Build and annotate an Imaginary Transition State (ITS) graph from a base + graph and a reaction-center (RC) graph. - :cvar None: This class only provides static methods and does not maintain state. + :cvar None: This class only provides static methods and does not + maintain state. """ @staticmethod def update_atom_map(graph: nx.Graph) -> None: - """ - Reset and renumber the 'atom_map' attribute of every node to match its node index. + """Reset and renumber the 'atom_map' attribute of every node to match + its node index. :param graph: The graph whose nodes will be renumbered. :type graph: nx.Graph @@ -31,9 +31,9 @@ def update_atom_map(graph: nx.Graph) -> None: @staticmethod def ITSGraph(G: nx.Graph, RC: nx.Graph) -> nx.Graph: - """ - Create an ITS graph by merging attributes from a reaction-center graph (RC) - into a copy of the base graph G and initializing transition-state metadata. + """Create an ITS graph by merging attributes from a reaction-center + graph (RC) into a copy of the base graph G and initializing transition- + state metadata. The returned ITS graph will have: 1. A deep copy of G’s nodes and edges. @@ -111,4 +111,4 @@ def ITSGraph(G: nx.Graph, RC: nx.Graph) -> nx.Graph: # 7) Renumber atom_map to node indices ITSBuilder.update_atom_map(its) - return its + return its \ No newline at end of file diff --git a/synkit/Graph/ITS/its_construction.py b/synkit/Graph/ITS/its_construction.py index d56bd15..bfa51f7 100644 --- a/synkit/Graph/ITS/its_construction.py +++ b/synkit/Graph/ITS/its_construction.py @@ -12,8 +12,8 @@ def ITSGraph( attributes_defaults: Optional[Dict[str, Any]] = None, balance_its: bool = True, ) -> nx.Graph: - """ - Create a Combined Graph Representation (CGR) by merging nodes and edges of G and H. + """Create a Combined Graph Representation (CGR) by merging nodes and + edges of G and H. The resulting ITS graph: - Uses a deep copy of the smaller (or larger, if balance_its is False) input graph. @@ -81,8 +81,7 @@ def ITSGraph( def get_node_attribute( graph: nx.Graph, node: Hashable, attribute: str, default: Any ) -> Any: - """ - Retrieve a node attribute or return a default if missing. + """Retrieve a node attribute or return a default if missing. :param graph: The graph containing the node. :type graph: nx.Graph @@ -90,7 +89,8 @@ def get_node_attribute( :type node: hashable :param attribute: The name of the attribute to retrieve. :type attribute: str - :param default: The value to return if the attribute is not present. + :param default: The value to return if the attribute is not + present. :type default: Any :returns: The attribute value or the default. :rtype: Any @@ -104,8 +104,7 @@ def get_node_attribute( def get_node_attributes_with_defaults( graph: nx.Graph, node: int, attributes_defaults: Dict[str, Any] = None ) -> Tuple: - """ - Retrieve multiple node attributes, applying defaults where missing. + """Retrieve multiple node attributes, applying defaults where missing. :param graph: The graph containing the node. :type graph: nx.Graph @@ -134,8 +133,7 @@ def get_node_attributes_with_defaults( def add_edges_to_ITS( ITS: nx.Graph, G: nx.Graph, H: nx.Graph, ignore_aromaticity: bool = False ) -> nx.Graph: - """ - Add and label edges in the ITS graph based on presence in G and H. + """Add and label edges in the ITS graph based on presence in G and H. For each edge (u,v) in G or H: - If present in both, label ``order=(order_G, order_H)``. @@ -190,8 +188,8 @@ def add_edges_to_ITS( def add_standard_order_attribute( graph: nx.Graph, ignore_aromaticity: bool = False ) -> nx.Graph: - """ - Compute and attach 'standard_order' to each edge as difference of orders. + """Compute and attach 'standard_order' to each edge as difference of + orders. :param graph: Graph whose edges have ``order=(o_G, o_H)``. :type graph: nx.Graph @@ -286,12 +284,13 @@ def construct( return its def typesGH(self) -> Dict[str, Dict[str, Tuple[Any, Any]]]: - """ - Returns the types and default values for selected node and edge attributes, useful for - interpreting the 'typesGH' annotation on ITS graphs. + """Returns the types and default values for selected node and edge + attributes, useful for interpreting the 'typesGH' annotation on ITS + graphs. - :returns: Dictionary with node and edge attribute types and defaults, e.g. - {"node": {attr: (type, 0)}, "edge": {attr: (type, 0)}} + :returns: Dictionary with node and edge attribute types and + defaults, e.g. {"node": {attr: (type, 0)}, "edge": {attr: + (type, 0)}} :rtype: dict[str, dict[str, tuple[type, Any]]] """ node_prop_types: Dict[str, Any] = { @@ -323,4 +322,4 @@ def typesGH(self) -> Dict[str, Dict[str, Tuple[Any, Any]]]: } node_defaults = {k: (tp, 0) for k, tp in sel_nodes.items()} edge_defaults = {k: (tp, 0) for k, tp in sel_edges.items()} - return {"node": node_defaults, "edge": edge_defaults} + return {"node": node_defaults, "edge": edge_defaults} \ No newline at end of file diff --git a/synkit/Graph/ITS/its_decompose.py b/synkit/Graph/ITS/its_decompose.py index a021c08..c608acb 100644 --- a/synkit/Graph/ITS/its_decompose.py +++ b/synkit/Graph/ITS/its_decompose.py @@ -15,8 +15,7 @@ def get_rc( standard_key: str = "standard_order", disconnected: bool = False, ) -> nx.Graph: - """ - Extract the reaction-center (RC) subgraph from an ITS graph. + """Extract the reaction-center (RC) subgraph from an ITS graph. This function identifies: 1. All bonds whose standard order (difference between ITS orders) is non-zero. @@ -307,8 +306,8 @@ def _add_bond_order_changes( def its_decompose(its_graph: nx.Graph, nodes_share="typesGH", edges_share="order"): - """ - Decompose an ITS graph into two separate reactant (G) and product (H) graphs. + """Decompose an ITS graph into two separate reactant (G) and product (H) + graphs. Nodes and edges in `its_graph` carry composite attributes: - Each node has `its_graph.nodes[nodes_share] = (node_attrs_G, node_attrs_H)`. @@ -379,8 +378,7 @@ def compare_graphs( node_attrs: list = ["element", "aromatic", "hcount", "charge", "neighbors"], edge_attrs: list = ["order"], ) -> bool: - """ - Compare two graphs based on specified node and edge attributes. + """Compare two graphs based on specified node and edge attributes. Parameters: - graph1 (nx.Graph): The first graph to compare. @@ -428,12 +426,12 @@ def compare_graphs( def enumerate_tautomers(reaction_smiles: str) -> Optional[List[str]]: - """ - Enumerates possible tautomers for reactants while canonicalizing the products in a - reaction SMILES string. This function first splits the reaction SMILES string into - reactants and products. It then generates all possible tautomers for the reactants and - canonicalizes the product molecule. The function returns a list of reaction SMILES - strings for each tautomer of the reactants combined with the canonical product. + """Enumerates possible tautomers for reactants while canonicalizing the + products in a reaction SMILES string. This function first splits the + reaction SMILES string into reactants and products. It then generates all + possible tautomers for the reactants and canonicalizes the product + molecule. The function returns a list of reaction SMILES strings for each + tautomer of the reactants combined with the canonical product. Parameters: - reaction_smiles (str): A SMILES string of the reaction formatted as @@ -493,9 +491,8 @@ def enumerate_tautomers(reaction_smiles: str) -> Optional[List[str]]: def mapping_success_rate(list_mapping_data): - """ - Calculate the success rate of entries containing atom mappings in a list of data - strings. + """Calculate the success rate of entries containing atom mappings in a list + of data strings. Parameters: - list_mapping_in_data (list of str): List containing strings to be searched for atom @@ -516,4 +513,4 @@ def mapping_success_rate(list_mapping_data): ) rate = 100 * (success / len(list_mapping_data)) - return round(rate, 2) + return round(rate, 2) \ No newline at end of file diff --git a/synkit/Graph/ITS/its_expand.py b/synkit/Graph/ITS/its_expand.py index 260786a..bb3cf2f 100644 --- a/synkit/Graph/ITS/its_expand.py +++ b/synkit/Graph/ITS/its_expand.py @@ -11,21 +11,23 @@ class ITSExpand: - """ - Partially expand a reaction SMILES (RSMI) by reconstructing intermediate transition states - (ITS) and applying transformation rules based on the reaction center graph. + """Partially expand a reaction SMILES (RSMI) by reconstructing intermediate + transition states (ITS) and applying transformation rules based on the + reaction center graph. - This class identifies the reaction center from an RSMI, builds and reconstructs the ITS graph, - decomposes it back into reactants and products, and standardizes atom mappings to produce - a fully mapped AAM RSMI. + This class identifies the reaction center from an RSMI, builds and + reconstructs the ITS graph, decomposes it back into reactants and + products, and standardizes atom mappings to produce a fully mapped + AAM RSMI. :cvar std: Standardize instance for reaction SMILES standardization. :type std: Standardize """ def __init__(self) -> None: - """ - Initialize ITSExpand. No instance-specific attributes are required. + """Initialize ITSExpand. + + No instance-specific attributes are required. """ pass @@ -35,8 +37,8 @@ def expand_aam_with_its( relabel: bool = False, use_G: bool = True, ) -> str: - """ - Expand a partial reaction SMILES to a full AAM RSMI using ITS reconstruction. + """Expand a partial reaction SMILES to a full AAM RSMI using ITS + reconstruction. :param rsmi: Reaction SMILES string in the format 'reactant>>product'. :type rsmi: str @@ -81,4 +83,4 @@ def expand_aam_with_its( # Convert graphs back to RSMI and standardize atom mappings expanded_rsmi = graph_to_rsmi(new_react, new_prod, its_graph, True, False) - return std.fit(expanded_rsmi, remove_aam=False) + return std.fit(expanded_rsmi, remove_aam=False) \ No newline at end of file diff --git a/synkit/Graph/ITS/its_relabel.py b/synkit/Graph/ITS/its_relabel.py index d3cca6c..be71882 100644 --- a/synkit/Graph/ITS/its_relabel.py +++ b/synkit/Graph/ITS/its_relabel.py @@ -12,26 +12,22 @@ class ITSRelabel: - """ - Extend reaction SMILES through atom-map alignment between reactant and product SynGraphs. + """Extend reaction SMILES through atom-map alignment between reactant and + product SynGraphs. :cvar logger: Logger instance for debug and info messages. :type logger: logging.Logger - :ivar graph_to_mol: Converter from SynGraph to RDKit Mol. :type graph_to_mol: GraphToMol """ def __init__(self) -> None: - """ - Initialize ITSRelabel with default GraphToMol converter. - """ + """Initialize ITSRelabel with default GraphToMol converter.""" self.graph_to_mol = GraphToMol() @staticmethod def _get_nodes_with_atom_map(graph: SynGraph) -> List[Any]: - """ - Extract node IDs with a non-zero atom_map attribute from a SynGraph. + """Extract node IDs with a non-zero atom_map attribute from a SynGraph. :param graph: Input SynGraph with 'atom_map' on nodes. :type graph: SynGraph @@ -46,8 +42,7 @@ def _get_nodes_with_atom_map(graph: SynGraph) -> List[Any]: @staticmethod def _remove_internal_edges(graph: SynGraph, nodes: List[Any]) -> SynGraph: - """ - Remove edges connecting nodes in the given list from a SynGraph. + """Remove edges connecting nodes in the given list from a SynGraph. :param graph: Input SynGraph to prune. :type graph: SynGraph @@ -68,8 +63,7 @@ def _remove_internal_edges(graph: SynGraph, nodes: List[Any]) -> SynGraph: def _dict_to_tuple_list( mapping: Dict[Any, Any], sort_by_key: bool = False, sort_by_value: bool = False ) -> List[Tuple[Any, Any]]: - """ - Convert a mapping dict into a sorted list of tuples. + """Convert a mapping dict into a sorted list of tuples. :param mapping: Dictionary to convert. :type mapping: Dict[Any, Any] @@ -94,19 +88,20 @@ def _update_mapping( mapping: Iterable[Tuple[Any, Any]], aam_key: str = "atom_map", ) -> Tuple[SynGraph, SynGraph]: - """ - Update node attributes in two SynGraphs based on a sequential mapping. + """Update node attributes in two SynGraphs based on a sequential + mapping. - This method resets the specified atom-map attribute for all nodes in both - graphs to 0, then assigns a new atom-map value (i+1) for each mapped pair: - G.nodes[g_node][aam_key] = i + 1 - H.nodes[h_node][aam_key] = i + 1 + This method resets the specified atom-map attribute for all + nodes in both graphs to 0, then assigns a new atom-map value + (i+1) for each mapped pair: G.nodes[g_node][aam_key] = i + 1 + H.nodes[h_node][aam_key] = i + 1 :param G: First SynGraph to update (reactant). :type G: SynGraph :param H: Second SynGraph to update (product). :type H: SynGraph - :param mapping: Iterable of (g_node, h_node) tuples defining node correspondence. + :param mapping: Iterable of (g_node, h_node) tuples defining + node correspondence. :type mapping: Iterable[Tuple[Any, Any]] :param aam_key: Name of the atom-map attribute on each node. :type aam_key: str @@ -138,8 +133,8 @@ def _update_mapping( return G_copy, H_copy def fit(self, rsmi: str) -> str: - """ - Generate an extended reaction SMILES by aligning atom maps of reactant and product. + """Generate an extended reaction SMILES by aligning atom maps of + reactant and product. :param rsmi: Reaction SMILES string formatted as 'reactant>>product'. :type rsmi: str diff --git a/synkit/Graph/ITS/normalize_aam.py b/synkit/Graph/ITS/normalize_aam.py index 4c5fe8f..238fa28 100644 --- a/synkit/Graph/ITS/normalize_aam.py +++ b/synkit/Graph/ITS/normalize_aam.py @@ -11,22 +11,17 @@ class NormalizeAAM: - """ - Provides functionalities to normalize atom mappings in SMILES representations, - extract and process reaction centers from ITS graphs, and convert between - graph representations and molecular models. - """ + """Provides functionalities to normalize atom mappings in SMILES + representations, extract and process reaction centers from ITS graphs, and + convert between graph representations and molecular models.""" def __init__(self) -> None: - """ - Initializes the NormalizeAAM class. - """ + """Initializes the NormalizeAAM class.""" pass @staticmethod def fix_rsmi_kekulize(rsmi: str) -> str: - """ - Filters the reactants and products of a reaction SMILES string. + """Filters the reactants and products of a reaction SMILES string. Parameters: - rsmi (str): A string representing the reaction SMILES in the form of "reactants >> products". @@ -46,8 +41,8 @@ def fix_rsmi_kekulize(rsmi: str) -> str: @staticmethod def fix_kekulize(smiles: str) -> str: - """ - Filters and returns valid SMILES strings from a string of SMILES, joined by '.'. + """Filters and returns valid SMILES strings from a string of SMILES, + joined by '.'. This function processes a string of SMILES separated by periods (e.g., "CCO.CC=O"), filters out invalid SMILES, and returns a string of valid SMILES joined by periods. @@ -73,8 +68,8 @@ def fix_kekulize(smiles: str) -> str: @staticmethod def extract_subgraph(graph: nx.Graph, indices: List[int]) -> nx.Graph: - """ - Extracts a subgraph from a given graph based on a list of node indices. + """Extracts a subgraph from a given graph based on a list of node + indices. Parameters: graph (nx.Graph): The original graph from which to extract the subgraph. @@ -88,8 +83,8 @@ def extract_subgraph(graph: nx.Graph, indices: List[int]) -> nx.Graph: def reset_indices_and_atom_map( self, subgraph: nx.Graph, aam_key: str = "atom_map" ) -> nx.Graph: - """ - Resets the node indices and the atom_map of the subgraph to be continuous from 1 onwards. + """Resets the node indices and the atom_map of the subgraph to be + continuous from 1 onwards. Parameters: subgraph (nx.Graph): The subgraph with possibly non-continuous indices. @@ -111,9 +106,9 @@ def reset_indices_and_atom_map( return new_graph def fit(self, rsmi: str, fix_aam_indice: bool = True) -> str: - """ - Processes a reaction SMILES (RSMI) to adjust atom mappings, extract reaction centers, - decompose into separate reactant and product graphs, and generate the corresponding SMILES. + """Processes a reaction SMILES (RSMI) to adjust atom mappings, extract + reaction centers, decompose into separate reactant and product graphs, + and generate the corresponding SMILES. Parameters: - rsmi (str): The reaction SMILES string to be processed. diff --git a/synkit/Graph/MTG/group_comp.py b/synkit/Graph/MTG/group_comp.py index 3fb23f6..dc83c53 100644 --- a/synkit/Graph/MTG/group_comp.py +++ b/synkit/Graph/MTG/group_comp.py @@ -58,8 +58,7 @@ def get_mapping_from_nodes( edges2: Iterable[Edge], ) -> MappingList: """Return *single‑node* mappings ``[{v₁: v₂}, …]`` that obey the - groupoid order rule w.r.t **all** incident edges on each side. - """ + groupoid order rule w.r.t **all** incident edges on each side.""" # Index incident edges once – O(|E|) inc1: Dict[NodeId, List[Edge]] = defaultdict(list) inc2: Dict[NodeId, List[Edge]] = defaultdict(list) diff --git a/synkit/Graph/MTG/groupoid.py b/synkit/Graph/MTG/groupoid.py index ce21740..b5d473b 100644 --- a/synkit/Graph/MTG/groupoid.py +++ b/synkit/Graph/MTG/groupoid.py @@ -50,7 +50,8 @@ def node_constraint( nodes1: Iterable[Node], nodes2: Iterable[Node], ) -> Dict[NodeId, List[NodeId]]: - """Compute candidate node mappings based on element and groupoid charge rule. + """Compute candidate node mappings based on element and groupoid charge + rule. For each node v1 in nodes1 and v2 in nodes2, v2 is a candidate if: 1. v1.attrs['element'] == v2.attrs['element'], and @@ -169,9 +170,9 @@ def _edge_constraint_vf2( edges2: Iterable[Edge], node_mapping: Optional[Mapping[NodeId, List[NodeId]]] = None, ) -> MappingList: - """ - VF2‐style routine, fully in Python (no NetworkX), seeded like VF3 but + """VF2‐style routine, fully in Python (no NetworkX), seeded like VF3 but relaxed so it returns the same maximal‐common‐subgraph mappings. + The returned dicts will always have their keys sorted ascending. """ # --- build adjacency lists with valid 'order' tuples --- diff --git a/synkit/Graph/MTG/mcs_matcher.py b/synkit/Graph/MTG/mcs_matcher.py index fe5cfdd..e8a6228 100644 --- a/synkit/Graph/MTG/mcs_matcher.py +++ b/synkit/Graph/MTG/mcs_matcher.py @@ -183,7 +183,8 @@ def find_rc_mapping(self, rc1, rc2, *, mcs: bool = False) -> None: # type: igno # Properties and dunders # ------------------------------------------------------------------ def get_mappings(self) -> List[Dict[int, int]]: - """Return the cached mapping list (empty if `find_*` not yet called).""" + """Return the cached mapping list (empty if `find_*` not yet + called).""" return self._mappings.copy() @property diff --git a/synkit/Graph/MTG/mtg.py b/synkit/Graph/MTG/mtg.py index 31b9fab..613ce68 100644 --- a/synkit/Graph/MTG/mtg.py +++ b/synkit/Graph/MTG/mtg.py @@ -117,7 +117,8 @@ def _fuse_nodes(self): def _insert_edges_from( self, edge_iter, node_map: Dict[NodeID, NodeID], existing: List[Edge] = None ) -> List[Edge]: - """Insert edges into *existing* applying the groupoid rule when possible.""" + """Insert edges into *existing* applying the groupoid rule when + possible.""" existing = [] if existing is None else existing.copy() # Remap and append new edges diff --git a/synkit/Graph/Matcher/batch_cluster.py b/synkit/Graph/Matcher/batch_cluster.py index 3657458..db58290 100644 --- a/synkit/Graph/Matcher/batch_cluster.py +++ b/synkit/Graph/Matcher/batch_cluster.py @@ -19,9 +19,8 @@ def __init__( edge_attribute: str = "order", backend: str = "nx", ): - """ - Initializes an AutoCat instance which uses isomorphism checks for categorizing - new graphs or rules. + """Initializes an AutoCat instance which uses isomorphism checks for + categorizing new graphs or rules. Parameters: - node_label_names (List[str]): Names of the node attributes to use in @@ -75,8 +74,8 @@ def lib_check( nodeMatch: Optional[Callable] = None, edgeMatch: Optional[Callable] = None, ) -> Dict: - """ - Checks and classifies a graph or rule based on existing templates using either graph or rule isomorphism. + """Checks and classifies a graph or rule based on existing templates + using either graph or rule isomorphism. Parameters: - data (Dict): A dictionary representing a graph or rule with its attributes and @@ -138,8 +137,7 @@ def lib_check( @staticmethod def batch_dicts(input_list, batch_size): - """ - Splits a list of dictionaries into batches of a specified size. + """Splits a list of dictionaries into batches of a specified size. Args: input_list (list of dict): The list of dictionaries to be batched. @@ -175,8 +173,8 @@ def cluster( rule_key: str = "gml", attribute_key: str = "WLHash", ) -> Tuple[List[Dict], List[Dict]]: - """ - Processes a list of graph data entries, classifying each based on existing templates. + """Processes a list of graph data entries, classifying each based on + existing templates. Parameters: - data (List[Dict]): A list of dictionaries, each representing a graph or rule @@ -199,10 +197,10 @@ def fit( attribute_key: str = "WLHash", batch_size: Optional[int] = None, ) -> Tuple[List[Dict], List[Dict]]: - """ - Processes and classifies data in batches. Uses GraphCluster for initial processing - and a stratified sampling technique to update templates if there is only one batch - and no initial templates are provided. + """Processes and classifies data in batches. Uses GraphCluster for + initial processing and a stratified sampling technique to update + templates if there is only one batch and no initial templates are + provided. Parameters: - data (List[Dict]): Data to process. diff --git a/synkit/Graph/Matcher/graph_cluster.py b/synkit/Graph/Matcher/graph_cluster.py index aa5c936..7777371 100644 --- a/synkit/Graph/Matcher/graph_cluster.py +++ b/synkit/Graph/Matcher/graph_cluster.py @@ -21,10 +21,10 @@ def __init__( edge_attribute: str = "order", backend: str = "nx", ): - """ - Initializes the GraphCluster with customization options for node and edge - matching functions. This class is designed to facilitate clustering of graph nodes - and edges based on specified attributes and their matching criteria. + """Initializes the GraphCluster with customization options for node and + edge matching functions. This class is designed to facilitate + clustering of graph nodes and edges based on specified attributes and + their matching criteria. Parameters: - node_label_names (List[str]): A list of node attribute names to be considered @@ -84,9 +84,9 @@ def iterative_cluster( nodeMatch: Optional[Callable] = None, edgeMatch: Optional[Callable] = None, ) -> Tuple[List[Set[int]], Dict[int, int]]: - """ - Clusters rules based on their similarities, which could include structural or - attribute-based similarities depending on the given attributes. + """Clusters rules based on their similarities, which could include + structural or attribute-based similarities depending on the given + attributes. Parameters: - rules (List[str]): List of rules, potentially serialized strings of rule @@ -159,10 +159,9 @@ def fit( attribute_key: str = "WLHash", strip: bool = False, ) -> List[Dict]: - """ - Automatically clusters the rules and assigns them cluster indices based on the - similarity, potentially using provided templates for clustering, or generating - new templates. + """Automatically clusters the rules and assigns them cluster indices + based on the similarity, potentially using provided templates for + clustering, or generating new templates. Parameters: - data (List[Dict]): A list containing dictionaries, each representing a diff --git a/synkit/Graph/Matcher/graph_morphism.py b/synkit/Graph/Matcher/graph_morphism.py index 137a295..c608acb 100644 --- a/synkit/Graph/Matcher/graph_morphism.py +++ b/synkit/Graph/Matcher/graph_morphism.py @@ -1,380 +1,516 @@ -import logging -import itertools -from operator import eq -from typing import Callable, Optional, Union, List, Any, Dict +import re import networkx as nx -from networkx.algorithms import isomorphism -from networkx.algorithms.isomorphism import GraphMatcher -from networkx.algorithms.isomorphism import generic_node_match, generic_edge_match +from typing import Optional, List +from rdkit import Chem +from rdkit.Chem.MolStandardize import rdMolStandardize -# Alias for any NetworkX graph type -graph_types = Union[nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph] +__all__ = ["get_rc", "its_decompose"] -def find_graph_isomorphism( - G1: graph_types, - G2: graph_types, - node_match: Optional[Callable[[Dict[str, Any], Dict[str, Any]], bool]] = None, - edge_match: Optional[Callable[[Dict[str, Any], Dict[str, Any]], bool]] = None, - use_defaults: bool = True, - fast_invariant_check: bool = True, - logger: Optional[logging.Logger] = None, -) -> Optional[Dict[Any, Any]]: - """ - Check whether two graphs are isomorphic and return the node-mapping. - - :param G1: The first NetworkX graph to compare. - :type G1: nx.Graph | nx.DiGraph | nx.MultiGraph | nx.MultiDiGraph - :param G2: The second NetworkX graph to compare. - :type G2: nx.Graph | nx.DiGraph | nx.MultiGraph | nx.MultiDiGraph - :param node_match: Optional function taking two node attribute dicts and - returning True if they match. - :type node_match: callable or None - :param edge_match: Optional function taking two edge attribute dicts and - returning True if they match. - :type edge_match: callable or None - :param use_defaults: Whether to use default matchers when None. - :type use_defaults: bool - :param fast_invariant_check: Perform quick node/edge count and degree - sequence checks prior to matcher. - :type fast_invariant_check: bool - :param logger: Logger for debug messages. Defaults to root logger. - :type logger: logging.Logger or None - - :returns: A dict mapping nodes in G1 to nodes in G2 if isomorphic; otherwise None. - :rtype: dict[Any, Any] or None +def get_rc( + ITS: nx.Graph, + element_key: List[str] = ["element", "charge", "typesGH", "atom_map"], + bond_key: str = "order", + standard_key: str = "standard_order", + disconnected: bool = False, +) -> nx.Graph: + """Extract the reaction-center (RC) subgraph from an ITS graph. + + This function identifies: + 1. All bonds whose standard order (difference between ITS orders) is non-zero. + 2. All H–H bonds, ensuring they are included even if no order change is detected. + 3. (Optional) Additional nodes with charge changes and reconnection of edges + if `disconnected=True`. + + :param ITS: The integrated transition-state graph with composite node/edge attributes. + :type ITS: nx.Graph + :param element_key: List of node‐attribute keys to copy into the RC graph. + :type element_key: List[str] + :param bond_key: Edge attribute key representing the tuple of bond orders. + :type bond_key: str + :param standard_key: Edge attribute key for the computed standard_order. + :type standard_key: str + :param disconnected: If True, also include nodes with charge changes and + reconnect any ITS edges between RC nodes. + :type disconnected: bool + :returns: A new graph containing only the reaction-center nodes and edges. + :rtype: nx.Graph + + :example: + >>> ITS = nx.Graph() + >>> # ... populate ITS with 'order', 'standard_order', 'typesGH', etc. ... + >>> RC = get_rc(ITS, disconnected=True) + >>> isinstance(RC, nx.Graph) + True """ - log = logger or logging.getLogger(__name__) - - # 1) Ensure same graph type - if type(G1) is not type(G2): - log.debug("Graph types differ: %r vs %r", type(G1), type(G2)) - return None - - # 2) Quick invariants - if fast_invariant_check: - if G1.number_of_nodes() != G2.number_of_nodes(): - log.debug( - "Node counts differ: %d vs %d", - G1.number_of_nodes(), - G2.number_of_nodes(), - ) - return None - if G1.number_of_edges() != G2.number_of_edges(): - log.debug( - "Edge counts differ: %d vs %d", - G1.number_of_edges(), - G2.number_of_edges(), + rc = nx.Graph() + _add_bond_order_changes(ITS, rc, element_key, bond_key, standard_key) + + # 1.5) H-H bonds (force inclusion, with fallback typesGH) + for u, v, data in ITS.edges(data=True): + elem_u = ITS.nodes[u].get("element") + elem_v = ITS.nodes[v].get("element") + if elem_u == "H" and elem_v == "H": + for n in (u, v): + node_data = dict(ITS.nodes[n]) + if "typesGH" not in node_data: + node_data["typesGH"] = ( + ("H", False, 0, 0, []), + ("*", False, 0, 0, []), + ) + # Ensure typesGH is available even if not in original element_key + final_attrs = {k: node_data[k] for k in element_key if k in node_data} + final_attrs["typesGH"] = node_data["typesGH"] + rc.add_node(n, **final_attrs) + + rc.add_edge( + u, + v, + **{ + bond_key: data.get(bond_key), + standard_key: data.get(standard_key), + }, ) - return None - degs1 = sorted(d for _, d in G1.degree()) - degs2 = sorted(d for _, d in G2.degree()) - if degs1 != degs2: - log.debug("Degree sequences differ") - return None - - # 3) Default matchers - if use_defaults: - if node_match is None: - node_match = isomorphism.categorical_node_match( - ["element", "atom_map", "hcount"], ["*", 0, 0] + if disconnected: + _add_charge_change_nodes(ITS, rc, element_key) + _reconnect_rc_edges(ITS, rc, bond_key, standard_key) + + return rc + + +def _carry_node_attrs(src: nx.Graph, dst: nx.Graph, n: int, keys: List[str]) -> None: + """Copy node *n* from *src* to *dst* with only *keys* attributes.""" + if dst.has_node(n): + return + attrs = {k: src.nodes[n][k] for k in keys if k in src.nodes[n]} + dst.add_node(n, **attrs) + + +def _add_charge_change_nodes( + ITS: nx.Graph, + rc: nx.Graph, + keys: List[str], +) -> None: + """Step 3a – add nodes whose *typesGH* shows a charge change.""" + for n, data in ITS.nodes(data=True): + gh = data.get("typesGH") + if ( + isinstance(gh, (list, tuple)) + and len(gh) >= 2 + and gh[0][3] != gh[1][3] + and not rc.has_node(n) + ): + _carry_node_attrs(ITS, rc, n, keys) + + +def _reconnect_rc_edges( + ITS: nx.Graph, + rc: nx.Graph, + bond_key: str, + standard_key: str, +) -> None: + """Step 3b – re-add any original ITS edge between nodes already in RC.""" + for u, v, data in ITS.edges(data=True): + if rc.has_node(u) and rc.has_node(v) and not rc.has_edge(u, v): + rc.add_edge( + u, + v, + **{bond_key: data.get(bond_key), standard_key: data.get(standard_key)}, ) - if edge_match is None: - edge_match = isomorphism.categorical_edge_match("order", 1) - # 4) Select the correct matcher - if isinstance(G1, (nx.MultiGraph, nx.MultiDiGraph)): - if isinstance(G1, nx.MultiGraph): - Matcher = nx.algorithms.isomorphism.MultiGraphMatcher - else: - Matcher = nx.algorithms.isomorphism.MultiDiGraphMatcher - else: - if isinstance(G1, nx.Graph): - Matcher = nx.algorithms.isomorphism.GraphMatcher - else: - Matcher = nx.algorithms.isomorphism.DiGraphMatcher - - matcher = Matcher(G1, G2, node_match=node_match, edge_match=edge_match) - if matcher.is_isomorphic(): - log.debug("Graphs are isomorphic; mapping found") - return matcher.mapping - else: - log.debug("Graphs are not isomorphic") - return None - - -def graph_isomorphism( - graph_1: nx.Graph, - graph_2: nx.Graph, - node_match: Optional[Callable] = None, - edge_match: Optional[Callable] = None, - use_defaults: bool = False, -) -> bool: - """ - Determines if two graphs are isomorphic, considering provided node and edge matching - functions. Uses default matching settings if none are provided. - Parameters: - - graph_1 (nx.Graph): The first graph to compare. - - graph_2 (nx.Graph): The second graph to compare. - - node_match (Optional[Callable]): The function used to match nodes. - Uses default if None. - - edge_match (Optional[Callable]): The function used to match edges. - Uses default if None. +def _add_bond_order_changes( + ITS: nx.Graph, + rc: nx.Graph, + keys: List[str], + bond_key: str, + standard_key: str, +) -> None: + """Step 1 – bond-order-change edges and their nodes.""" + for u, v, data in ITS.edges(data=True): + old, new = data.get(bond_key, (None, None)) + if old == new: + continue + for n in (u, v): + _carry_node_attrs(ITS, rc, n, keys) + rc.add_edge( + u, v, **{bond_key: data[bond_key], standard_key: data.get(standard_key)} + ) - Returns: - - bool: True if the graphs are isomorphic, False otherwise. + +# def get_rc( +# ITS: nx.Graph, +# element_key: List[str] = ["element", "charge", "typesGH", "atom_map"], +# bond_key: str = "order", +# standard_key: str = "standard_order", +# disconnected: bool = False, +# ) -> nx.Graph: +# """ +# Extract the reaction center (RC) from ITS graph. + +# Enhancements: +# - Adds nodes and edges where bond order changes (core logic). +# - If disconnected=True: +# - Adds nodes with charge change based on typesGH. +# - Reconnects any ITS edge between two RC nodes. +# - NEW: Always includes H-H bonds in RC. Adds default typesGH if missing. +# """ +# rc = nx.Graph() + +# # 1) edges with bond-order change +# for u, v, data in ITS.edges(data=True): +# old, new = data.get(bond_key, [None, None]) +# if old != new: +# for n in (u, v): +# if not rc.has_node(n): +# rc.add_node( +# n, +# **{ +# k: ITS.nodes[n][k] for k in element_key if k in ITS.nodes[n] +# }, +# ) +# rc.add_edge( +# u, +# v, +# **{bond_key: data.get(bond_key), standard_key: data.get(standard_key)}, +# ) + +# # 1.5) H-H bonds (force inclusion, with fallback typesGH) +# for u, v, data in ITS.edges(data=True): +# elem_u = ITS.nodes[u].get("element") +# elem_v = ITS.nodes[v].get("element") +# if elem_u == "H" and elem_v == "H": +# for n in (u, v): +# node_data = dict(ITS.nodes[n]) +# if "typesGH" not in node_data: +# node_data["typesGH"] = ( +# ("H", False, 0, 0, []), +# ("*", False, 0, 0, []), +# ) +# # Ensure typesGH is available even if not in original element_key +# final_attrs = {k: node_data[k] for k in element_key if k in node_data} +# final_attrs["typesGH"] = node_data["typesGH"] +# rc.add_node(n, **final_attrs) + +# rc.add_edge( +# u, +# v, +# **{ +# bond_key: data.get(bond_key), +# standard_key: data.get(standard_key), +# }, +# ) + +# if disconnected: +# # 2) nodes with typesGH-based charge change +# for n, data in ITS.nodes(data=True): +# gh = data.get("typesGH") +# if ( +# isinstance(gh, (list, tuple)) +# and len(gh) >= 2 +# and len(gh[0]) > 3 +# and len(gh[1]) > 3 +# and gh[0][3] != gh[1][3] +# ): +# if not rc.has_node(n): +# rc.add_node(n, **{k: data[k] for k in element_key if k in data}) + +# # 3) reconnect RC nodes +# for u, v, data in ITS.edges(data=True): +# if rc.has_node(u) and rc.has_node(v) and not rc.has_edge(u, v): +# rc.add_edge( +# u, +# v, +# **{ +# bond_key: data.get(bond_key), +# standard_key: data.get(standard_key), +# }, +# ) + +# return rc + + +# def get_rc( +# ITS: nx.Graph, +# element_key: List[str] = ["element", "charge", "typesGH", "atom_map"], +# bond_key: str = "order", +# standard_key: str = "standard_order", +# disconnected: bool = False, +# ) -> nx.Graph: +# """ +# Extract the reaction center (RC) from ITS by: + +# 1. Always adding any edge whose bond order changes +# (bond_key[0] != bond_key[1]), plus its two end-nodes. +# 2. [if disconnected=True] Adding any node whose 'typesGH' record shows a charge change +# (typesGH[0][3] != typesGH[1][3]), even if isolated. +# 3. [if disconnected=True] Re-adding any ITS edge between two nodes already in RC +# (to preserve connectivity), carrying over bond_key & standard_key. + +# Parameters: +# - ITS (nx.Graph): input ITS graph. +# - element_key (List[str]): node attrs to carry over. +# - bond_key (str): edge attr key for bond order. +# - standard_key (str): edge attr key for standard order. +# - disconnected (bool): if True, include “charge-change” nodes (step 2) and +# reconnect any edges among RC nodes (step 3). If False, only performs step 1. +# """ +# rc = nx.Graph() + +# # 1) edges with bond-order change +# for u, v, data in ITS.edges(data=True): +# old, new = data.get(bond_key, [None, None]) +# if old != new: +# for n in (u, v): +# if not rc.has_node(n): +# rc.add_node( +# n, +# **{ +# k: ITS.nodes[n][k] for k in element_key if k in ITS.nodes[n] +# }, +# ) +# rc.add_edge( +# u, +# v, +# **{bond_key: data.get(bond_key), standard_key: data.get(standard_key)}, +# ) + +# if disconnected: +# # 2) nodes with a typesGH-based charge change +# for n, data in ITS.nodes(data=True): +# gh = data.get("typesGH") +# if ( +# isinstance(gh, (list, tuple)) +# and len(gh) >= 2 +# and len(gh[0]) > 3 +# and len(gh[1]) > 3 +# and gh[0][3] != gh[1][3] +# ): +# if not rc.has_node(n): +# rc.add_node(n, **{k: data[k] for k in element_key if k in data}) + +# # 3) re-add any ITS edge between RC nodes to preserve connectivity +# for u, v, data in ITS.edges(data=True): +# if rc.has_node(u) and rc.has_node(v) and not rc.has_edge(u, v): +# rc.add_edge( +# u, +# v, +# **{ +# bond_key: data.get(bond_key), +# standard_key: data.get(standard_key), +# }, +# ) + +# return rc + + +def its_decompose(its_graph: nx.Graph, nodes_share="typesGH", edges_share="order"): + """Decompose an ITS graph into two separate reactant (G) and product (H) + graphs. + + Nodes and edges in `its_graph` carry composite attributes: + - Each node has `its_graph.nodes[nodes_share] = (node_attrs_G, node_attrs_H)`. + - Each edge has `its_graph.edges[edges_share] = (order_G, order_H)`. + + This function splits those tuples to reconstruct the original G and H graphs. + + :param its_graph: The ITS graph with composite node/edge attributes. + :type its_graph: nx.Graph + :param nodes_share: Node attribute key storing (G_attrs, H_attrs) tuples. + :type nodes_share: str + :param edges_share: Edge attribute key storing (order_G, order_H) tuples. + :type edges_share: str + :returns: A tuple of two graphs (G, H) reconstructed from the ITS. + :rtype: Tuple[nx.Graph, nx.Graph] + + :example: + >>> its = nx.Graph() + >>> # ... set its.nodes[n]['typesGH'] and its.edges[e]['order'] ... + >>> G, H = its_decompose(its) + >>> isinstance(G, nx.Graph) and isinstance(H, nx.Graph) + True """ - # Define default node and edge attributes and match settings - if use_defaults: - node_label_names = ["element", "charge"] - node_label_default = ["*", 0] - edge_attribute = "order" - - # Default node and edge match functions if not provided - if node_match is None: - node_match = generic_node_match( - node_label_names, node_label_default, [eq] * len(node_label_names) + G = nx.Graph() + H = nx.Graph() + + # Decompose nodes + for node, data in its_graph.nodes(data=True): + if nodes_share in data: + node_attr_g, node_attr_h = data[nodes_share] + # Unpack node attributes for G + G.add_node( + node, + element=node_attr_g[0], + aromatic=node_attr_g[1], + hcount=node_attr_g[2], + charge=node_attr_g[3], + neighbors=node_attr_g[4], + atom_map=node, ) - if edge_match is None: - edge_match = generic_edge_match(edge_attribute, 1, eq) - - # Perform the isomorphism check using NetworkX - return nx.is_isomorphic( - graph_1, graph_2, node_match=node_match, edge_match=edge_match - ) - - -def subgraph_isomorphism( - child_graph: nx.Graph, - parent_graph: nx.Graph, - node_label_names: List[str] = ["element", "charge"], - node_label_default: List[Any] = ["*", 0], - edge_attribute: str = "order", - use_filter: bool = False, - check_type: str = "induced", # "induced" or "monomorphism" - node_comparator: Optional[Callable[[Any, Any], bool]] = None, - edge_comparator: Optional[Callable[[Any, Any], bool]] = None, + if len(node_attr_h) > 0: + # Unpack node attributes for H + H.add_node( + node, + element=node_attr_h[0], + aromatic=node_attr_h[1], + hcount=node_attr_h[2], + charge=node_attr_h[3], + neighbors=node_attr_h[4], + atom_map=node, + ) + + # Decompose edges + for u, v, data in its_graph.edges(data=True): + if edges_share in data: + order_g, order_h = data[edges_share] + if order_g > 0: # Assuming 0 means no edge in G + G.add_edge(u, v, order=order_g) + if order_h > 0: # Assuming 0 means no edge in H + H.add_edge(u, v, order=order_h) + + return G, H + + +def compare_graphs( + graph1: nx.Graph, + graph2: nx.Graph, + node_attrs: list = ["element", "aromatic", "hcount", "charge", "neighbors"], + edge_attrs: list = ["order"], ) -> bool: - """ - Enhanced checks if the child graph is a subgraph isomorphic to the parent graph based on - customizable node and edge attributes. + """Compare two graphs based on specified node and edge attributes. Parameters: - - child_graph (nx.Graph): The child graph. - - parent_graph (nx.Graph): The parent graph. - - node_label_names (List[str]): Labels to compare. - - node_label_default (List[Any]): Defaults for missing node labels. - - edge_attribute (str): The edge attribute to compare. - - use_filter (bool): Whether to use pre-filters based on node and edge count. - - check_type (str): "induced" (default) or "monomorphism" for the type of subgraph matching. - - node_comparator (Callable[[Any, Any], bool]): Custom comparator for node attributes. - - edge_comparator (Callable[[Any, Any], bool]): Custom comparator for edge attributes. + - graph1 (nx.Graph): The first graph to compare. + - graph2 (nx.Graph): The second graph to compare. + - node_attrs (list): A list of node attribute names to include in the comparison. + - edge_attrs (list): A list of edge attribute names to include in the comparison. Returns: - - bool: True if subgraph isomorphism is found, False otherwise. + - bool: True if both graphs are identical with respect to the specified attributes, + otherwise False. """ - if use_filter: - # Initial quick filters based on node and edge counts - if len(child_graph) > len(parent_graph) or len(child_graph.edges) > len( - parent_graph.edges - ): - return False + # Compare node sets + if set(graph1.nodes()) != set(graph2.nodes()): + return False - # Step 2: Node label filter - Only consider 'element' and 'charge' attributes - for _, child_data in child_graph.nodes(data=True): - found_match = False - for _, parent_data in parent_graph.nodes(data=True): - match = True - # Compare only the 'element' and 'charge' attributes - for label, default in zip(node_label_names, node_label_default): - child_value = child_data.get(label, default) - parent_value = parent_data.get(label, default) - if child_value != parent_value: - match = False - break - if match: - found_match = True - break - if not found_match: - return False - - # Step 3: Edge label filter - Ensure that the edge attribute 'order' matches if provided - if edge_attribute: - for child_edge in child_graph.edges(data=True): - child_node1, child_node2, child_data = child_edge - if child_node1 in parent_graph and child_node2 in parent_graph: - # Ensure the edge exists in the parent graph - if not parent_graph.has_edge(child_node1, child_node2): - return False - # Check if the 'order' attribute matches - parent_edge_data = parent_graph[child_node1][child_node2] - child_order = child_data.get(edge_attribute) - parent_order = parent_edge_data.get(edge_attribute) - - # Handle comparison of tuple values for 'order' attribute - if isinstance(child_order, tuple) and isinstance( - parent_order, tuple - ): - if child_order != parent_order: - return False - elif child_order != parent_order: - return False - else: - return False - - # Setting up attribute comparison functions - node_comparator = node_comparator if node_comparator else eq - edge_comparator = edge_comparator if edge_comparator else eq - - # Creating match conditions for nodes and edges based on custom or default comparators - node_match = generic_node_match( - node_label_names, node_label_default, [node_comparator] * len(node_label_names) - ) - edge_match = ( - generic_edge_match(edge_attribute, None, edge_comparator) - if edge_attribute - else None - ) - - # Graph matching setup - matcher = GraphMatcher( - parent_graph, child_graph, node_match=node_match, edge_match=edge_match - ) + # Compare nodes based on attributes + for node in graph1.nodes(): + if node not in graph2: + return False + node_data1 = {attr: graph1.nodes[node].get(attr, None) for attr in node_attrs} + node_data2 = {attr: graph2.nodes[node].get(attr, None) for attr in node_attrs} + if node_data1 != node_data2: + return False - # Executing the matching based on specified type - if check_type == "induced": - return matcher.subgraph_is_isomorphic() - else: - return matcher.subgraph_is_monomorphic() + # Compare edge sets with sorted tuples + if set(tuple(sorted(edge)) for edge in graph1.edges()) != set( + tuple(sorted(edge)) for edge in graph2.edges() + ): + return False + + # Compare edges based on attributes + for edge in graph1.edges(): + # Sort the edge for consistent comparison + sorted_edge = tuple(sorted(edge)) + if sorted_edge not in graph2.edges(): + return False + edge_data1 = {attr: graph1.edges[edge].get(attr, None) for attr in edge_attrs} + edge_data2 = { + attr: graph2.edges[sorted_edge].get(attr, None) for attr in edge_attrs + } + if edge_data1 != edge_data2: + return False + return True -def maximum_connected_common_subgraph( - graph_1: nx.Graph, - graph_2: nx.Graph, - node_label_names: List[str] = ["element", "charge"], - node_label_default: List[Any] = ["*", 0], - edge_attribute: str = "standard_order", -) -> nx.Graph: - """ - Computes the largest connected common subgraph (MCS) between two graphs using - subgraph isomorphism based on customizable node and edge attributes. - The function iterates over subsets of nodes from the smaller graph—starting from the largest - possible subgraph size down to 1—and returns the first (largest) candidate that is connected - and is isomorphic to a subgraph of the larger graph. +def enumerate_tautomers(reaction_smiles: str) -> Optional[List[str]]: + """Enumerates possible tautomers for reactants while canonicalizing the + products in a reaction SMILES string. This function first splits the + reaction SMILES string into reactants and products. It then generates all + possible tautomers for the reactants and canonicalizes the product + molecule. The function returns a list of reaction SMILES strings for each + tautomer of the reactants combined with the canonical product. Parameters: - - graph_1 (nx.Graph): The first graph for comparison. - - graph_2 (nx.Graph): The second graph for comparison. - - node_label_names (List[str]): List of node attribute names used for matching. - - node_label_default (List[Any]): Default values for missing node attributes. - - edge_attribute (str): The edge attribute to compare. + - reaction_smiles (str): A SMILES string of the reaction formatted as + 'reactants>>products'. Returns: - - nx.Graph: A graph representing the largest connected common subgraph found; if none exists, - returns an empty graph. + - List[str] | None: A list of SMILES strings for the reaction, with each string + representing a different + - tautomer of the reactants combined with the canonicalized products. Returns None if + an error occurs or if invalid SMILES strings are provided. + + Raises: + - ValueError: If the provided SMILES strings cannot be converted to molecule objects, + indicating invalid input. """ - node_match = generic_node_match( - node_label_names, node_label_default, [eq] * len(node_label_names) - ) - edge_match = generic_edge_match(edge_attribute, 1, eq) - - # Determine which graph is smaller for efficiency. - if graph_1.number_of_nodes() <= graph_2.number_of_nodes(): - smaller_graph, larger_graph = graph_1, graph_2 - else: - smaller_graph, larger_graph = graph_2, graph_1 - - num_nodes_smaller = smaller_graph.number_of_nodes() - # Iterate over possible subgraph sizes from the largest to 1. - for subgraph_size in range(num_nodes_smaller, 0, -1): - for nodes_subset in itertools.combinations( - smaller_graph.nodes(), subgraph_size - ): - candidate_subgraph = smaller_graph.subgraph(nodes_subset) - # If the subgraph has more than one node, check it is connected. - if candidate_subgraph.number_of_nodes() > 1 and not nx.is_connected( - candidate_subgraph - ): - continue - - # Check for subgraph isomorphism in the larger graph. - matcher = GraphMatcher( - larger_graph, - candidate_subgraph, - node_match=node_match, - edge_match=edge_match, + try: + # Split the input reaction SMILES string into reactants and products + reactants_smiles, products_smiles = reaction_smiles.split(">>") + + # Convert SMILES strings to molecule objects + reactants_mol = Chem.MolFromSmiles(reactants_smiles) + products_mol = Chem.MolFromSmiles(products_smiles) + + if reactants_mol is None or products_mol is None: + raise ValueError( + "Invalid SMILES string provided for reactants or products." ) - if matcher.subgraph_is_isomorphic(): - return candidate_subgraph.copy() - return nx.Graph() + # Initialize tautomer enumerator + enumerator = rdMolStandardize.TautomerEnumerator() -def heuristics_MCCS( - graphs: List[nx.Graph], - node_label_names: List[str] = ["element", "charge"], - node_label_default: List[Any] = ["*", 0], - edge_attribute: str = "standard_order", -) -> nx.Graph: - """ - Computes the Maximum Connected Common Subgraph (MCCS) over a list of graphs using a heuristic approach. + # Enumerate tautomers for the reactants and canonicalize the products + try: + reactants_can = enumerator.Enumerate(reactants_mol) + except Exception as e: + print(f"An error occurred: {e}") + reactants_can = [reactants_mol] + products_can = products_mol + + # Convert molecule objects back to SMILES strings + reactants_can_smiles = [Chem.MolToSmiles(i) for i in reactants_can] + products_can_smiles = Chem.MolToSmiles(products_can) + + # Combine each reactant tautomer with the canonical product in SMILES format + rsmi_list = [i + ">>" + products_can_smiles for i in reactants_can_smiles] + if len(rsmi_list) == 0: + return [reaction_smiles] + else: + # rsmi_list.remove(reaction_smiles) + rsmi_list.insert(0, reaction_smiles) + return rsmi_list - This function computes the MCCS between the first two graphs using the - `maximum_connected_common_subgraph` function based on customizable node and edge attributes. - For more than two graphs, it iteratively updates the common subgraph by calculating the MCCS - between the current common subgraph and each subsequent graph. An early exit occurs if the - intermediate common subgraph becomes empty. + except Exception as e: + print(f"An error occurred: {e}") + return [reaction_smiles] + + +def mapping_success_rate(list_mapping_data): + """Calculate the success rate of entries containing atom mappings in a list + of data strings. Parameters: - - graphs (List[nx.Graph]): A list of networkx graphs for which the common subgraph is to be computed. - - node_label_names (List[str]): List of node attribute names used for matching. - - node_label_default (List[Any]): Default values for missing node attributes. - - edge_attribute (str): The edge attribute to compare. + - list_mapping_in_data (list of str): List containing strings to be searched for atom + mappings. Returns: - - nx.Graph: The maximum connected common subgraph common to all provided graphs. If no common - subgraph exists, an empty graph is returned. + - float: The success rate of finding atom mappings in the list as a percentage. Raises: - - ValueError: If the input list of graphs is empty. + - ValueError: If the input list is empty. """ - if not graphs: - raise ValueError("Input list of graphs is empty.") - - if len(graphs) == 1: - return graphs[0].copy() - - # Handle the two-graph case explicitly. - if len(graphs) == 2: - return maximum_connected_common_subgraph( - graphs[0], - graphs[1], - node_label_names=node_label_names, - node_label_default=node_label_default, - edge_attribute=edge_attribute, - ) + atom_map_pattern = re.compile(r":\d+") + if not list_mapping_data: + raise ValueError("The input list is empty, cannot calculate success rate.") - # Iteratively compute the MCCS for more than two graphs. - current_mcs = maximum_connected_common_subgraph( - graphs[0], - graphs[1], - node_label_names=node_label_names, - node_label_default=node_label_default, - edge_attribute=edge_attribute, + success = sum( + 1 for entry in list_mapping_data if re.search(atom_map_pattern, entry) ) + rate = 100 * (success / len(list_mapping_data)) - for graph in graphs[2:]: - if current_mcs.number_of_nodes() == 0: - break # Early exit if no common subgraph remains. - current_mcs = maximum_connected_common_subgraph( - current_mcs, - graph, - node_label_names=node_label_names, - node_label_default=node_label_default, - edge_attribute=edge_attribute, - ) - - return current_mcs + return round(rate, 2) \ No newline at end of file diff --git a/synkit/Graph/Matcher/multi_turbo_iso.py b/synkit/Graph/Matcher/multi_turbo_iso.py index bcc2a85..db1e37e 100644 --- a/synkit/Graph/Matcher/multi_turbo_iso.py +++ b/synkit/Graph/Matcher/multi_turbo_iso.py @@ -9,21 +9,22 @@ class MultiTurboISO: - """ - Accelerated sub-graph search across a batch of host graphs. + """Accelerated sub-graph search across a batch of host graphs. Builds a single global signature bucket over all hosts and reuses a - lightweight TurboISO matcher per host. For each query graph, hosts are - first pruned by a signature + degree filter, and then TurboISO’s + lightweight TurboISO matcher per host. For each query graph, hosts + are first pruned by a signature + degree filter, and then TurboISO’s backtracking is run only on the surviving hosts. :param hosts: List of host graphs to index. :type hosts: List[nx.Graph] :param node_label: Node attribute(s) used for signature matching. :type node_label: str or list[str] - :param edge_label: Edge attribute(s) to match; pass None to ignore edges. + :param edge_label: Edge attribute(s) to match; pass None to ignore + edges. :type edge_label: str or list[str] or None - :param distance_threshold: Skip distance filtering if candidate pool is smaller. + :param distance_threshold: Skip distance filtering if candidate pool + is smaller. :type distance_threshold: int :returns: An instance of MultiTurboISO with global index built. :rtype: MultiTurboISO @@ -99,7 +100,10 @@ def node_label(self) -> List[str]: @property def edge_label(self) -> List[str]: - """Edge‑attribute selector(s). Empty list means ‘ignore’.""" + """Edge‑attribute selector(s). + + Empty list means ‘ignore’. + """ return list(self._edge_attr) # -------------------------------------------------------------- helpers @@ -172,7 +176,7 @@ def search_many( ) -> List[Dict[int, Union[bool, List[Dict[Any, Any]]]]]: """Match a list of pattern graphs. - Returns a list of per‑pattern dictionaries in the same order as the - input list. + Returns a list of per‑pattern dictionaries in the same order as + the input list. """ return [self.search_one(p, prune=prune) for p in patterns] diff --git a/synkit/Graph/Matcher/sing.py b/synkit/Graph/Matcher/sing.py index 3e4f1b9..39c9a8b 100644 --- a/synkit/Graph/Matcher/sing.py +++ b/synkit/Graph/Matcher/sing.py @@ -5,9 +5,10 @@ class SING: """Subgraph search In Non-homogeneous Graphs (SING) - A lightweight Python implementation adopting a *filter-and-refine* strategy - with path-based features. This version supports **heterogeneous graphs** - through flexible **node and edge attribute selections**. + A lightweight Python implementation adopting a *filter-and-refine* + strategy with path-based features. This version supports + **heterogeneous graphs** through flexible **node and edge attribute + selections**. """ # --------------------------------------------------------------------- @@ -62,14 +63,16 @@ def __init__( # ------------------------------------------------------------------ def _node_signature(self, v: Any, G: nx.Graph) -> str: - """Return a string signature for *v* in *G* based on ``self.node_att``.""" + """Return a string signature for *v* in *G* based on + ``self.node_att``.""" vals = [str(G.nodes[v].get(a, "#")) for a in self.node_att] return "|".join(vals) def _edge_signature(self, u: Any, v: Any, G: nx.Graph) -> str: """Return a string signature for edge *(u,v)* in *G* based on - ``self.edge_att``. If no edge attributes were requested, returns an - empty string. + ``self.edge_att``. + + If no edge attributes were requested, returns an empty string. """ if not self.edge_att: return "" @@ -85,7 +88,9 @@ def _extract_path_features( ) -> Set[str]: """Enumerate *all* simple paths starting at *node* up to ``self.max_path_length`` edges (inclusive), represented as label - sequences. Works for both data and query graphs. + sequences. + + Works for both data and query graphs. """ features: Set[str] = set() max_len = self.max_path_length @@ -126,8 +131,7 @@ def _build_index(self) -> None: def _candidate_vertices(self, query_graph: nx.Graph) -> Dict[Any, Set[Any]]: """Return *per-query-vertex* candidate sets using posting-list - intersections. - """ + intersections.""" cand: Dict[Any, Set[Any]] = {} for qv in query_graph.nodes: q_feats = self._extract_path_features(qv, query_graph, is_query=True) diff --git a/synkit/Graph/Matcher/subgraph_matcher.py b/synkit/Graph/Matcher/subgraph_matcher.py index aa70236..1eb5678 100644 --- a/synkit/Graph/Matcher/subgraph_matcher.py +++ b/synkit/Graph/Matcher/subgraph_matcher.py @@ -109,16 +109,16 @@ # Core engine class # --------------------------------------------------------------------------- class SubgraphMatch: - """ - Boolean-only checks for graph isomorphism and subgraph (induced or monomorphic) matching. + """Boolean-only checks for graph isomorphism and subgraph (induced or + monomorphic) matching. - Provides static methods for NetworkX-based checks and optional GML "rule" backend. + Provides static methods for NetworkX-based checks and optional GML + "rule" backend. """ @staticmethod def _get_edge_labels(graph: Any) -> list: - """ - Extracts the bond types (edge labels) from a given graph. + """Extracts the bond types (edge labels) from a given graph. Parameters: - graph: The graph object containing the edges. @@ -130,8 +130,7 @@ def _get_edge_labels(graph: Any) -> list: @staticmethod def _get_node_labels(graph: Any) -> list: - """ - Extracts the atom IDs (node labels) from a given graph. + """Extracts the atom IDs (node labels) from a given graph. Parameters: - graph: The graph object containing the vertices. @@ -145,9 +144,8 @@ def _get_node_labels(graph: Any) -> list: def rule_subgraph_morphism( rule_1: str, rule_2: str, use_filter: bool = False ) -> bool: - """ - Evaluates if two GML-formatted rule representations are isomorphic or one is a - subgraph of the other (monomorphic). + """Evaluates if two GML-formatted rule representations are isomorphic + or one is a subgraph of the other (monomorphic). Parameters: - rule_1 (str): GML string of the first rule. @@ -191,10 +189,8 @@ def subgraph_isomorphism( node_comparator: Optional[Callable[[Any, Any], bool]] = None, edge_comparator: Optional[Callable[[Any, Any], bool]] = None, ) -> bool: - """ - Enhanced checks if the child graph is a subgraph isomorphic to the parent graph based on - customizable node and edge attributes. - """ + """Enhanced checks if the child graph is a subgraph isomorphic to the + parent graph based on customizable node and edge attributes.""" if use_filter: if ( child_graph.number_of_nodes() > parent_graph.number_of_nodes() @@ -263,9 +259,8 @@ def is_subgraph( check_type: str = "induced", backend: str = "nx", ) -> bool: - """ - Unified API for subgraph/isomorphism either via NX or GML backend. - """ + """Unified API for subgraph/isomorphism either via NX or GML + backend.""" if backend == "nx": return SubgraphMatch.subgraph_isomorphism( pattern, @@ -303,8 +298,8 @@ def find_subgraph_mappings( max_results: Optional[int] = None, strict_cc_count: bool = True, ) -> List[MappingDict]: - """ - Dispatch to a subgraph-matching strategy and return all pattern→host mappings. + """Dispatch to a subgraph-matching strategy and return all pattern→host + mappings. Depending on `strategy`, this will call: - ALL → `_all_monomorphisms` @@ -406,8 +401,7 @@ def _find_component_aware_subgraph_mappings( max_results: Optional[int] = None, strict_cc_count: bool = False, ) -> List[MappingDict]: - """ - Component‑aware VF2 without any attribute/degree/WL‑1 pre‑filters. + """Component‑aware VF2 without any attribute/degree/WL‑1 pre‑filters. The only constraints are: • each pattern‑CC must fit in a *distinct* host‑CC diff --git a/synkit/Graph/Matcher/turbo_iso.py b/synkit/Graph/Matcher/turbo_iso.py index 22dfa77..4a01601 100644 --- a/synkit/Graph/Matcher/turbo_iso.py +++ b/synkit/Graph/Matcher/turbo_iso.py @@ -81,7 +81,9 @@ def _init_candidates(self, Q: nx.Graph) -> Dict[Any, Set[Any]]: # ------------------------------------------------ distance consistency def _within_dist(self, src: Any, dsts: Set[Any], limit: int) -> bool: """Check whether *any* dst in *dsts* lies within *limit* hops of src. - Stops BFS early once found. Returns True/False.""" + + Stops BFS early once found. Returns True/False. + """ if not dsts: return False if limit == float("inf"): diff --git a/synkit/Graph/Wildcard/fuse_graph.py b/synkit/Graph/Wildcard/fuse_graph.py index dbc93d2..4007a0b 100644 --- a/synkit/Graph/Wildcard/fuse_graph.py +++ b/synkit/Graph/Wildcard/fuse_graph.py @@ -15,11 +15,10 @@ def find_wc_graph_isomorphism( edge_match: Optional[Callable[[Dict[str, Any], Dict[str, Any]], bool]] = None, logger: Optional[logging.Logger] = None, ) -> Optional[Dict[Any, Any]]: - """ - Wildcard‑aware sub‑graph isomorphism. Returns a mapping from every node in - the **smaller** graph to a node in the **larger** graph, allowing any node - whose ``element == "*"`` to match *any* concrete node (or group of nodes) - on the host side. + """Wildcard‑aware sub‑graph isomorphism. Returns a mapping from every node + in the **smaller** graph to a node in the **larger** graph, allowing any + node whose ``element == "*"`` to match *any* concrete node (or group of + nodes) on the host side. :param G1: First input graph. :type G1: nx.Graph | nx.DiGraph | nx.MultiGraph | nx.MultiDiGraph @@ -87,8 +86,7 @@ def fuse_wc_graphs( wildcard: str = "*", logger: Optional[logging.Logger] = None, ) -> GraphType: - """ - Fuse a wildcard‑pattern graph *G1* into the concrete host *G2*. + """Fuse a wildcard‑pattern graph *G1* into the concrete host *G2*. The result lives **entirely in G2’s node‑ID space** and contains: diff --git a/synkit/Graph/canon_graph.py b/synkit/Graph/canon_graph.py index b0c9043..f134d88 100644 --- a/synkit/Graph/canon_graph.py +++ b/synkit/Graph/canon_graph.py @@ -100,7 +100,8 @@ def _default_edge_key(u: NodeId, v: NodeId, data: EdgeData) -> Tuple[Any, ...]: def _digest(text: str) -> Digest: - """First 32 hex chars of SHA‑256 – short *but* collision‑safe for up to 2¹²⁸ graphs.""" + """First 32 hex chars of SHA‑256 – short *but* collision‑safe for up to + 2¹²⁸ graphs.""" return hashlib.sha256(text.encode()).hexdigest()[:32] @@ -110,8 +111,7 @@ def _digest(text: str) -> Digest: class GraphCanonicaliser: - """ - Factory that turns arbitrary ``networkx.Graph`` objects into their + """Factory that turns arbitrary ``networkx.Graph`` objects into their *canonical* twin plus a **stable 32‑hex digest**. Parameters @@ -169,8 +169,7 @@ def __init__( # High‑level helpers # # ------------------------------------------------------------------ # def canonicalise_graph(self, graph: nx.Graph) -> "CanonicalGraph": - """ - Return a :class:`CanonicalGraph` wrapper around *graph*. + """Return a :class:`CanonicalGraph` wrapper around *graph*. The wrapper exposes: @@ -183,8 +182,7 @@ def canonicalise_graphs( self, graphs: Iterable[nx.Graph], ) -> Tuple["CanonicalGraph", ...]: - """ - Bulk helper that returns *all* wrappers **sorted by hash**. + """Bulk helper that returns *all* wrappers **sorted by hash**. Useful when you want fast set comparison but need the canonical graphs as well: @@ -203,8 +201,7 @@ def canonicalise_graphs( # Digest / core methods # # ------------------------------------------------------------------ # def canonical_signature(self, graph: nx.Graph) -> Digest: - """ - Return the *hash of the canonical form* of *graph*. + """Return the *hash of the canonical form* of *graph*. Equal digests ⇒ graphs are guaranteed isomorphic **under the chosen back‑end and keys**. @@ -254,8 +251,7 @@ def _canon_generic(self, g: nx.Graph) -> nx.Graph: return G2 def _canon_wl(self, g: nx.Graph) -> nx.Graph: - """ - Weisfeiler–Lehman colour-refinement back-end (pure Python). + """Weisfeiler–Lehman colour-refinement back-end (pure Python). Seeds each node’s initial colour by the tuple of attributes in `self._wl_node_attrs` (e.g. ["element","charge","hcount"]), @@ -352,8 +348,7 @@ def __repr__(self) -> str: # pragma: no cover # Value wrapper (unchanged surface – richer docs) # ============================================================================= class CanonicalGraph: - """ - *Value object* tying together: + """*Value object* tying together: * the **original** NetworkX graph (mutable, user‑supplied); * its **canonical twin** (immutable copy, nodes relabelled 1…N); @@ -428,9 +423,9 @@ def help(self) -> None: # pragma: no cover class CanonicalRule: - """ - Value object that wraps a graph transformation rule in GML string form, - providing a canonicalised GML output and a stable 32-character SHA-256 hash. + """Value object that wraps a graph transformation rule in GML string form, + providing a canonicalised GML output and a stable 32-character SHA-256 + hash. Internally, the GML rule is parsed into a NetworkX graph via `gml_to_its`, canonicalised using a `GraphCanonicaliser`, and re-serialized back to GML @@ -458,8 +453,7 @@ def __init__( rule: str, canon: GraphCanonicaliser = GraphCanonicaliser(), ) -> None: - """ - Instantiate a CanonicalRule. + """Instantiate a CanonicalRule. Parameters ---------- @@ -523,9 +517,7 @@ def canonical_hash(self) -> Digest: return self._canonical_hash def help(self) -> None: - """ - Print original and canonical rule texts and underlying graphs. - """ + """Print original and canonical rule texts and underlying graphs.""" print("Original GML rule:") print(self._original_rule) print("\nCanonical GML rule:") diff --git a/synkit/Graph/syn_graph.py b/synkit/Graph/syn_graph.py index 685f264..3b8d5e0 100644 --- a/synkit/Graph/syn_graph.py +++ b/synkit/Graph/syn_graph.py @@ -34,9 +34,8 @@ class SynGraph: - """ - Wrapper around networkx.Graph providing both its original and (optionally) - canonicalized form, plus a SHA-256 signature. + """Wrapper around networkx.Graph providing both its original and + (optionally) canonicalized form, plus a SHA-256 signature. Parameters: - graph (nx.Graph): The NetworkX graph to wrap. @@ -62,8 +61,7 @@ def __init__( canonicaliser: Optional[GraphCanonicaliser] = None, canon: bool = True, ) -> None: - """ - Initialize a SynGraph wrapper. + """Initialize a SynGraph wrapper. Parameters: - graph (nx.Graph): Input graph. @@ -82,23 +80,18 @@ def __init__( self._canonical = None def __getattr__(self, name: str) -> Any: - """ - Delegate any unknown attribute lookup to the underlying ._raw graph. - """ + """Delegate any unknown attribute lookup to the underlying ._raw + graph.""" return getattr(self._raw, name) def __eq__(self, other: object) -> bool: - """ - Two SynGraph instances are equal iff their signatures match. - """ + """Two SynGraph instances are equal iff their signatures match.""" if not isinstance(other, SynGraph): return False return self.signature == other.signature def __hash__(self) -> int: - """ - Hash on the signature, allowing use in sets and as dict keys. - """ + """Hash on the signature, allowing use in sets and as dict keys.""" return hash(self.signature) @property @@ -119,8 +112,7 @@ def signature(self) -> Optional[str]: def get_nodes( self, data: bool = True ) -> Iterable[Union[Any, Tuple[Any, Dict[str, Any]]]]: - """ - Yield nodes from the original graph. + """Yield nodes from the original graph. Parameters ---------- @@ -132,8 +124,7 @@ def get_nodes( def get_edges( self, data: bool = True ) -> Iterable[Union[Tuple[Any, Any], Tuple[Any, Any, Dict[str, Any]]]]: - """ - Yield edges from the original graph. + """Yield edges from the original graph. Parameters ---------- @@ -151,9 +142,7 @@ def __repr__(self) -> str: return f"" def help(self) -> None: - """ - Print a summary of the SynGraph API. - """ + """Print a summary of the SynGraph API.""" print( "SynGraph Help\n" "----------\n" diff --git a/synkit/Graph/utils.py b/synkit/Graph/utils.py index 0e5fa5f..bca1aac 100644 --- a/synkit/Graph/utils.py +++ b/synkit/Graph/utils.py @@ -3,8 +3,7 @@ def print_graph_attributes(G: nx.Graph) -> None: - """ - Print all node and edge attributes from a NetworkX graph. + """Print all node and edge attributes from a NetworkX graph. Parameters: G (nx.Graph): A NetworkX graph (Graph, DiGraph, MultiGraph, etc.). @@ -23,8 +22,7 @@ def print_graph_attributes(G: nx.Graph) -> None: def remove_wildcard_nodes(G: nx.Graph, inplace: bool = True) -> nx.Graph: - """ - Remove all wildcard nodes from the graph. + """Remove all wildcard nodes from the graph. A wildcard node is identified by having its 'element' attribute equal to '*'. @@ -59,8 +57,7 @@ def add_wildcard_subgraph_for_unmapped( edge_keys: List[str] = ["order"], inplace: bool = False, ) -> Tuple[nx.Graph, Dict[Any, Any]]: - """ - Extend G with wildcard nodes/edges for every L-node not already mapped, + """Extend G with wildcard nodes/edges for every L-node not already mapped, preserving original L->G mapping and returning the full mapping. Parameters @@ -123,10 +120,9 @@ def add_wildcard_subgraph_for_unmapped( def clean_graph_keep_largest_component(graph: nx.Graph) -> nx.Graph: - """ - Return a shallow copy of the input graph with all edges removed - where the 'standard_order' attribute is exactly 0, then retain only - the largest connected component. + """Return a shallow copy of the input graph with all edges removed where + the 'standard_order' attribute is exactly 0, then retain only the largest + connected component. Parameters ---------- diff --git a/synkit/IO/chem_converter.py b/synkit/IO/chem_converter.py index 31feb61..db60bf1 100644 --- a/synkit/IO/chem_converter.py +++ b/synkit/IO/chem_converter.py @@ -31,20 +31,22 @@ def smiles_to_graph( ], edge_attrs: Optional[List[str]] = ["order"], ) -> Optional[nx.Graph]: - """ - Helper function to convert a SMILES string to a NetworkX graph. + """Helper function to convert a SMILES string to a NetworkX graph. :param smiles: SMILES representation of the molecule. :type smiles: str - :param drop_non_aam: Whether to drop nodes without atom mapping numbers. + :param drop_non_aam: Whether to drop nodes without atom mapping + numbers. :type drop_non_aam: bool :param light_weight: Whether to create a light-weight graph. :type light_weight: bool :param sanitize: Whether to sanitize the molecule during conversion. :type sanitize: bool - :param use_index_as_atom_map: Whether to use atom indices as atom-map numbers. + :param use_index_as_atom_map: Whether to use atom indices as atom- + map numbers. :type use_index_as_atom_map: bool - :returns: The NetworkX graph representation, or None if conversion fails. + :returns: The NetworkX graph representation, or None if conversion + fails. :rtype: networkx.Graph or None """ @@ -97,20 +99,22 @@ def rsmi_to_graph( ], edge_attrs: Optional[List[str]] = ["order"], ) -> Tuple[Optional[nx.Graph], Optional[nx.Graph]]: - """ - Convert a reaction SMILES (RSMI) into reactant and product graphs. + """Convert a reaction SMILES (RSMI) into reactant and product graphs. :param rsmi: Reaction SMILES string in “reactants>>products” format. :type rsmi: str - :param drop_non_aam: If True, drop nodes without atom mapping numbers. + :param drop_non_aam: If True, drop nodes without atom mapping + numbers. :type drop_non_aam: bool :param light_weight: If True, create a light-weight graph. :type light_weight: bool :param sanitize: If True, sanitize molecules during conversion. :type sanitize: bool - :param use_index_as_atom_map: Whether to use atom indices as atom-map numbers. + :param use_index_as_atom_map: Whether to use atom indices as atom- + map numbers. :type use_index_as_atom_map: bool - :returns: A tuple `(reactant_graph, product_graph)`, each a NetworkX graph or None. + :returns: A tuple `(reactant_graph, product_graph)`, each a NetworkX + graph or None. :rtype: tuple of (networkx.Graph or None, networkx.Graph or None) """ try: @@ -142,15 +146,16 @@ def graph_to_smi( sanitize: bool = True, preserve_atom_maps: Optional[List[int]] = None, ) -> Optional[str]: - """ - Convert a NetworkX molecular graph to a SMILES string. + """Convert a NetworkX molecular graph to a SMILES string. - :param graph: Graph representation of the molecule. - Nodes must carry chemical attributes (e.g. ‘element’, atom maps). + :param graph: Graph representation of the molecule. Nodes must carry + chemical attributes (e.g. ‘element’, atom maps). :type graph: networkx.Graph - :param sanitize: Whether to perform RDKit sanitization on the resulting molecule. + :param sanitize: Whether to perform RDKit sanitization on the + resulting molecule. :type sanitize: bool - :param preserve_atom_maps: List of atom-map numbers for which hydrogens remain explicit. + :param preserve_atom_maps: List of atom-map numbers for which + hydrogens remain explicit. :type preserve_atom_maps: list of int or None :returns: SMILES string, or None if conversion fails. :rtype: str or None @@ -177,20 +182,22 @@ def graph_to_rsmi( sanitize: bool = True, explicit_hydrogen: bool = False, ) -> Optional[str]: - """ - Convert reactant and product graphs into a reaction SMILES string. + """Convert reactant and product graphs into a reaction SMILES string. :param r: Graph representing the reactants. :type r: networkx.Graph :param p: Graph representing the products. :type p: networkx.Graph - :param its: Imaginary transition state graph. If None, it will be constructed. + :param its: Imaginary transition state graph. If None, it will be + constructed. :type its: networkx.Graph or None :param sanitize: Whether to sanitize molecules during conversion. :type sanitize: bool - :param explicit_hydrogen: Whether to preserve explicit hydrogens in the SMILES. + :param explicit_hydrogen: Whether to preserve explicit hydrogens in + the SMILES. :type explicit_hydrogen: bool - :returns: Reaction SMILES string in 'reactants>>products' format or None on failure. + :returns: Reaction SMILES string in 'reactants>>products' format or + None on failure. :rtype: str or None """ try: @@ -229,22 +236,28 @@ def smart_to_gml( explicit_hydrogen: bool = False, useSmiles: bool = True, ) -> str: - """ - Convert a reaction SMARTS (or SMILES) template into a GML‐encoded DPO rule. + """Convert a reaction SMARTS (or SMILES) template into a GML‐encoded DPO + rule. :param smart: The reaction SMARTS or SMILES string. :type smart: str - :param core: If True, include only the reaction core in the GML. Defaults to True. + :param core: If True, include only the reaction core in the GML. + Defaults to True. :type core: bool - :param sanitize: If True, sanitize molecules during conversion. Defaults to True. + :param sanitize: If True, sanitize molecules during conversion. + Defaults to True. :type sanitize: bool - :param rule_name: Identifier for the output rule. Defaults to "rule". + :param rule_name: Identifier for the output rule. Defaults to + "rule". :type rule_name: str - :param reindex: If True, reindex graph nodes before exporting. Defaults to False. + :param reindex: If True, reindex graph nodes before exporting. + Defaults to False. :type reindex: bool - :param explicit_hydrogen: If True, include explicit hydrogen atoms. Defaults to False. + :param explicit_hydrogen: If True, include explicit hydrogen atoms. + Defaults to False. :type explicit_hydrogen: bool - :param useSmiles: If True, treat input as SMILES; if False, as SMARTS. Defaults to True. + :param useSmiles: If True, treat input as SMILES; if False, as + SMARTS. Defaults to True. :type useSmiles: bool :returns: The GML representation of the reaction rule. :rtype: str @@ -271,14 +284,14 @@ def gml_to_smart( explicit_hydrogen: bool = False, useSmiles: bool = True, ) -> Tuple[str, nx.Graph]: - """ - Convert a GML string back to a SMARTS string and ITS graph. + """Convert a GML string back to a SMARTS string and ITS graph. :param gml: The GML string to convert. :type gml: str :param sanitize: Whether to sanitize molecules upon conversion. :type sanitize: bool - :param explicit_hydrogen: Whether hydrogens are explicitly represented. + :param explicit_hydrogen: Whether hydrogens are explicitly + represented. :type explicit_hydrogen: bool :param useSmiles: If True, output SMILES; otherwise SMARTS. :type useSmiles: bool @@ -303,18 +316,19 @@ def its_to_gml( reindex: bool = True, explicit_hydrogen: bool = False, ) -> str: - """ - Convert an ITS graph (reaction graph) to GML format. + """Convert an ITS graph (reaction graph) to GML format. :param its: The input ITS graph representing the reaction. :type its: networkx.Graph - :param core: If True, focus only on the reaction center. Defaults to True. + :param core: If True, focus only on the reaction center. Defaults to + True. :type core: bool :param rule_name: Name of the reaction rule. Defaults to "rule". :type rule_name: str :param reindex: If True, reindex graph nodes. Defaults to True. :type reindex: bool - :param explicit_hydrogen: If True, include explicit hydrogens. Defaults to False. + :param explicit_hydrogen: If True, include explicit hydrogens. + Defaults to False. :type explicit_hydrogen: bool :returns: The GML representation of the ITS graph. :rtype: str @@ -335,8 +349,8 @@ def its_to_gml( def gml_to_its(gml: str) -> nx.Graph: - """ - Convert a GML string representation of a reaction back into an ITS graph. + """Convert a GML string representation of a reaction back into an ITS + graph. :param gml: The GML string representing the reaction. :type gml: str @@ -367,28 +381,37 @@ def rsmi_to_its( edge_attrs: Optional[List[str]] = ["order"], explicit_hydrogen: bool = False, ) -> nx.Graph: - """ - Convert a reaction SMILES (rSMI) to an ITS (Imaginary Transition State) graph. + """Convert a reaction SMILES (rSMI) to an ITS (Imaginary Transition State) + graph. - :param rsmi: The reaction SMILES string, optionally containing atom-map labels. + :param rsmi: The reaction SMILES string, optionally containing atom- + map labels. :type rsmi: str - :param drop_non_aam: If True, discard any molecular fragments without atom-atom maps. + :param drop_non_aam: If True, discard any molecular fragments + without atom-atom maps. :type drop_non_aam: bool - :param sanitize: If True, perform molecule sanitization (valence checks, kekulization). + :param sanitize: If True, perform molecule sanitization (valence + checks, kekulization). :type sanitize: bool - :param use_index_as_atom_map: If True, override atom-map labels by atom indices. + :param use_index_as_atom_map: If True, override atom-map labels by + atom indices. :type use_index_as_atom_map: bool - :param core: If True, return only the reaction-center subgraph of the ITS. + :param core: If True, return only the reaction-center subgraph of + the ITS. :type core: bool - :param node_attrs: Node attributes to include in the ITS graph (e.g., element, charge). + :param node_attrs: Node attributes to include in the ITS graph + (e.g., element, charge). :type node_attrs: list[str] - :param edge_attrs: Edge attributes to include in the ITS graph (e.g., order). + :param edge_attrs: Edge attributes to include in the ITS graph + (e.g., order). :type edge_attrs: list[str] - :param explicit_hydrogen: If True, convert implicit hydrogens to explicit nodes. + :param explicit_hydrogen: If True, convert implicit hydrogens to + explicit nodes. :type explicit_hydrogen: bool :returns: A NetworkX graph representing the complete or core ITS. :rtype: networkx.Graph - :raises ValueError: If the SMILES string is invalid or graph construction fails. + :raises ValueError: If the SMILES string is invalid or graph + construction fails. """ r, p = rsmi_to_graph( rsmi, @@ -411,26 +434,27 @@ def its_to_rsmi( sanitize: bool = True, explicit_hydrogen: bool = False, ) -> str: - """ - Convert an ITS graph into a reaction SMILES (rSMI) string. + """Convert an ITS graph into a reaction SMILES (rSMI) string. - :param its: A fully annotated ITS graph (nodes with atom-map attributes). + :param its: A fully annotated ITS graph (nodes with atom-map + attributes). :type its: networkx.Graph :param sanitize: If True, sanitize prior to SMILES generation. :type sanitize: bool :param explicit_hydrogen: If True, include explicit hydrogens. :type explicit_hydrogen: bool - :returns: A canonical reaction-SMILES string ('reactants>agents>products'). + :returns: A canonical reaction-SMILES string + ('reactants>agents>products'). :rtype: str - :raises ValueError: If graph cannot be decomposed or sanitisation fails. + :raises ValueError: If graph cannot be decomposed or sanitisation + fails. """ r, p = its_decompose(its) return graph_to_rsmi(r, p, its, sanitize, explicit_hydrogen) def rsmi_to_rsmarts(rsmi: str) -> str: - """ - Convert a mapped reaction SMILES to a reaction SMARTS string. + """Convert a mapped reaction SMILES to a reaction SMARTS string. :param rsmi: Reaction SMILES input. :type rsmi: str @@ -446,8 +470,7 @@ def rsmi_to_rsmarts(rsmi: str) -> str: def rsmarts_to_rsmi(rsmarts: str) -> str: - """ - Convert a reaction SMARTS to a reaction SMILES string. + """Convert a reaction SMARTS to a reaction SMILES string. :param rsmarts: Reaction SMARTS input. :type rsmarts: str diff --git a/synkit/IO/data_io.py b/synkit/IO/data_io.py index a66bd7e..c817c60 100644 --- a/synkit/IO/data_io.py +++ b/synkit/IO/data_io.py @@ -11,14 +11,13 @@ def save_database(database: List[Dict], pathname: str = "./Data/database.json") -> None: - """ - Save a database (a list of dictionaries) to a JSON file. + """Save a database (a list of dictionaries) to a JSON file. :param database: The database to be saved. :type database: list[dict] - :param pathname: The path where the database will be saved. Defaults to './Data/database.json'. + :param pathname: The path where the database will be saved. Defaults + to './Data/database.json'. :type pathname: str - :raises TypeError: If the database is not a list of dictionaries. :raises ValueError: If there is an error writing the file. """ @@ -32,15 +31,13 @@ def save_database(database: List[Dict], pathname: str = "./Data/database.json") def load_database(pathname: str = "./Data/database.json") -> List[Dict]: - """ - Load a database (a list of dictionaries) from a JSON file. + """Load a database (a list of dictionaries) from a JSON file. - :param pathname: The path from where the database will be loaded. Defaults to './Data/database.json'. + :param pathname: The path from where the database will be loaded. + Defaults to './Data/database.json'. :type pathname: str - :returns: The loaded database. :rtype: list[dict] - :raises ValueError: If there is an error reading the file. """ try: @@ -52,8 +49,7 @@ def load_database(pathname: str = "./Data/database.json") -> List[Dict]: def save_to_pickle(data: List[Dict[str, Any]], filename: str) -> None: - """ - Save a list of dictionaries to a pickle file. + """Save a list of dictionaries to a pickle file. :param data: A list of dictionaries to be saved. :type data: list[dict] @@ -65,12 +61,10 @@ def save_to_pickle(data: List[Dict[str, Any]], filename: str) -> None: def load_from_pickle(filename: str) -> List[Any]: - """ - Load data from a pickle file. + """Load data from a pickle file. :param filename: The name of the pickle file to load data from. :type filename: str - :returns: The data loaded from the pickle file. :rtype: list """ @@ -79,13 +73,12 @@ def load_from_pickle(filename: str) -> List[Any]: def load_gml_as_text(gml_file_path: str) -> Optional[str]: - """ - Load the contents of a GML file as a text string. + """Load the contents of a GML file as a text string. :param gml_file_path: The file path to the GML file. :type gml_file_path: str - - :returns: The text content of the GML file, or None if the file does not exist or an error occurs. + :returns: The text content of the GML file, or None if the file does + not exist or an error occurs. :rtype: str or None """ try: @@ -100,14 +93,12 @@ def load_gml_as_text(gml_file_path: str) -> Optional[str]: def save_text_as_gml(gml_text: str, file_path: str) -> bool: - """ - Save a GML text string to a file. + """Save a GML text string to a file. :param gml_text: The GML content as a text string. :type gml_text: str :param file_path: The file path where the GML text will be saved. :type file_path: str - :returns: True if saving was successful, False otherwise. :rtype: bool """ @@ -122,28 +113,26 @@ def save_text_as_gml(gml_text: str, file_path: str) -> bool: def save_compressed(array: ndarray, filename: str) -> None: - """ - Saves a NumPy array in a compressed format using .npz extension. + """Saves a NumPy array in a compressed format using .npz extension. :param array: The NumPy array to be saved. :type array: numpy.ndarray - :param filename: The file path or name to save the array to, with a '.npz' extension. + :param filename: The file path or name to save the array to, with a + '.npz' extension. :type filename: str """ np.savez_compressed(filename, array=array) def load_compressed(filename: str) -> ndarray: - """ - Loads a NumPy array from a compressed .npz file. + """Loads a NumPy array from a compressed .npz file. :param filename: The path of the .npz file to load. :type filename: str - :returns: The loaded NumPy array. :rtype: numpy.ndarray - - :raises KeyError: If the .npz file does not contain an array with the key 'array'. + :raises KeyError: If the .npz file does not contain an array with + the key 'array'. """ with np.load(filename) as data: if "array" in data: @@ -155,8 +144,7 @@ def load_compressed(filename: str) -> ndarray: def save_model(model: Any, filename: str) -> None: - """ - Save a machine learning model to a file using joblib. + """Save a machine learning model to a file using joblib. :param model: The machine learning model to save. :type model: object @@ -168,12 +156,11 @@ def save_model(model: Any, filename: str) -> None: def load_model(filename: str) -> Any: - """ - Load a machine learning model from a file using joblib. + """Load a machine learning model from a file using joblib. - :param filename: The path to the file from which the model will be loaded. + :param filename: The path to the file from which the model will be + loaded. :type filename: str - :returns: The loaded machine learning model. :rtype: object """ @@ -183,12 +170,12 @@ def load_model(filename: str) -> Any: def save_dict_to_json(data: dict, file_path: str) -> None: - """ - Save a dictionary to a JSON file. + """Save a dictionary to a JSON file. :param data: The dictionary to be saved. :type data: dict - :param file_path: The path to the file where the dictionary should be saved. + :param file_path: The path to the file where the dictionary should + be saved. :type file_path: str """ with open(file_path, "w") as json_file: @@ -197,13 +184,13 @@ def save_dict_to_json(data: dict, file_path: str) -> None: def load_dict_from_json(file_path: str) -> Optional[dict]: - """ - Load a dictionary from a JSON file. + """Load a dictionary from a JSON file. - :param file_path: The path to the JSON file from which to load the dictionary. + :param file_path: The path to the JSON file from which to load the + dictionary. :type file_path: str - - :returns: The dictionary loaded from the JSON file, or None if an error occurs. + :returns: The dictionary loaded from the JSON file, or None if an + error occurs. :rtype: dict or None """ try: @@ -217,13 +204,13 @@ def load_dict_from_json(file_path: str) -> Optional[dict]: def load_from_pickle_generator(file_path: str) -> Generator[Any, None, None]: - """ - A generator that yields items from a pickle file where each pickle load returns a list of dictionaries. + """A generator that yields items from a pickle file where each pickle load + returns a list of dictionaries. :param file_path: The path to the pickle file to load. :type file_path: str - - :yields: A single item from the list of dictionaries stored in the pickle file. + :yields: A single item from the list of dictionaries stored in the + pickle file. :rtype: Any """ with open(file_path, "rb") as file: @@ -237,16 +224,16 @@ def load_from_pickle_generator(file_path: str) -> Generator[Any, None, None]: def collect_data(num_batches: int, temp_dir: str, file_template: str) -> List[Any]: - """ - Collects and aggregates data from multiple pickle files into a single list. + """Collects and aggregates data from multiple pickle files into a single + list. :param num_batches: The number of batch files to process. :type num_batches: int :param temp_dir: The directory where the batch files are stored. :type temp_dir: str - :param file_template: The template string for batch file names, expecting an integer formatter. + :param file_template: The template string for batch file names, + expecting an integer formatter. :type file_template: str - :returns: A list of aggregated data items from all batch files. :rtype: list """ @@ -259,8 +246,7 @@ def collect_data(num_batches: int, temp_dir: str, file_template: str) -> List[An def save_list_to_file(data_list: list, file_path: str) -> None: - """ - Save a list to a file in JSON format. + """Save a list to a file in JSON format. :param data_list: The list to save. :type data_list: list @@ -272,12 +258,10 @@ def save_list_to_file(data_list: list, file_path: str) -> None: def load_list_from_file(file_path: str) -> list: - """ - Load a list from a JSON-formatted file. + """Load a list from a JSON-formatted file. :param file_path: The path to the file to read the list from. :type file_path: str - :returns: The list loaded from the file. :rtype: list """ @@ -286,17 +270,14 @@ def load_list_from_file(file_path: str) -> list: def save_dg(dg, path: str) -> str: - """ - Save a DG instance to disk using MØD's dump method. + """Save a DG instance to disk using MØD's dump method. :param dg: The derivation graph to save. :type dg: DG :param path: The file path where the graph will be dumped. :type path: str - :returns: The path of the dumped file. :rtype: str - :raises Exception: If saving fails. """ try: @@ -309,19 +290,17 @@ def save_dg(dg, path: str) -> str: def load_dg(path: str, graph_db: list, rule_db: list): - """ - Load a DG instance from a dumped file. + """Load a DG instance from a dumped file. :param path: The file path of the dumped graph. :type path: str - :param graph_db: List of Graph objects representing the graph database. + :param graph_db: List of Graph objects representing the graph + database. :type graph_db: list :param rule_db: List of Rule objects required for loading the DG. :type rule_db: list - :returns: The loaded derivation graph instance. :rtype: DG - :raises Exception: If loading fails. """ from mod import DG diff --git a/synkit/IO/data_process.py b/synkit/IO/data_process.py index 3475674..ec88283 100644 --- a/synkit/IO/data_process.py +++ b/synkit/IO/data_process.py @@ -7,10 +7,9 @@ def merge_dicts( key: str, intersection: bool = True, ) -> List[Dict[str, Any]]: - """ - Merges two lists of dictionaries based on a specified key, with an option to - either merge only dictionaries with matching key values (intersection) or - all dictionaries (union). + """Merges two lists of dictionaries based on a specified key, with an + option to either merge only dictionaries with matching key values + (intersection) or all dictionaries (union). Parameters: - list1 (List[Dict[str, Any]]): The first list of dictionaries. diff --git a/synkit/IO/debug.py b/synkit/IO/debug.py index 08deb6d..fdba57c 100644 --- a/synkit/IO/debug.py +++ b/synkit/IO/debug.py @@ -7,8 +7,8 @@ def setup_logging( log_level: str = "INFO", log_filename: str = None, task_type: str = None ) -> logging.Logger: - """ - Configures logging to either the console or a file, based on provided parameters. + """Configures logging to either the console or a file, based on provided + parameters. :param log_level: Logging level to set. Defaults to 'INFO'. Options: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'. @@ -47,19 +47,17 @@ def setup_logging( def configure_warnings_and_logs( ignore_warnings: bool = False, disable_rdkit_logs: bool = False ) -> None: - """ - Configures Python warnings and RDKit log behavior based on input flags. + """Configures Python warnings and RDKit log behavior based on input flags. - :param ignore_warnings: Whether to suppress all Python warnings. Default is False. + :param ignore_warnings: Whether to suppress all Python warnings. + Default is False. :type ignore_warnings: bool - :param disable_rdkit_logs: Whether to disable RDKit error and warning logs. Default is False. + :param disable_rdkit_logs: Whether to disable RDKit error and + warning logs. Default is False. :type disable_rdkit_logs: bool - - :returns: None - - :usage: - Use this function to control verbosity (e.g. in production or testing), but use with - caution during development to avoid missing critical issues. + :returns: None :usage: Use this function to control verbosity (e.g. + in production or testing), but use with caution during + development to avoid missing critical issues. """ if ignore_warnings: warnings.filterwarnings("ignore") diff --git a/synkit/IO/dg_to_gml.py b/synkit/IO/dg_to_gml.py index 00a3902..c290ec6 100644 --- a/synkit/IO/dg_to_gml.py +++ b/synkit/IO/dg_to_gml.py @@ -104,9 +104,8 @@ def printGraph(g): return s, Rule.fromGMLString(s, add=False) def fit(self, dg, origSmiles): - """ - Matches the original SMILES to a list of generated reaction SMILES and - returns the parsed reaction. + """Matches the original SMILES to a list of generated reaction SMILES + and returns the parsed reaction. Parameters: - dg (DataGenerator): The data generator instance containing the reactions. diff --git a/synkit/IO/gml_to_nx.py b/synkit/IO/gml_to_nx.py index 4c28587..a186068 100644 --- a/synkit/IO/gml_to_nx.py +++ b/synkit/IO/gml_to_nx.py @@ -5,21 +5,21 @@ class GMLToNX: - """ - Parses GML-like text and transforms it into three NetworkX graphs - representing the left, right, and context graphs of a chemical reaction step. + """Parses GML-like text and transforms it into three NetworkX graphs + representing the left, right, and context graphs of a chemical reaction + step. :param gml_text: The GML-like text to parse. :type gml_text: str - - :ivar graphs: A dictionary containing 'left', 'right', and 'context' NetworkX graphs. + :ivar graphs: A dictionary containing 'left', 'right', and 'context' + NetworkX graphs. :vartype graphs: dict[str, nx.Graph] """ def __init__(self, gml_text: str): - """ - Initializes a GMLToNX object that can parse GML-like text into separate - NetworkX graphs representing different stages or components of a chemical reaction. + """Initializes a GMLToNX object that can parse GML-like text into + separate NetworkX graphs representing different stages or components of + a chemical reaction. :param gml_text: The GML-like text to be parsed. :type gml_text: str @@ -28,13 +28,13 @@ def __init__(self, gml_text: str): self.graphs = {"left": nx.Graph(), "context": nx.Graph(), "right": nx.Graph()} def _parse_element(self, line: str, current_section: str): - """ - Parses a line of GML-like text to extract node or edge data and adds it to the - current section's graph. + """Parses a line of GML-like text to extract node or edge data and adds + it to the current section's graph. :param line: A single line of GML-like text. :type line: str - :param current_section: Which section ('left', 'right', 'context') to add the node/edge to. + :param current_section: Which section ('left', 'right', + 'context') to add the node/edge to. :type current_section: str """ label_to_order = {"-": 1, ":": 1.5, "=": 2, "#": 3} @@ -60,12 +60,11 @@ def _parse_element(self, line: str, current_section: str): self.graphs[current_section].add_edge(source, target, order=order) def _extract_element_and_charge(self, label: str) -> Tuple[str, int]: - """ - Extracts the chemical element and its charge from a node label. + """Extracts the chemical element and its charge from a node label. - :param label: The label string from a GML node (e.g., 'N+', 'O2-', etc.). + :param label: The label string from a GML node (e.g., 'N+', + 'O2-', etc.). :type label: str - :returns: A tuple of (element symbol, formal charge). :rtype: tuple[str, int] """ @@ -82,10 +81,10 @@ def _extract_element_and_charge(self, label: str) -> Tuple[str, int]: return element, charge def _synchronize_nodes_and_edges(self): - """ - Ensures that all nodes and edges in 'context' appear in both 'left' and 'right'. - We do not remove edges from left or right if they are not in context. - We only add missing context nodes and edges to left and right. + """Ensures that all nodes and edges in 'context' appear in both 'left' + and 'right'. We do not remove edges from left or right if they are not + in context. We only add missing context nodes and edges to left and + right. :returns: None """ diff --git a/synkit/IO/graph_to_mol.py b/synkit/IO/graph_to_mol.py index 3c8f012..8049f0d 100644 --- a/synkit/IO/graph_to_mol.py +++ b/synkit/IO/graph_to_mol.py @@ -4,18 +4,20 @@ class GraphToMol: - """ - Converts a NetworkX graph representation of a molecule into an RDKit molecule object. + """Converts a NetworkX graph representation of a molecule into an RDKit + molecule object. - This class reconstructs RDKit molecules from node and edge attributes in a graph, - correctly interpreting atom types, charges, mapping numbers, bond orders, and optionally - explicit hydrogen counts. + This class reconstructs RDKit molecules from node and edge + attributes in a graph, correctly interpreting atom types, charges, + mapping numbers, bond orders, and optionally explicit hydrogen + counts. - :param node_attributes: Mapping of expected attribute names to node keys in the graph. For example, - {"element": "element", "charge": "charge", "atom_map": "atom_map"}. + :param node_attributes: Mapping of expected attribute names to node + keys in the graph. For example, {"element": "element", "charge": + "charge", "atom_map": "atom_map"}. :type node_attributes: Dict[str, str] - :param edge_attributes: Mapping of expected attribute names to edge keys in the graph. - For example, {"order": "order"}. + :param edge_attributes: Mapping of expected attribute names to edge + keys in the graph. For example, {"order": "order"}. :type edge_attributes: Dict[str, str] """ @@ -28,14 +30,15 @@ def __init__( }, edge_attributes: Dict[str, str] = {"order": "order"}, ): - """ - Initializes the GraphToMol object with mappings for node and edge attributes. + """Initializes the GraphToMol object with mappings for node and edge + attributes. - :param node_attributes: Mapping from desired atom attribute names to graph node keys. - E.g. {"element": "element", "charge": "charge", "atom_map": "atom_map"} + :param node_attributes: Mapping from desired atom attribute + names to graph node keys. E.g. {"element": "element", + "charge": "charge", "atom_map": "atom_map"} :type node_attributes: Dict[str, str] - :param edge_attributes: Mapping from desired bond attribute names to graph edge keys. - E.g. {"order": "order"} + :param edge_attributes: Mapping from desired bond attribute + names to graph edge keys. E.g. {"order": "order"} :type edge_attributes: Dict[str, str] """ self.node_attributes = node_attributes @@ -48,22 +51,23 @@ def graph_to_mol( sanitize: bool = True, use_h_count: bool = False, ) -> Chem.Mol: - """ - Converts a NetworkX graph into an RDKit molecule. + """Converts a NetworkX graph into an RDKit molecule. :param graph: The NetworkX graph representing the molecule. :type graph: nx.Graph - :param ignore_bond_order: If True, all bonds are created as single bonds regardless of edge attributes. - Defaults to False. + :param ignore_bond_order: If True, all bonds are created as + single bonds regardless of edge attributes. Defaults to + False. :type ignore_bond_order: bool - :param sanitize: If True, the resulting RDKit molecule will be sanitized after construction. - Defaults to True. + :param sanitize: If True, the resulting RDKit molecule will be + sanitized after construction. Defaults to True. :type sanitize: bool - :param use_h_count: If True, the 'hcount' attribute (if present) will be used to set explicit hydrogen counts - on atoms. Defaults to False. + :param use_h_count: If True, the 'hcount' attribute (if present) + will be used to set explicit hydrogen counts on atoms. + Defaults to False. :type use_h_count: bool - - :returns: An RDKit molecule constructed from the graph's nodes and edges. + :returns: An RDKit molecule constructed from the graph's nodes + and edges. :rtype: Chem.Mol """ mol = Chem.RWMol() @@ -110,12 +114,13 @@ def graph_to_mol( @staticmethod def get_bond_type_from_order(order: float) -> Chem.BondType: - """ - Converts a numerical bond order into the corresponding RDKit BondType. + """Converts a numerical bond order into the corresponding RDKit + BondType. :param order: The numerical bond order (typically 1, 2, or 3). :type order: float - :returns: The corresponding RDKit bond type (single, double, triple, or aromatic). + :returns: The corresponding RDKit bond type (single, double, + triple, or aromatic). :rtype: Chem.BondType """ if order == 1: diff --git a/synkit/IO/mol_to_graph.py b/synkit/IO/mol_to_graph.py index 55524c0..312bdf0 100644 --- a/synkit/IO/mol_to_graph.py +++ b/synkit/IO/mol_to_graph.py @@ -35,12 +35,13 @@ def __init__( ], edge_attrs: Optional[List[str]] = ["order"], ) -> None: - """ - Initialize the MolToGraph helper. + """Initialize the MolToGraph helper. - :param node_attrs: Names of node attributes to keep when transforming. + :param node_attrs: Names of node attributes to keep when + transforming. :type node_attrs: List[str] - :param edge_attrs: Names of edge attributes to keep when transforming. + :param edge_attrs: Names of edge attributes to keep when + transforming. :type edge_attrs: List[str] """ self.node_attrs: List[str] = node_attrs or [] @@ -52,18 +53,21 @@ def transform( drop_non_aam: bool = False, use_index_as_atom_map: bool = False, ) -> nx.Graph: - """ - Build a graph directly from a molecule, including only selected attributes. + """Build a graph directly from a molecule, including only selected + attributes. :param mol: The RDKit molecule to convert. :type mol: Chem.Mol - :param drop_non_aam: If True, skips atoms without atom mapping numbers - (requires use_index_as_atom_map=True). Defaults to False. + :param drop_non_aam: If True, skips atoms without atom mapping + numbers (requires use_index_as_atom_map=True). Defaults to + False. :type drop_non_aam: bool - :param use_index_as_atom_map: If True, uses atom mapping numbers as node IDs when present; - otherwise uses atom index+1. Defaults to False. + :param use_index_as_atom_map: If True, uses atom mapping numbers + as node IDs when present; otherwise uses atom index+1. + Defaults to False. :type use_index_as_atom_map: bool - :returns: A NetworkX graph containing only the specified node and edge attributes. + :returns: A NetworkX graph containing only the specified node + and edge attributes. :rtype: nx.Graph """ if drop_non_aam and not use_index_as_atom_map: @@ -107,8 +111,7 @@ def transform( @staticmethod def _gather_atom_properties(atom: Chem.Atom) -> Dict[str, Any]: - """ - Collect the full set of atom attributes for graph nodes. + """Collect the full set of atom attributes for graph nodes. :param atom: The RDKit Atom object. :type atom: Chem.Atom @@ -137,8 +140,7 @@ def _gather_atom_properties(atom: Chem.Atom) -> Dict[str, Any]: @staticmethod def _gather_bond_properties(bond: Chem.Bond) -> Dict[str, Any]: - """ - Collect the full set of bond attributes for graph edges. + """Collect the full set of bond attributes for graph edges. :param bond: The RDKit Bond object. :type bond: Chem.Bond @@ -155,8 +157,7 @@ def _gather_bond_properties(bond: Chem.Bond) -> Dict[str, Any]: @staticmethod def get_stereochemistry(atom: Chem.Atom) -> str: - """ - Determine the stereochemistry (R/S) of a chiral atom. + """Determine the stereochemistry (R/S) of a chiral atom. :param atom: The RDKit Atom object. :type atom: Chem.Atom @@ -172,12 +173,12 @@ def get_stereochemistry(atom: Chem.Atom) -> str: @staticmethod def get_bond_stereochemistry(bond: Chem.Bond) -> str: - """ - Determine the stereochemistry (E/Z) of a double bond. + """Determine the stereochemistry (E/Z) of a double bond. :param bond: The RDKit Bond object. :type bond: Chem.Bond - :returns: 'E', 'Z', or 'N' for non-stereospecific or non-double bond. + :returns: 'E', 'Z', or 'N' for non-stereospecific or non-double + bond. :rtype: str """ if bond.GetBondType() != Chem.BondType.DOUBLE: @@ -191,8 +192,7 @@ def get_bond_stereochemistry(bond: Chem.Bond) -> str: @staticmethod def has_atom_mapping(mol: Chem.Mol) -> bool: - """ - Check if any atom in the molecule has an atom mapping number. + """Check if any atom in the molecule has an atom mapping number. :param mol: The RDKit molecule. :type mol: Chem.Mol @@ -203,8 +203,7 @@ def has_atom_mapping(mol: Chem.Mol) -> bool: @staticmethod def random_atom_mapping(mol: Chem.Mol) -> Chem.Mol: - """ - Assign random atom mapping numbers to all atoms in the molecule. + """Assign random atom mapping numbers to all atoms in the molecule. :param mol: The RDKit molecule. :type mol: Chem.Mol @@ -225,17 +224,18 @@ def mol_to_graph( light_weight: bool = False, use_index_as_atom_map: bool = False, ) -> nx.Graph: - """ - Convert a molecule to a full-featured NetworkX graph. + """Convert a molecule to a full-featured NetworkX graph. :param mol: The RDKit molecule to convert. :type mol: Chem.Mol :param drop_non_aam: If True, drop atoms without mapping numbers - (requires use_index_as_atom_map=True). Defaults to False. + (requires use_index_as_atom_map=True). Defaults to False. :type drop_non_aam: bool - :param light_weight: If True, create a lightweight graph with minimal attributes. Defaults to False. + :param light_weight: If True, create a lightweight graph with + minimal attributes. Defaults to False. :type light_weight: bool - :param use_index_as_atom_map: If True, prefer atom maps as node IDs. Defaults to False. + :param use_index_as_atom_map: If True, prefer atom maps as node + IDs. Defaults to False. :type use_index_as_atom_map: bool :returns: A NetworkX graph of the molecule with all attributes. :rtype: nx.Graph @@ -257,14 +257,15 @@ def _create_light_weight_graph( drop_non_aam: bool = False, use_index_as_atom_map: bool = False, ) -> nx.Graph: - """ - Create a lightweight graph with basic atom and bond info. + """Create a lightweight graph with basic atom and bond info. :param mol: The RDKit molecule. :type mol: Chem.Mol - :param drop_non_aam: If True, skip atoms without mapping numbers. Defaults to False. + :param drop_non_aam: If True, skip atoms without mapping + numbers. Defaults to False. :type drop_non_aam: bool - :param use_index_as_atom_map: If True, use atom maps as node IDs when present. Defaults to False. + :param use_index_as_atom_map: If True, use atom maps as node IDs + when present. Defaults to False. :type use_index_as_atom_map: bool :returns: A NetworkX graph with minimal node/edge attributes. :rtype: nx.Graph @@ -306,14 +307,15 @@ def _create_detailed_graph( drop_non_aam: bool = True, use_index_as_atom_map: bool = True, ) -> nx.Graph: - """ - Create a detailed graph with full atom and bond attributes. + """Create a detailed graph with full atom and bond attributes. :param mol: The RDKit molecule. :type mol: Chem.Mol - :param drop_non_aam: If True, skip atoms without mapping numbers. Defaults to True. + :param drop_non_aam: If True, skip atoms without mapping + numbers. Defaults to True. :type drop_non_aam: bool - :param use_index_as_atom_map: If True, use atom maps as node IDs when present. Defaults to True. + :param use_index_as_atom_map: If True, use atom maps as node IDs + when present. Defaults to True. :type use_index_as_atom_map: bool :returns: A NetworkX graph with full node/edge attributes. :rtype: nx.Graph @@ -341,8 +343,7 @@ def _create_detailed_graph( @staticmethod def add_partial_charges(mol: Chem.Mol) -> None: - """ - Compute and assign Gasteiger charges to all atoms in the molecule. + """Compute and assign Gasteiger charges to all atoms in the molecule. :param mol: The RDKit molecule. :type mol: Chem.Mol diff --git a/synkit/IO/nx_to_gml.py b/synkit/IO/nx_to_gml.py index 5f66ef6..ed63280 100644 --- a/synkit/IO/nx_to_gml.py +++ b/synkit/IO/nx_to_gml.py @@ -4,29 +4,28 @@ class NXToGML: - """ - Converts NetworkX graph representations of chemical reactions to GML (Graph Modelling Language) strings. - Useful for exporting reaction rules in a standard graph format. + """Converts NetworkX graph representations of chemical reactions to GML + (Graph Modelling Language) strings. Useful for exporting reaction rules in + a standard graph format. - This class provides static methods for converting individual graphs, sets of reaction graphs, and - managing charge/attribute changes in the export process. + This class provides static methods for converting individual graphs, + sets of reaction graphs, and managing charge/attribute changes in + the export process. """ def __init__(self) -> None: - """ - Initializes an NXToGML object. - """ + """Initializes an NXToGML object.""" pass @staticmethod def _charge_to_string(charge: int) -> str: - """ - Converts an integer charge into a string representation. + """Converts an integer charge into a string representation. - :param charge: The charge value, which can be positive, negative, or zero. + :param charge: The charge value, which can be positive, + negative, or zero. :type charge: int - - :returns: The string representation of the charge (e.g. '+', '2+', '-', '3-', ''). + :returns: The string representation of the charge (e.g. '+', + '2+', '-', '3-', ''). :rtype: str """ if charge > 0: @@ -40,8 +39,8 @@ def _charge_to_string(charge: int) -> str: def _find_changed_nodes( graph1: nx.Graph, graph2: nx.Graph, attributes: List[str] = ["charge"] ) -> List[int]: - """ - Identifies nodes with changes in specified attributes between two NetworkX graphs. + """Identifies nodes with changes in specified attributes between two + NetworkX graphs. :param graph1: The first NetworkX graph. :type graph1: nx.Graph @@ -49,8 +48,8 @@ def _find_changed_nodes( :type graph2: nx.Graph :param attributes: List of attribute names to check for changes. :type attributes: list[str] - - :returns: Node identifiers that have changes in the specified attributes. + :returns: Node identifiers that have changes in the specified + attributes. :rtype: list[int] """ changed_nodes = [] @@ -71,19 +70,21 @@ def _convert_graph_to_gml( changed_node_ids: List[int], explicit_hydrogen: bool = False, ) -> str: - """ - Converts a NetworkX graph to a GML string for a specific reaction section. + """Converts a NetworkX graph to a GML string for a specific reaction + section. :param graph: The NetworkX graph to be converted. :type graph: nx.Graph - :param section: The section name in the GML output ('left', 'right', or 'context'). + :param section: The section name in the GML output ('left', + 'right', or 'context'). :type section: str :param changed_node_ids: List of nodes with changed attributes. :type changed_node_ids: list[int] - :param explicit_hydrogen: Whether to explicitly include hydrogen atoms in the output. + :param explicit_hydrogen: Whether to explicitly include hydrogen + atoms in the output. :type explicit_hydrogen: bool - - :returns: The GML string representation of the graph for the specified section. + :returns: The GML string representation of the graph for the + specified section. :rtype: str """ order_to_label = {1: "-", 1.5: ":", 2: "=", 3: "#"} @@ -134,8 +135,8 @@ def _rule_grammar( changed_node_ids: List[int], explicit_hydrogen: bool, ) -> str: - """ - Generates a GML string for a chemical rule, including left, context, and right graphs. + """Generates a GML string for a chemical rule, including left, context, + and right graphs. :param L: The left graph. :type L: nx.Graph @@ -147,9 +148,9 @@ def _rule_grammar( :type rule_name: str :param changed_node_ids: List of nodes with changed attributes. :type changed_node_ids: list[int] - :param explicit_hydrogen: Whether to explicitly include hydrogen atoms in the output. + :param explicit_hydrogen: Whether to explicitly include hydrogen + atoms in the output. :type explicit_hydrogen: bool - :returns: The GML string representation of the rule. :rtype: str """ @@ -171,21 +172,22 @@ def transform( attributes: List[str] = ["charge"], explicit_hydrogen: bool = False, ) -> str: - """ - Processes a triple of reaction graphs to generate a GML string rule, with options for node - reindexing and explicit hydrogen expansion. + """Processes a triple of reaction graphs to generate a GML string rule, + with options for node reindexing and explicit hydrogen expansion. :param graph_rules: Tuple containing (L, R, K) reaction graphs. :type graph_rules: tuple[nx.Graph, nx.Graph, nx.Graph] :param rule_name: The rule name to use in the output. :type rule_name: str - :param reindex: Whether to reindex node IDs based on the L graph sequence. + :param reindex: Whether to reindex node IDs based on the L graph + sequence. :type reindex: bool - :param attributes: List of attribute names to check for node changes. + :param attributes: List of attribute names to check for node + changes. :type attributes: list[str] - :param explicit_hydrogen: Whether to explicitly include hydrogen atoms in the output. + :param explicit_hydrogen: Whether to explicitly include hydrogen + atoms in the output. :type explicit_hydrogen: bool - :returns: The GML string representing the chemical rule. :rtype: str """ diff --git a/synkit/IO/smiles_to_id.py b/synkit/IO/smiles_to_id.py index bd7bb71..26d6afa 100644 --- a/synkit/IO/smiles_to_id.py +++ b/synkit/IO/smiles_to_id.py @@ -6,8 +6,8 @@ def smiles_to_iupac(smiles_string: str, timeout: int = 1): - """ - Converts a SMILES string to its corresponding IUPAC name(s) using the PubChem PUG REST API. + """Converts a SMILES string to its corresponding IUPAC name(s) using the + PubChem PUG REST API. Parameters: - smiles_string (str): The SMILES string of the compound (e.g., "C=O" for formaldehyde). @@ -64,8 +64,7 @@ def smiles_to_iupac(smiles_string: str, timeout: int = 1): def batch_process_smiles(smiles_batch: List[str], timeout=1): - """ - Processes a batch of SMILES strings to get IUPAC names. + """Processes a batch of SMILES strings to get IUPAC names. Parameters: - smiles_batch (list): A list of SMILES strings to process. @@ -80,8 +79,8 @@ def batch_process_smiles(smiles_batch: List[str], timeout=1): def get_iupac_for_smiles_list( smiles_list: List[str], batch_size=10, n_jobs=4, timeout=1 ): - """ - Convert a list of SMILES strings to their corresponding IUPAC names using the PubChem API with batch processing. + """Convert a list of SMILES strings to their corresponding IUPAC names + using the PubChem API with batch processing. Parameters: smiles_list (list): A list of SMILES strings to be converted to IUPAC names. diff --git a/synkit/Rule/Apply/reactor_rule.py b/synkit/Rule/Apply/reactor_rule.py index ff49a4b..74b60c5 100644 --- a/synkit/Rule/Apply/reactor_rule.py +++ b/synkit/Rule/Apply/reactor_rule.py @@ -3,7 +3,7 @@ from synkit.IO.chem_converter import gml_to_smart from synkit.Chem.Reaction.standardize import Standardize -from synkit.Chem.Reaction.rsmi_utils import reverse_reaction +from synkit.Chem.utils import reverse_reaction from synkit.Graph.ITS.normalize_aam import NormalizeAAM from synkit.Graph.ITS.its_expand import ITSExpand @@ -27,16 +27,15 @@ class ReactorRule: - """ - Handles the transformation of SMILES strings to reaction SMILES (RSMI) by applying - chemical reaction rules defined in GML strings. It can optionally reverse the reaction, - exclude atom mappings, and include unchanged reagents in the output. + """Handles the transformation of SMILES strings to reaction SMILES (RSMI) + by applying chemical reaction rules defined in GML strings. + + It can optionally reverse the reaction, exclude atom mappings, and + include unchanged reagents in the output. """ def __init__(self) -> None: - """ - Initializes the ReactorRule object. - """ + """Initializes the ReactorRule object.""" pass def _process( @@ -47,9 +46,9 @@ def _process( exclude_aam: bool = False, include_reagents: bool = False, ) -> List[str]: - """ - Processes a reaction SMILES (RSMI) to adjust atom mappings, extract reaction centers, - decompose into separate reactant and product graphs, and generate the corresponding SMILES. + """Processes a reaction SMILES (RSMI) to adjust atom mappings, extract + reaction centers, decompose into separate reactant and product graphs, + and generate the corresponding SMILES. Parameters: - smiles (str): The SMILES string of the molecule to be transformed. diff --git a/synkit/Rule/Apply/retro_reactor.py b/synkit/Rule/Apply/retro_reactor.py index caba446..9a29cfc 100644 --- a/synkit/Rule/Apply/retro_reactor.py +++ b/synkit/Rule/Apply/retro_reactor.py @@ -23,8 +23,7 @@ class RetroReactor: def __init__(self) -> None: - """ - Initialize the RuleFrag class with caches and null initial values. + """Initialize the RuleFrag class with caches and null initial values. Attributes: - backward_cache: A dictionary cache (keyed by (smiles, rule)) to avoid redundant computations. @@ -32,9 +31,9 @@ def __init__(self) -> None: self.backward_cache: Dict[Tuple[str, str], List[str]] = {} def _apply_backward(self, smiles: str, rule: str) -> List[str]: - """ - Apply a transformation rule in backward mode to a SMILES string, returning possible precursors. - Uses caching to avoid redundant computations. + """Apply a transformation rule in backward mode to a SMILES string, + returning possible precursors. Uses caching to avoid redundant + computations. Parameters: - smiles (str): SMILES string to transform. @@ -71,10 +70,9 @@ def _apply_backward(self, smiles: str, rule: str) -> List[str]: return self.backward_cache[cache_key] def _heuristic(self, current_smiles: str, precursor_smiles: str) -> int: - """ - Heuristic function for A* search. Here, we define the "distance" as the - absolute difference in the carbon count between the current SMILES and - the known precursor SMILES. + """Heuristic function for A* search. Here, we define the "distance" as + the absolute difference in the carbon count between the current SMILES + and the known precursor SMILES. Parameters: - current_smiles (str): The SMILES of the node being expanded. @@ -93,8 +91,8 @@ def backward_synthesis_search( max_solutions: int = 1, fast_process: bool = True, ) -> List[Dict[str, List]]: - """ - Perform a backward synthesis search from a product to a known precursor using A* search. + """Perform a backward synthesis search from a product to a known + precursor using A* search. Constrains any intermediate X to satisfy: n_C(known_precursor_smiles) <= n_C(X) <= n_C(product_smiles). diff --git a/synkit/Rule/Apply/rule_matcher.py b/synkit/Rule/Apply/rule_matcher.py index a63fbe5..4e9b88f 100644 --- a/synkit/Rule/Apply/rule_matcher.py +++ b/synkit/Rule/Apply/rule_matcher.py @@ -29,8 +29,7 @@ class RuleMatcher: - """ - Match a reaction SMILES against a transformation‑rule graph and extract + """Match a reaction SMILES against a transformation‑rule graph and extract the SMARTS pattern that reproduces the reaction. On initialization, the matcher standardizes the RSMI, builds reactant/product @@ -56,15 +55,15 @@ class RuleMatcher: """ def __init__(self, rsmi: str, rule: nx.Graph) -> None: - """ - Initialize the matcher by standardizing the RSMI, building graphs, + """Initialize the matcher by standardizing the RSMI, building graphs, checking balance, and computing the match. :param rsmi: Reaction SMILES in 'reactant>>product' format. :type rsmi: str :param rule: Transformation‑rule graph. :type rule: nx.Graph - :raises ValueError: If no SMARTS reproduces the RSMI under the given rule. + :raises ValueError: If no SMARTS reproduces the RSMI under the + given rule. """ self.std = Standardize() self.rsmi = self.std.fit(rsmi) @@ -85,8 +84,7 @@ def __init__(self, rsmi: str, rule: nx.Graph) -> None: self.result = match def get_result(self) -> Tuple[str, nx.Graph]: - """ - Return the SMARTS and rule graph found during initialization. + """Return the SMARTS and rule graph found during initialization. :returns: A tuple (smarts, rule_graph). :rtype: tuple[str, nx.Graph] @@ -94,10 +92,10 @@ def get_result(self) -> Tuple[str, nx.Graph]: return self.result def _match_valid(self) -> Optional[Tuple[str, nx.Graph]]: - """ - Attempt a direct (balanced) match of the rule. + """Attempt a direct (balanced) match of the rule. - :returns: (smarts, rule) if direct match succeeds; otherwise None. + :returns: (smarts, rule) if direct match succeeds; otherwise + None. :rtype: Optional[tuple[str, nx.Graph]] """ reactor = SynReactor(substrate=self.r_graph, template=self.rule) @@ -107,13 +105,13 @@ def _match_valid(self) -> Optional[Tuple[str, nx.Graph]]: return None def _match_reverse(self) -> Optional[Tuple[str, nx.Graph]]: - """ - Attempt a reverse‑balance (partial) match for unbalanced reactions. + """Attempt a reverse‑balance (partial) match for unbalanced reactions. - First tries matching on product fragments, then on reactant fragments - with the template inverted. + First tries matching on product fragments, then on reactant + fragments with the template inverted. - :returns: (smarts, rule) if a partial match is found; otherwise None. + :returns: (smarts, rule) if a partial match is found; otherwise + None. :rtype: Optional[tuple[str, nx.Graph]] """ # Product‑side fragments @@ -138,8 +136,7 @@ def _match_reverse(self) -> Optional[Tuple[str, nx.Graph]]: @staticmethod def all_in(a: List[str], b: List[str]) -> bool: - """ - Check if every element of list `a` appears in list `b`. + """Check if every element of list `a` appears in list `b`. :param a: List of elements to test for membership. :type a: list[str] @@ -151,8 +148,7 @@ def all_in(a: List[str], b: List[str]) -> bool: return set(a).issubset(b) def help(self) -> None: - """ - Print internal state and candidate SMARTS patterns for debugging. + """Print internal state and candidate SMARTS patterns for debugging. :returns: None :rtype: NoneType @@ -165,8 +161,7 @@ def help(self) -> None: print(" ", smarts) def __str__(self) -> str: - """ - Short string showing the RSMI and balance status. + """Short string showing the RSMI and balance status. :returns: Human‑readable summary. :rtype: str @@ -175,8 +170,7 @@ def __str__(self) -> str: return f"" def __repr__(self) -> str: - """ - Detailed representation including rule size and balance. + """Detailed representation including rule size and balance. :returns: repr string. :rtype: str diff --git a/synkit/Rule/Apply/rule_rbl.py b/synkit/Rule/Apply/rule_rbl.py index f94d18d..613b3b3 100644 --- a/synkit/Rule/Apply/rule_rbl.py +++ b/synkit/Rule/Apply/rule_rbl.py @@ -1,7 +1,7 @@ import importlib.util from typing import List from synkit.Chem.Reaction.standardize import Standardize -from synkit.Chem.Reaction.rsmi_utils import ( +from synkit.Chem.utils import ( find_longest_fragment, merge_reaction, remove_common_reagents, @@ -26,8 +26,8 @@ def __init__(self) -> None: pass def rbl(self, rsmi: str, gml_rule: str, remove_aam: bool = True) -> List[str]: - """ - Applies transformation rules to a reaction SMILES string based on GML rules. + """Applies transformation rules to a reaction SMILES string based on + GML rules. Parameters: - rsmi (str): Reaction SMILES string to process. diff --git a/synkit/Rule/Compose/compose_rule.py b/synkit/Rule/Compose/compose_rule.py index 57b0ee0..cbcf82d 100644 --- a/synkit/Rule/Compose/compose_rule.py +++ b/synkit/Rule/Compose/compose_rule.py @@ -4,8 +4,8 @@ from synkit.IO.chem_converter import gml_to_smart, smart_to_gml from synkit.Rule.Modify.rule_utils import _increment_gml_ids from synkit.Chem.Reaction.standardize import Standardize -from synkit.Chem.Reaction.cleanning import Cleanning -from synkit.Chem.Reaction.rsmi_utils import find_longest_fragment +from synkit.Chem.Reaction.cleaning import Cleaning +from synkit.Chem.utils import find_longest_fragment logger = setup_logging() @@ -23,8 +23,7 @@ class ComposeRule: @staticmethod def filter_smallest_vertex(combo: List[object]) -> List[object]: - """ - Filters and returns the elements from a list that have the smallest + """Filters and returns the elements from a list that have the smallest number of vertices in their context. Parameters: @@ -50,9 +49,8 @@ def filter_smallest_vertex(combo: List[object]) -> List[object]: @staticmethod def rule_cluster(graphs: List[Any]) -> List[Any]: - """ - Cluster graphs based on their isomorphic relationships and - return a representative from each cluster. + """Cluster graphs based on their isomorphic relationships and return a + representative from each cluster. Parameters: - graphs (List[Any]): A list of graph objects. @@ -84,8 +82,8 @@ def rule_cluster(graphs: List[Any]) -> List[Any]: def _compose_mapping( rule_1: str, rule_2: str, mapping: Dict[int, int], return_string: bool = True ) -> Any: - """ - Compose two rule graphs from their GML representations using a mapping between external IDs. + """Compose two rule graphs from their GML representations using a + mapping between external IDs. Parameters: - rule_1 (str): The GML representation for the first rule. @@ -118,8 +116,8 @@ def _compose_mapping( @staticmethod def _compose(rule_1: str, rule_2: str, return_string: bool = True) -> List[Any]: - """ - Compose two rules and return a list of modifications that pass chemical valence checks. + """Compose two rules and return a list of modifications that pass + chemical valence checks. Parameters: - rule_1 (str): The first rule (in GML format) to compose. @@ -145,8 +143,8 @@ def _compose(rule_1: str, rule_2: str, return_string: bool = True) -> List[Any]: @staticmethod def _get_valid_rule(rules: List[str], format: str = "gml") -> List[str]: - """ - Validate and convert a list of rule GML strings to either SMARTS or GML format. + """Validate and convert a list of rule GML strings to either SMARTS or + GML format. Parameters: - rules (List[str]): A list of rule GML strings. @@ -171,8 +169,8 @@ def _get_valid_rule(rules: List[str], format: str = "gml") -> List[str]: @staticmethod def _get_comp_reaction(smart_1: str, smart_2: str) -> str: - """ - Compute a representative reaction SMILES for the composed rule from two SMARTS strings. + """Compute a representative reaction SMILES for the composed rule from + two SMARTS strings. Parameters: - smart_1 (str): The first reaction in SMARTS notation. @@ -190,8 +188,8 @@ def _get_comp_reaction(smart_1: str, smart_2: str) -> str: return new_rsmi def get_rule_comp(self, smart_1: str, smart_2: str) -> Optional[str]: - """ - Compose two reaction SMARTS strings into a rule (GML format) that reproduces a reference reaction. + """Compose two reaction SMARTS strings into a rule (GML format) that + reproduces a reference reaction. Parameters: - smart_1 (str): The first reaction in SMARTS notation. @@ -215,7 +213,7 @@ def get_rule_comp(self, smart_1: str, smart_2: str) -> Optional[str]: for candidate in candidate_rules: reactor = MODReactor(initial_smiles, candidate).run() inferred_rsmi = reactor.get_reaction_smiles() - inferred_rsmi = Cleanning.clean_smiles(inferred_rsmi) + inferred_rsmi = Cleaning.clean_smiles(inferred_rsmi) inferred_prod = [i.split(">>")[1].split(".") for i in inferred_rsmi] if any(largest_prod in smi for smi in inferred_prod): cds.append(candidate) diff --git a/synkit/Rule/Compose/rule_compose.py b/synkit/Rule/Compose/rule_compose.py index 2a401f2..9d9936b 100644 --- a/synkit/Rule/Compose/rule_compose.py +++ b/synkit/Rule/Compose/rule_compose.py @@ -25,8 +25,7 @@ def __init__(self) -> None: @staticmethod def filter_smallest_vertex(combo: List[object]) -> List[object]: - """ - Filters and returns the elements from a list that have the smallest + """Filters and returns the elements from a list that have the smallest number of vertices in their context. Parameters: @@ -52,9 +51,8 @@ def filter_smallest_vertex(combo: List[object]) -> List[object]: @staticmethod def rule_cluster(graphs: List) -> List: - """ - Clusters graphs based on their isomorphic relationship and returns - a list of graphs, each from a different cluster. + """Clusters graphs based on their isomorphic relationship and returns a + list of graphs, each from a different cluster. Parameters: - graphs: A list of graph objects. @@ -89,8 +87,8 @@ def rule_cluster(graphs: List) -> List: @staticmethod def _compose(rule_1, rule_2): - """ - Compose two rules and filter the results based on chemical valence constraints. + """Compose two rules and filter the results based on chemical valence + constraints. Parameters: - rule_1: First rule object to compose. @@ -115,8 +113,7 @@ def _compose(rule_1, rule_2): @staticmethod def _process_compose(rule_1_id, rule_2_id, rule_path, rule_path_compose): - """ - Process and compose two rules based on their GML files. + """Process and compose two rules based on their GML files. Parameters: - rule_1_id (str): Identifier for the first rule. @@ -144,8 +141,8 @@ def _process_compose(rule_1_id, rule_2_id, rule_path, rule_path_compose): @staticmethod def _auto_compose(rule_path, rule_path_compose): - """ - Automatically find all GML files in the given directory and compose them pairwise. + """Automatically find all GML files in the given directory and compose + them pairwise. Parameters: - rule_path (str): Directory path where the GML files are stored. @@ -182,11 +179,10 @@ def _auto_compose(rule_path, rule_path_compose): def save_gml_from_text( gml_content: str, gml_file_path: str, rule_id: str, parent_ids: List[str] ) -> bool: - """ - Save a text string to a GML file by modifying the 'ruleID' line to include parent - rule names. This function parses the given GML content, identifies any lines - starting with 'ruleID', and replaces these lines with a new ruleID that - incorporates identifiers from parent rules. + """Save a text string to a GML file by modifying the 'ruleID' line to + include parent rule names. This function parses the given GML content, + identifies any lines starting with 'ruleID', and replaces these lines + with a new ruleID that incorporates identifiers from parent rules. Parameters: - gml_content (str): The content to be saved to the GML file. This should be the diff --git a/synkit/Rule/Compose/rule_mapping.py b/synkit/Rule/Compose/rule_mapping.py index 6613c6c..8c031b1 100644 --- a/synkit/Rule/Compose/rule_mapping.py +++ b/synkit/Rule/Compose/rule_mapping.py @@ -11,8 +11,9 @@ class RuleMapping: def enumerate_all_unique_mappings( child: nx.Graph, parent: nx.Graph ) -> List[Dict[Any, Any]]: - """ - Generate all unique mappings (as dictionaries) from the child graph to the parent graph. + """Generate all unique mappings (as dictionaries) from the child graph + to the parent graph. + A mapping is valid if: - Every node from the child graph is assigned exactly one parent node. - The parent's node has the same 'element' attribute as the child node. @@ -66,9 +67,9 @@ def backtrack( def standardize_order( order_tuple: Tuple[float, ...], ) -> Optional[Tuple[float, ...]]: - """ - Standardizes an order tuple by adding 1 to every element repeatedly until no element is negative. - If the resulting tuple becomes all zeros, returns None, which indicates that the edge should be dropped. + """Standardizes an order tuple by adding 1 to every element repeatedly + until no element is negative. If the resulting tuple becomes all zeros, + returns None, which indicates that the edge should be dropped. For example: (-1.0, 0.0) --> add 1 gives (0.0, 1.0) @@ -90,8 +91,8 @@ def standardize_order( @staticmethod def keep_largest_component(graph: nx.Graph) -> nx.Graph: - """ - Given an undirected graph, returns the subgraph corresponding to the largest connected component. + """Given an undirected graph, returns the subgraph corresponding to the + largest connected component. Parameters: - graph (nx.Graph): The input graph from which the largest component is extracted. @@ -211,9 +212,9 @@ def graph_alignment( node_label_default: List[str] = ["*"], edge_attribute: str = "standard_order", ) -> Tuple[bool, Optional[Dict[Any, Any]]]: - """ - Check whether the child and parent graphs are isomorphic using specified node and edge match criteria. - If they are isomorphic, return the mapping from child to parent. + """Check whether the child and parent graphs are isomorphic using + specified node and edge match criteria. If they are isomorphic, return + the mapping from child to parent. Parameters: - child (nx.Graph): The child graph to align. @@ -244,8 +245,8 @@ def get_child1_to_child2_mapping( mapping_child1_to_parent: Dict[Any, Any], mapping_child2_to_parent: Dict[Any, Any], ) -> Dict[Any, Optional[Any]]: - """ - Build a mapping from Child1 to Child2 using each child's mapping to a common Parent. + """Build a mapping from Child1 to Child2 using each child's mapping to + a common Parent. If a Parent node in Child1's mapping is not in Child2's inverted mapping, that Child1 node will map to None. @@ -277,8 +278,8 @@ def get_child1_to_child2_mapping( return mapping_child1_to_child2 def fit(self, rule_1: str, rule_2: str, comp_rule: str) -> Optional[Dict[Any, Any]]: - """ - Demonstrate an alignment-based composition workflow using the class methods. + """Demonstrate an alignment-based composition workflow using the class + methods. 1. Convert each GML-based rule into an internal graph (via gml_to_its). 2. Enumerate all unique mappings from rule_2 to comp_rule. diff --git a/synkit/Rule/Compose/seq_comp.py b/synkit/Rule/Compose/seq_comp.py index 5eec9a5..d28363c 100644 --- a/synkit/Rule/Compose/seq_comp.py +++ b/synkit/Rule/Compose/seq_comp.py @@ -5,24 +5,23 @@ class SeqComp: - """ - A class for generating pairwise mappings between sequential chemical reaction rules. + """A class for generating pairwise mappings between sequential chemical + reaction rules. - This class takes a list of reaction SMARTS strings, converts them to their corresponding - GML representations, composes candidate reaction rules for each consecutive pair, and computes - a mapping between the rules using a rule mapping algorithm. + This class takes a list of reaction SMARTS strings, converts them to + their corresponding GML representations, composes candidate reaction + rules for each consecutive pair, and computes a mapping between the + rules using a rule mapping algorithm. """ def __init__(self) -> None: - """ - Initialize an instance of the SeqComp class. - """ + """Initialize an instance of the SeqComp class.""" pass @staticmethod def sequence_map(smarts: List[str]) -> Dict[str, Optional[dict]]: - """ - Generate pairwise mapping dictionaries between consecutive reaction SMARTS strings. + """Generate pairwise mapping dictionaries between consecutive reaction + SMARTS strings. This function processes a list of reaction SMARTS strings by: 1. Converting each SMARTS string to its GML representation. diff --git a/synkit/Rule/Compose/valence_constrain.py b/synkit/Rule/Compose/valence_constrain.py index 84a5f93..543430d 100644 --- a/synkit/Rule/Compose/valence_constrain.py +++ b/synkit/Rule/Compose/valence_constrain.py @@ -17,9 +17,8 @@ class ValenceConstrain: def __init__(self): - """ - Initialize the ValenceConstrain class by setting up bond type orders and loading - the maximum valence data. + """Initialize the ValenceConstrain class by setting up bond type orders + and loading the maximum valence data. Parameters: - None @@ -39,8 +38,7 @@ def __init__(self): self.maxValence = load_database(maxValence_path)[0] def valence(self, vertex) -> int: - """ - Calculate the valence of a vertex based on its incident edges. + """Calculate the valence of a vertex based on its incident edges. Parameters: - vertex (Vertex): The vertex for which to calculate the valence. @@ -51,8 +49,7 @@ def valence(self, vertex) -> int: return sum(self.btToOrder[edge.bondType] for edge in vertex.incidentEdges) def check_rule(self, rule, verbose: bool = False, log_error: bool = False) -> bool: - """ - Check if the rule is chemically valid according to valence rules. + """Check if the rule is chemically valid according to valence rules. Parameters: - rule (Rule): The rule to check for chemical validity. @@ -92,8 +89,7 @@ def check_rule(self, rule, verbose: bool = False, log_error: bool = False) -> bo return False def split(self, rules: List) -> Tuple[List, List]: - """ - Split rules into 'good' and 'bad' based on their chemical validity. + """Split rules into 'good' and 'bad' based on their chemical validity. Parameters: - rules (List[Rule]): A list of rules to be checked and split. diff --git a/synkit/Rule/Modify/implict_rule.py b/synkit/Rule/Modify/implict_rule.py index 44e83af..0ef9cca 100644 --- a/synkit/Rule/Modify/implict_rule.py +++ b/synkit/Rule/Modify/implict_rule.py @@ -9,9 +9,8 @@ def implicit_rule( rsmi: Union[str, List[str]], disconnected: bool = True, balance_its: bool = False ) -> Union[Any, List[Any]]: - """ - Construct reaction-center objects from reaction SMILES by applying implicit‐H rules - and ITS graph construction. + """Construct reaction-center objects from reaction SMILES by applying + implicit‐H rules and ITS graph construction. Parameters ---------- diff --git a/synkit/Rule/Modify/longest_path.py b/synkit/Rule/Modify/longest_path.py index e202ba0..af9b617 100644 --- a/synkit/Rule/Modify/longest_path.py +++ b/synkit/Rule/Modify/longest_path.py @@ -5,8 +5,7 @@ class LongestPath: def __init__(self, G: nx.Graph): - """ - Initializes the LongestPath object with a graph. + """Initializes the LongestPath object with a graph. Parameters: - G (nx.Graph): The networkx graph. @@ -15,9 +14,8 @@ def __init__(self, G: nx.Graph): self.vertices = len(G.nodes) def BFS(self, u: int) -> Tuple[int, int]: - """ - Performs a Breadth-First Search (BFS) from a given node `u` to - find the farthest node and its distance. + """Performs a Breadth-First Search (BFS) from a given node `u` to find + the farthest node and its distance. Parameters: - u (int): The starting node for the BFS. @@ -55,9 +53,8 @@ def BFS(self, u: int) -> Tuple[int, int]: return nodeIdx, maxDis def LongestPathInDisconnectedGraph(self) -> int: - """ - Finds the longest path in a potentially disconnected graph. - The graph can consist of multiple components. + """Finds the longest path in a potentially disconnected graph. The + graph can consist of multiple components. This method performs a BFS on every unvisited component to find the farthest node and computes the longest path across all components. diff --git a/synkit/Rule/Modify/molecule_rule.py b/synkit/Rule/Modify/molecule_rule.py index c6bd471..9d1027a 100644 --- a/synkit/Rule/Modify/molecule_rule.py +++ b/synkit/Rule/Modify/molecule_rule.py @@ -6,20 +6,17 @@ class MoleculeRule: - """ - A class for generating molecule rules, atom-mapped SMILES, and GML representations from SMILES strings. - """ + """A class for generating molecule rules, atom-mapped SMILES, and GML + representations from SMILES strings.""" def __init__(self) -> None: - """ - Initializes the MoleculeRule object. - """ + """Initializes the MoleculeRule object.""" pass @staticmethod def remove_edges_from_left_right(input_str: str) -> str: - """ - Remove all contents from the 'left' and 'right' sections of a chemical rule description. + """Remove all contents from the 'left' and 'right' sections of a + chemical rule description. Parameters: - input_str (str): The string representation of the rule. @@ -45,8 +42,8 @@ def remove_edges_from_left_right(input_str: str) -> str: @staticmethod def generate_atom_map(smiles: str) -> Optional[str]: - """ - Generate atom-mapped SMILES by assigning unique map numbers to each atom in the molecule. + """Generate atom-mapped SMILES by assigning unique map numbers to each + atom in the molecule. Parameters: - smiles (str): The SMILES string representing the molecule. @@ -67,8 +64,7 @@ def generate_atom_map(smiles: str) -> Optional[str]: @staticmethod def generate_molecule_smart(smiles: str) -> Optional[str]: - """ - Generate a SMARTS-like string from atom-mapped SMILES. + """Generate a SMARTS-like string from atom-mapped SMILES. Parameters: - smiles (str): The SMILES string representing the molecule. @@ -90,8 +86,7 @@ def generate_molecule_rule( explicit_hydrogen: bool = True, sanitize: bool = True, ) -> Optional[str]: - """ - Generate a GML representation of the molecule rule from SMILES. + """Generate a GML representation of the molecule rule from SMILES. Parameters: - smiles (str): The SMILES string representing the molecule. diff --git a/synkit/Rule/Modify/prune_templates.py b/synkit/Rule/Modify/prune_templates.py index 7ab6f97..f9ea216 100644 --- a/synkit/Rule/Modify/prune_templates.py +++ b/synkit/Rule/Modify/prune_templates.py @@ -6,8 +6,8 @@ class PruneTemplate: def __init__(self, templates: List[List[Dict[str, Any]]], graph_key: str) -> None: - """ - Initialize the PruneTemplate object with the provided templates and graph key. + """Initialize the PruneTemplate object with the provided templates and + graph key. Parameters: - templates (List[List[Dict[str, Any]]]): A list of lists containing dictionaries @@ -22,8 +22,7 @@ def __init__(self, templates: List[List[Dict[str, Any]]], graph_key: str) -> Non def remove_edges_by_attribute( input_graph: nx.Graph, attribute: str = "standard_order", value: Any = 0 ) -> nx.Graph: - """ - Remove edges from the input graph where a given attribute equals a + """Remove edges from the input graph where a given attribute equals a specified value. Parameters: @@ -49,9 +48,8 @@ def remove_edges_by_attribute( return graph def fit(self) -> List[List[Dict[str, Any]]]: - """ - Prune the templates by removing subgraphs where the longest path is shorter - than the radius. + """Prune the templates by removing subgraphs where the longest path is + shorter than the radius. Returns: List[List[Dict[str, Any]]]: The pruned list of templates. diff --git a/synkit/Rule/Modify/rule_utils.py b/synkit/Rule/Modify/rule_utils.py index 0e618f6..30e4e6e 100644 --- a/synkit/Rule/Modify/rule_utils.py +++ b/synkit/Rule/Modify/rule_utils.py @@ -6,9 +6,10 @@ def find_block(lines, keyword): - """ - Finds the start and end indices of a block (e.g., "left [", "context [", etc.) - in the given lines of GML. Returns (start_idx, end_idx) or (None, None) if not found. + """Finds the start and end indices of a block (e.g., "left [", "context [", + etc.) in the given lines of GML. + + Returns (start_idx, end_idx) or (None, None) if not found. """ start_idx = None depth = 0 @@ -29,8 +30,8 @@ def find_block(lines, keyword): def get_nodes_from_edges(block_lines): - """ - Extract node IDs from edges in the given block lines. + """Extract node IDs from edges in the given block lines. + Returns a set of node IDs found in the edges. """ node_set = set() @@ -43,8 +44,8 @@ def get_nodes_from_edges(block_lines): def parse_context(context_lines, node_regex=None, edge_regex=None): - """ - Parse the context lines to identify nodes and edges. + """Parse the context lines to identify nodes and edges. + Returns two structures: - context_nodes: {node_id: label} - context_edges: list of (source, target, label) @@ -67,9 +68,10 @@ def parse_context(context_lines, node_regex=None, edge_regex=None): def filter_context(context_lines, relevant_nodes): - """ - Given the context lines and a set of relevant nodes, remove hydrogen nodes - not in relevant_nodes and all edges connected to them. Returns filtered lines. + """Given the context lines and a set of relevant nodes, remove hydrogen + nodes not in relevant_nodes and all edges connected to them. + + Returns filtered lines. """ context_nodes, context_edges = parse_context(context_lines) @@ -105,11 +107,11 @@ def filter_context(context_lines, relevant_nodes): def strip_context(gml_text: str, remove_all: bool = True) -> str: - """ - Filters or clears the 'context' section of GML-like content based on the remove_all flag. - If remove_all is True, all edges in the 'context' section are removed. - If False, it removes hydrogen nodes that do not appear in both 'left' and 'right' sections, - along with their edges, while preserving the original structure and formatting of the GML. + """Filters or clears the 'context' section of GML-like content based on the + remove_all flag. If remove_all is True, all edges in the 'context' section + are removed. If False, it removes hydrogen nodes that do not appear in both + 'left' and 'right' sections, along with their edges, while preserving the + original structure and formatting of the GML. Parameters: - gml_text (str): GML-like content describing a chemical reaction rule. @@ -173,8 +175,8 @@ def strip_context(gml_text: str, remove_all: bool = True) -> str: def _increment_gml_ids(gml_content: str) -> str: - """ - Increment the numerical IDs within a GML content string if node id 0 exists. + """Increment the numerical IDs within a GML content string if node id 0 + exists. Parameters: - gml_content (str): The GML content as a string. diff --git a/synkit/Rule/Modify/strip_rule.py b/synkit/Rule/Modify/strip_rule.py index 2dfbaec..c236d7a 100644 --- a/synkit/Rule/Modify/strip_rule.py +++ b/synkit/Rule/Modify/strip_rule.py @@ -2,9 +2,10 @@ def filter_context(context_lines, left_edges): - """ - Given the context lines and a set of edges from the left graph, remove edges - from the context that are also present in the left graph (ignoring labels). + """Given the context lines and a set of edges from the left graph, remove + edges from the context that are also present in the left graph (ignoring + labels). + Returns filtered lines. """ # Create a set of edges from the left graph (ignoring labels) @@ -32,10 +33,12 @@ def filter_context(context_lines, left_edges): def strip_context(gml_text: str, remove_all: bool = False) -> str: - """ - Filters or clears the 'context' section of GML-like content based on the remove_all flag. - If remove_all is True, all edges in the 'context' section are removed. - If False, it removes edges in the 'context' that are also present in the 'left' section. + """Filters or clears the 'context' section of GML-like content based on the + remove_all flag. + + If remove_all is True, all edges in the 'context' section are + removed. If False, it removes edges in the 'context' that are also + present in the 'left' section. """ lines = gml_text.split("\n") diff --git a/synkit/Rule/syn_rule.py b/synkit/Rule/syn_rule.py index 6ec5625..5408625 100644 --- a/synkit/Rule/syn_rule.py +++ b/synkit/Rule/syn_rule.py @@ -160,10 +160,10 @@ def _strip_explicit_h( left: nx.Graph, right: nx.Graph, ) -> None: - """ - Remove explicit hydrogens from rc, left, right—but only when *both* - left & right agree the H should be implicit. Otherwise an H remains - explicit in all three graphs. + """Remove explicit hydrogens from rc, left, right—but only when *both* + left & right agree the H should be implicit. + + Otherwise an H remains explicit in all three graphs. """ def _removable_on(graph: nx.Graph, h: str) -> bool: diff --git a/synkit/Synthesis/CRN/crn.py b/synkit/Synthesis/CRN/crn.py index 0506533..646e63f 100644 --- a/synkit/Synthesis/CRN/crn.py +++ b/synkit/Synthesis/CRN/crn.py @@ -4,7 +4,7 @@ from copy import deepcopy from typing import Any, Dict, List, Sequence, Union -from synkit.Chem.Reaction.cleanning import Cleanning +from synkit.Chem.Reaction.cleaning import Cleaning from synkit.Chem.utils import ( count_carbons, get_max_fragment, @@ -18,8 +18,7 @@ class CRN: - """ - Expand an initial pool of molecules through several rounds of rule + """Expand an initial pool of molecules through several rounds of rule application using **MODReactor** under the hood. Public attributes @@ -91,8 +90,7 @@ def rule_count(self) -> int: @property def product_sets(self) -> Dict[str, List[str]]: - """ - Dict view of the per‑round reaction SMILES. + """Dict view of the per‑round reaction SMILES. Handles both shapes: @@ -136,7 +134,8 @@ def __repr__(self) -> str: # ============================================================ internals def _expand_once(self, smiles: List[str]) -> List[str]: - """Apply every rule once to the molecule pool and return reaction RSMI.""" + """Apply every rule once to the molecule pool and return reaction + RSMI.""" rxn_results: List[str] = [] smiles_for_mod = process_smiles_list(smiles) @@ -150,7 +149,7 @@ def _expand_once(self, smiles: List[str]) -> List[str]: ) reactor.run() rsmi = reactor.get_reaction_smiles() - rsmi = Cleanning().clean_smiles(rsmi) + rsmi = Cleaning().clean_smiles(rsmi) rsmi = [_remove_reagent(r) for r in rsmi] rxn_results.extend(rsmi) @@ -164,7 +163,8 @@ def _update_smiles_pool( starting: str, target: str, ) -> List[str]: - """Merge products from *reactions* into *current* with optional pruning.""" + """Merge products from *reactions* into *current* with optional + pruning.""" new: List[str] = [] for rsmi in reactions: diff --git a/synkit/Synthesis/CRN/dcrn.py b/synkit/Synthesis/CRN/dcrn.py index e036206..6613ad2 100644 --- a/synkit/Synthesis/CRN/dcrn.py +++ b/synkit/Synthesis/CRN/dcrn.py @@ -1,14 +1,14 @@ from collections import defaultdict import heapq from typing import List, Dict, Any -from synkit.Chem.Reaction.cleanning import Cleanning + from synkit.Chem.utils import ( count_carbons, process_smiles_list, get_max_fragment, ) from synkit.Synthesis.reactor_utils import _remove_reagent -from synkit.Synthesis.core_engine import CoreEngine +from synkit.Synthesis.Reactor.mod_reactor import MODReactor class DCRN: @@ -31,14 +31,13 @@ def __init__( @staticmethod def _get_valid_node(molecules, lower, upper): - """ - Filters molecules by their carbon count within the given range. - """ + """Filters molecules by their carbon count within the given range.""" return [mol for mol in molecules if lower <= count_carbons(mol) <= upper] def _expand(self, smiles_list: List[str]) -> List[str]: - """ - Expands molecules based on transformation rules. Uses caching to avoid redundant computation. + """Expands molecules based on transformation rules. + + Uses caching to avoid redundant computation. """ smiles_tuple = tuple(smiles_list) if smiles_tuple in self.expansion_cache: @@ -47,8 +46,8 @@ def _expand(self, smiles_list: List[str]) -> List[str]: results = [] processed_smiles = process_smiles_list(smiles_list) for rule_dict in self.rule_list: - expansions = CoreEngine()._inference(rule_dict["gml"], processed_smiles) - expansions = Cleanning().clean_smiles(expansions) + expansions = MODReactor()._inference(rule_dict["gml"], processed_smiles) + expansions = MODReactor().clean_smiles(expansions) expansions = [_remove_reagent(e) for e in expansions] for r in expansions: product = r.split(">>")[1] @@ -63,15 +62,13 @@ def _expand(self, smiles_list: List[str]) -> List[str]: return valid_nodes def _heuristic(self, a: str, b: str) -> int: - """ - Returns the heuristic estimate (absolute difference in carbon count) between two compounds. - """ + """Returns the heuristic estimate (absolute difference in carbon count) + between two compounds.""" return abs(count_carbons(a) - count_carbons(b)) def _dynamic_expand_node(self, node: str, smiles_list: list) -> None: - """ - Dynamically expands the given node to generate new possible compounds. - """ + """Dynamically expands the given node to generate new possible + compounds.""" if node not in self.visited: self.visited.add(node) expanded_nodes = self._expand( @@ -86,8 +83,9 @@ def build_and_search( max_solutions: int = 5, fast_process: bool = True, ) -> Dict[str, Any]: - """ - Builds the search graph and searches for paths from the starting compound to the target compound. + """Builds the search graph and searches for paths from the starting + compound to the target compound. + Ensures depth levels follow a sequential order starting from 0. """ # Initialize the heap with the starting compound at depth 0 diff --git a/synkit/Synthesis/CRN/mod_crn.py b/synkit/Synthesis/CRN/mod_crn.py index 51ad4d0..cc941dd 100644 --- a/synkit/Synthesis/CRN/mod_crn.py +++ b/synkit/Synthesis/CRN/mod_crn.py @@ -22,9 +22,7 @@ class MODCRN: - """ - MODCRN - ====== + """MODCRN ====== High-level class for constructing, inspecting, and reporting a chemical reaction network using the MØD derivation graph (DG) API. @@ -136,28 +134,21 @@ def build(self) -> None: builder.execute(strat) def print_summary(self) -> None: - """ - Print and save a concise summary of the derivation graph. - """ + """Print and save a concise summary of the derivation graph.""" out_dir = "out" os.makedirs(out_dir, exist_ok=True) self._dg.print() def export_report(self) -> None: - """ - Generate an external report via the `mod_post` CLI. - - """ + """Generate an external report via the `mod_post` CLI.""" try: subprocess.run(["mod_post"], check=True) except subprocess.CalledProcessError as e: logger.error(f"mod_post failed with exit code {e.returncode}") def help(self) -> None: - """ - Print usage examples and API summary for MODCRN. - """ + """Print usage examples and API summary for MODCRN.""" print( "MODCRN Usage:\n" " crn = MODCRN(rule_db_path, initial_smiles, repeats)\n" diff --git a/synkit/Synthesis/MSR/multi_steps.py b/synkit/Synthesis/MSR/multi_steps.py index 1c0ad98..2c144d9 100644 --- a/synkit/Synthesis/MSR/multi_steps.py +++ b/synkit/Synthesis/MSR/multi_steps.py @@ -9,17 +9,15 @@ class MultiSteps: def __init__(self) -> None: - """ - Initialize the MultiStep class with a Standardize instance. - """ + """Initialize the MultiStep class with a Standardize instance.""" self.std = Standardize() @staticmethod def _process( gml_list: List[str], order: List[int], rsmi: str, exclude_aam: bool = True ) -> Tuple[List[List[str]], Dict[str, List[str]]]: - """ - Process a series of chemical reactions according to given rules and order. + """Process a series of chemical reactions according to given rules and + order. Parameters: - gml_list (List[str]): List of GML format strings representing reaction rules. @@ -73,8 +71,8 @@ def _process( def _get_aam( rsmi_list: List[str], rule_list: List[str], order: List[int] ) -> List[str]: - """ - Apply atom-atom mapping to a series of reaction SMILES strings according to specified rules. + """Apply atom-atom mapping to a series of reaction SMILES strings + according to specified rules. Parameters: - rsmi_list (List[str]): List of reaction SMILES strings. @@ -111,8 +109,8 @@ def multi_step( order: List[int], cat: Union[str, List[str]], ) -> List[str]: - """ - Orchestrate a multi-step chemical reaction process using a set of rules and a starting reactant. + """Orchestrate a multi-step chemical reaction process using a set of + rules and a starting reactant. Parameters: - original_rsmi (str): Initial reactant SMILES string. diff --git a/synkit/Synthesis/MSR/path_finder.py b/synkit/Synthesis/MSR/path_finder.py index 9693d09..a610ab8 100644 --- a/synkit/Synthesis/MSR/path_finder.py +++ b/synkit/Synthesis/MSR/path_finder.py @@ -9,9 +9,9 @@ def __init__( self, reaction_rounds: List[Dict[str, List[str]]], ): - """ - Initialize with a list of dictionaries, each representing a reaction round, - plus an optional random state for reproducible Monte Carlo search. + """Initialize with a list of dictionaries, each representing a reaction + round, plus an optional random state for reproducible Monte Carlo + search. Parameters: - reaction_rounds (List[Dict[str, List[str]]]): A list where each dictionary @@ -56,9 +56,9 @@ def search_paths( max_solutions: Optional[int] = None, cheapest: bool = True, ) -> List[List[str]]: - """ - Search for reaction pathways from the input molecule to the target molecule - using a specified method, optionally limiting the number of solutions. + """Search for reaction pathways from the input molecule to the target + molecule using a specified method, optionally limiting the number of + solutions. Additionally, `cheapest` can be set to True or False: - If cheapest=True, BFS uses a visited set and A* prunes costlier routes (typical approach). @@ -92,12 +92,13 @@ def _bfs( max_solutions: Optional[int], cheapest: bool, ) -> List[List[str]]: - """ - Perform a BFS search. If cheapest=True, use a visited set to avoid re-processing - the same (molecule, round_index). If cheapest=False, skip that pruning and collect - *all* possible solutions (potentially large if cycles exist). + """Perform a BFS search. If cheapest=True, use a visited set to avoid + re-processing the same (molecule, round_index). If cheapest=False, skip + that pruning and collect *all* possible solutions (potentially large if + cycles exist). - Returns a list of successful reaction pathways, up to max_solutions if specified. + Returns a list of successful reaction pathways, up to + max_solutions if specified. """ queue = deque([(input_smiles, [], 0)]) @@ -134,9 +135,10 @@ def _bfs( return pathways def _heuristic(self, smiles: str, target_smiles: str) -> int: - """ - Heuristic function for A* search. - Returns difference in SMILES lengths as a stand-in for "distance." + """Heuristic function for A* search. + + Returns difference in SMILES lengths as a stand-in for + "distance." """ return abs(len(smiles) - len(target_smiles)) @@ -147,12 +149,12 @@ def _astar( max_solutions: Optional[int], cheapest: bool, ) -> List[List[str]]: - """ - A* search. If cheapest=True, we track the best cost visited for each state - and prune costlier paths. If cheapest=False, we do not prune, so we collect - all solutions (but it may be large). + """A* search. If cheapest=True, we track the best cost visited for each + state and prune costlier paths. If cheapest=False, we do not prune, so + we collect all solutions (but it may be large). - Returns a list of successful reaction pathways, up to max_solutions if specified. + Returns a list of successful reaction pathways, up to + max_solutions if specified. """ start_cost = self._heuristic(input_smiles, target_smiles) # Heap stores (cost, current_smiles, current_path, round_index) diff --git a/synkit/Synthesis/Metrics/_base.py b/synkit/Synthesis/Metrics/_base.py index 5ed47ab..09657c5 100644 --- a/synkit/Synthesis/Metrics/_base.py +++ b/synkit/Synthesis/Metrics/_base.py @@ -15,8 +15,7 @@ def _compute_metrics( k: int = 5, beta: float = 1, ) -> Dict[str, float]: - """ - Computes the metrics for a list of reactions data. + """Computes the metrics for a list of reactions data. Parameters: - reactions_data (List[Dict[str, any]]): List of dictionaries containing RSMI strings. diff --git a/synkit/Synthesis/Metrics/_plot.py b/synkit/Synthesis/Metrics/_plot.py index ab84809..0751580 100644 --- a/synkit/Synthesis/Metrics/_plot.py +++ b/synkit/Synthesis/Metrics/_plot.py @@ -17,9 +17,9 @@ def plot_recognition_coverage_curve( show_f2=True, show_legend=True, ): - """ - Plots a Recognition-Coverage curve using provided data, including optional - F2 scores annotated. Styled with Seaborn for enhanced visual appearance. + """Plots a Recognition-Coverage curve using provided data, including + optional F2 scores annotated. Styled with Seaborn for enhanced visual + appearance. Parameters: - data (dict): Nested dictionary containing the data for each radii, @@ -66,9 +66,8 @@ def plot_recognition_coverage_curve( def plot_f2_scores_line(data, figsize=(8, 6), show_f2=True, show_legend=True): - """ - Plots F2 scores across different radii using a line plot, showing - the trend of F2 score changes, and annotated with optional F2 scores. + """Plots F2 scores across different radii using a line plot, showing the + trend of F2 score changes, and annotated with optional F2 scores. Parameters: - data (dict): Dictionary containing nested dictionaries with 'F2_score' diff --git a/synkit/Synthesis/Metrics/_ranking.py b/synkit/Synthesis/Metrics/_ranking.py index b11fd88..50b773c 100644 --- a/synkit/Synthesis/Metrics/_ranking.py +++ b/synkit/Synthesis/Metrics/_ranking.py @@ -5,9 +5,9 @@ def _coverage( reactions_data: List[Dict[str, str]], key_ground_truth: str, key_prediction: str ) -> float: - """ - Calculates the coverage percentage, which measures how many of the predicted reactions - exactly match the ground truth reactions given in a list of dictionaries. + """Calculates the coverage percentage, which measures how many of the + predicted reactions exactly match the ground truth reactions given in a + list of dictionaries. Parameters: - reactions_data (List[Dict[str, str]]): List of dictionaries containing @@ -29,11 +29,10 @@ def _coverage( def _novelty_rate( reactions_data: List[Dict[str, any]], key_ground_truth: str, key_prediction: str ) -> float: - """ - Calculates the False Positive Rate (FPR) for each observation and then averages - these values across all observations. The FPR represents the proportion of - predictions that do not match - the ground truth for each individual entry in the dataset. + """Calculates the False Positive Rate (FPR) for each observation and then + averages these values across all observations. The FPR represents the + proportion of predictions that do not match the ground truth for each + individual entry in the dataset. Parameters: - reactions_data (List[Dict[str, any]]): List of dictionaries containing @@ -68,10 +67,10 @@ def _novelty_rate( def _recognition_rate( reactions_data: List[Dict[str, any]], key_ground_truth: str, key_prediction: str ) -> float: - """ - Calculates the recognition rate for each observation and averages these rates - across all observations. The recognition rate measures the proportion of - the prediction list that matches the single ground truth reaction for each entry. + """Calculates the recognition rate for each observation and averages these + rates across all observations. The recognition rate measures the proportion + of the prediction list that matches the single ground truth reaction for + each entry. Parameters: - reactions_data (List[Dict[str, any]]): List of dictionaries containing @@ -110,9 +109,9 @@ def _top_k_accuracy( key_prediction: str, k: int, ) -> float: - """ - Calculates the Top-K accuracy by using the coverage function on the top K predictions. - This measures the probability that the true reaction is within the top K predictions. + """Calculates the Top-K accuracy by using the coverage function on the top + K predictions. This measures the probability that the true reaction is + within the top K predictions. Parameters: - reactions_data (List[Dict[str, any]]): List of dictionaries containing @@ -137,12 +136,12 @@ def _calculate_f_beta_score( coverage_rate: float, # This serves as the recall beta: float = 1.0, # Beta factor, default is 1.0 for F1 score ) -> float: - """ - Computes the F-beta Score, which is a weighted harmonic mean of recognition rate - and coverage rate. The recognition rate (precision) and coverage rate (recall) - must be expressed as percentages. A beta value of 1.0 means equal importance to - precision and recall (F1 Score), greater than 1.0 gives more importance to recall - (e.g., F2 Score), and less than 1.0 prioritizes precision (e.g., F0.5 Score). + """Computes the F-beta Score, which is a weighted harmonic mean of + recognition rate and coverage rate. The recognition rate (precision) and + coverage rate (recall) must be expressed as percentages. A beta value of + 1.0 means equal importance to precision and recall (F1 Score), greater than + 1.0 gives more importance to recall (e.g., F2 Score), and less than 1.0 + prioritizes precision (e.g., F0.5 Score). Parameters: - recognition_rate (float): The recognition rate of the predictions, diff --git a/synkit/Synthesis/Reactor/batch_reactor.py b/synkit/Synthesis/Reactor/batch_reactor.py index a87cd87..987d797 100644 --- a/synkit/Synthesis/Reactor/batch_reactor.py +++ b/synkit/Synthesis/Reactor/batch_reactor.py @@ -8,8 +8,7 @@ class BatchReactor: - """ - Apply a collection of pattern-graphs (rules) to a batch of substrates. + """Apply a collection of pattern-graphs (rules) to a batch of substrates. Each data entry can be: - a dict (expects substrate under `host_key`) @@ -53,25 +52,26 @@ def __init__( implicit_temp: bool = False, strategy: str = "bt", ) -> None: - """ - Initialize batch reactor configuration. - - :param data: Batch of substrates to process. - :type data: list - :param host_key: Key to extract graph/SMILES from dict entries. - :type host_key: str or None - :param react_engine: Which reactor engine to use ('syn' or 'mod'). - :type react_engine: str - :param filter_engine: RuleFilter engine (or None to skip filtering). - :type filter_engine: str or None - :param invert: Use inverted rule patterns if True. - :type invert: bool - :param explicit_h: Use explicit hydrogens in SynReactor. - :type explicit_h: bool - :param implicit_temp: Use implicit templates in SynReactor. - :type implicit_temp: bool - :param strategy: Matching strategy identifier. - :type strategy: str + """Initialize batch reactor configuration. + + :param data: Batch of substrates to process. + :type data: list + :param host_key: Key to extract graph/SMILES from dict entries. + :type host_key: str or None + :param react_engine: Which reactor engine to use ('syn' or + 'mod'). + :type react_engine: str + :param filter_engine: RuleFilter engine (or None to skip + filtering). + :type filter_engine: str or None + :param invert: Use inverted rule patterns if True. + :type invert: bool + :param explicit_h: Use explicit hydrogens in SynReactor. + :type explicit_h: bool + :param implicit_temp: Use implicit templates in SynReactor. + :type implicit_temp: bool + :param strategy: Matching strategy identifier. + :type strategy: str """ self._data = data self._host_key = host_key @@ -89,13 +89,13 @@ def __init__( def _get_substrate( self, entry: Union[Dict[str, Any], str, nx.Graph] ) -> Union[nx.Graph, str]: - """ - Normalize and validate an entry based on react_engine. + """Normalize and validate an entry based on react_engine. - :param entry: The substrate entry (dict, SMILES, or Graph). - :type entry: dict or str or nx.Graph - :returns: networkx.Graph (for 'syn') or SMILES string (for 'mod'). - :rtype: nx.Graph or str + :param entry: The substrate entry (dict, SMILES, or Graph). + :type entry: dict or str or nx.Graph + :returns: networkx.Graph (for 'syn') or SMILES string (for + 'mod'). + :rtype: nx.Graph or str """ # extract from dict if needed if isinstance(entry, dict): @@ -122,15 +122,14 @@ def _get_substrate( def _filter_rules( self, substrate: Union[nx.Graph, str], rules_list: List[Any] ) -> List[Any]: - """ - Apply rule filtering if configured. - - :param substrate: Host graph or SMILES to filter against. - :type substrate: nx.Graph or str - :param rules_list: List of rule patterns. - :type rules_list: list - :returns: Filtered list of rules. - :rtype: list + """Apply rule filtering if configured. + + :param substrate: Host graph or SMILES to filter against. + :type substrate: nx.Graph or str + :param rules_list: List of rule patterns. + :type rules_list: list + :returns: Filtered list of rules. + :rtype: list """ if self._filter_engine and self._react_engine == "syn": rf = RuleFilter( @@ -143,13 +142,13 @@ def _filter_rules( return rules_list def fit(self, rules_list: List[Any]) -> List[List[str]]: - """ - Apply each rule to every substrate, returning product SMARTS or reaction SMILES. + """Apply each rule to every substrate, returning product SMARTS or + reaction SMILES. - :param rules_list: List of rules (pattern-graphs or objects). - :type rules_list: list - :returns: Nested list: outputs[i] for substrate i. - :rtype: List[List[str]] + :param rules_list: List of rules (pattern-graphs or objects). + :type rules_list: list + :returns: Nested list: outputs[i] for substrate i. + :rtype: List[List[str]] """ results: List[List[str]] = [] for entry in self._data: @@ -181,8 +180,7 @@ def fit(self, rules_list: List[Any]) -> List[List[str]]: @property def data(self) -> List[Union[Dict[str, Any], str, nx.Graph]]: - """ - Original batch input data. + """Original batch input data. :returns: The list of data entries. :rtype: list @@ -191,8 +189,7 @@ def data(self) -> List[Union[Dict[str, Any], str, nx.Graph]]: @property def filter_engine(self) -> Optional[str]: - """ - The engine used for rule filtering. + """The engine used for rule filtering. :returns: Name of filter engine or None. :rtype: str or None @@ -201,8 +198,7 @@ def filter_engine(self) -> Optional[str]: @property def react_engine(self) -> str: - """ - The engine used for reaction application. + """The engine used for reaction application. :returns: Name of react engine ('syn' or 'mod'). :rtype: str @@ -210,8 +206,7 @@ def react_engine(self) -> str: return self._react_engine def __repr__(self) -> str: - """ - Concise summary of BatchReactor configuration. + """Concise summary of BatchReactor configuration. :returns: Representation string. :rtype: str @@ -222,8 +217,7 @@ def __repr__(self) -> str: ) def __help__(self) -> str: - """ - Return class documentation for interactive help. + """Return class documentation for interactive help. :returns: The class docstring. :rtype: str diff --git a/synkit/Synthesis/Reactor/core_engine.py b/synkit/Synthesis/Reactor/core_engine.py deleted file mode 100644 index 2624e0e..0000000 --- a/synkit/Synthesis/Reactor/core_engine.py +++ /dev/null @@ -1,212 +0,0 @@ -# import warnings -# from rdkit import Chem -# from pathlib import Path -# from typing import List, Union -# from collections import Counter -# from synkit.IO.data_io import load_gml_as_text -# from synkit.Synthesis.reactor_utils import _deduplicateGraphs, _get_connected_subgraphs - -# import mod -# from mod import smiles, config, ruleGMLString, DG - - -# class CoreEngine: -# """ -# The MØDModeling class encapsulates functionalities for reaction modeling using the MØD -# toolkit. It provides methods for forward and backward prediction based on templates -# library. -# """ - -# def __init__(self) -> None: -# warnings.warn("deprecated", DeprecationWarning) -# pass - -# @staticmethod -# def generate_reaction_smiles( -# temp_results: List[str], base_smiles: str, is_forward: bool = True -# ) -> List[str]: -# """ -# Constructs reaction SMILES strings from intermediate results using a base SMILES -# string, indicating whether the process is a forward or backward reaction. This -# function iterates over a list of intermediate SMILES strings, combines them with -# the base SMILES, and formats them into complete reaction SMILES strings. - -# Parameters: -# - temp_results (List[str]): Intermediate SMILES strings resulting from partial -# reactions or combinations. -# - base_smiles (str): The SMILES string representing the starting point of the -# reaction, either as reactants or products, depending on the reaction direction. -# - is_forward (bool, optional): Flag to determine the direction of the reaction; -# 'True' for forward reactions where 'base_smiles' are reactants, and 'False' for -# backward reactions where 'base_smiles' are products. Defaults to True. - -# Returns: -# - List[str]: A list of complete reaction SMILES strings, formatted according to -# the specified reaction direction. -# """ -# results = [] -# for comb in temp_results: -# joined_smiles = ".".join(comb) -# reaction_smiles = ( -# f"{base_smiles}>>{joined_smiles}" -# if is_forward -# else f"{joined_smiles}>>{base_smiles}" -# ) -# results.append(reaction_smiles) -# return results - -# @staticmethod -# def _prediction_wo_reagent( -# initial_molecules: List[Union[str, object]], -# rule: mod.libpymod.Rule, -# print_results: bool = False, -# verbosity: int = 0, -# ) -> List[List[str]]: -# """ -# Applies the reaction rule to the given molecules without considering reagents. - -# Parameters: -# - initial_molecules (List[Union[str, object]]): List of initial molecules represented by SMILES or objects. -# - rule (mod.libpymod.Rule): The reaction rule to apply. -# - print_results (bool): Whether to print the results. -# - verbosity (int): Verbosity level for output. - -# Returns: -# - List[List[str]]: A list of intermediate SMILES strings for the reaction products. -# """ -# # Initialize the derivation graph and execute the strategy -# dg = DG(graphDatabase=initial_molecules) -# config.dg.doRuleIsomorphismDuringBinding = False -# dg.build().apply(initial_molecules, rule, verbosity=verbosity) -# if print_results: -# dg.print() - -# temp_results = [] -# for e in dg.edges: -# productSmiles = [v.graph.smiles for v in e.targets] -# temp_results.append(productSmiles) -# del dg -# return temp_results - -# @staticmethod -# def _prediction_with_reagent( -# initial_smiles: List[str], -# initial_molecules: List[Union[str, object]], -# rule: mod.libpymod.Rule, -# print_results: bool = False, -# verbosity: int = 0, -# ) -> List[List[str]]: -# """ -# Applies the reaction rule to the given molecules considering the reagents. - -# Parameters: -# - initial_smiles (List[str]): Initial molecules represented as SMILES strings. -# - initial_molecules (List[Union[str, object]]): List of initial molecules. -# - rule (mod.libpymod.Rule): The reaction rule to apply. -# - print_results (bool): Whether to print the results. -# - verbosity (int): Verbosity level for output. - -# Returns: -# - List[List[str]]: A list of intermediate SMILES strings with reagents included. -# """ -# dg = DG(graphDatabase=initial_molecules) -# config.dg.doRuleIsomorphismDuringBinding = False -# dg.build().apply(initial_molecules, rule, verbosity=verbosity, onlyProper=False) -# if print_results: -# dg.print() -# temp_results, small_educt = [], [] -# for edge in dg.edges: -# temp_results.append([vertex.graph.smiles for vertex in edge.targets]) -# small_educt.append([vertex.graph.smiles for vertex in edge.sources]) - -# for key, solution in enumerate(temp_results): -# educt = small_educt[key] -# small_educt_counts = Counter( -# Chem.CanonSmiles(smile) for smile in educt if smile is not None -# ) -# reagent_counts = Counter([Chem.CanonSmiles(s) for s in initial_smiles]) -# reagent_counts.subtract(small_educt_counts) -# reagent = [ -# smile -# for smile, count in reagent_counts.items() -# for _ in range(count) -# if count > 0 -# ] -# solution.extend(reagent) -# del dg -# return temp_results - -# @staticmethod -# def _inference( -# rule_file_path: Union[str, Path], -# initial_smiles: List[str], -# prediction_type: str = "forward", -# print_results: bool = False, -# verbosity: int = 0, -# ) -> List[str]: -# """ -# Applies a specified reaction rule to a set of initial molecules represented by SMILES strings. -# The reaction can be simulated in forward or backward direction. - -# Parameters: -# - rule_file_path (Union[str, Path]): Path to the GML file containing the reaction rule. -# - initial_smiles (List[str]): Initial molecules as SMILES strings. -# - prediction_type (str): Direction of the reaction ('forward' or 'backward'). -# - print_results (bool): Whether to print the results. -# - verbosity (int): Verbosity level for output. - -# Returns: -# - List[str]: SMILES strings of the resulting molecules or reactions. -# """ - -# # Determine the rule inversion based on reaction type -# invert_rule = prediction_type == "backward" -# # Convert SMILES strings to molecule objects, avoiding duplicate conversions -# initial_molecules = [smiles(smile, add=False) for smile in (initial_smiles)] - -# initial_molecules = _deduplicateGraphs(initial_molecules) - -# initial_molecules = sorted( -# initial_molecules, key=lambda molecule: molecule.numVertices, reverse=False -# ) -# # Load the reaction rule from the GML file -# rule_path = Path(rule_file_path) - -# try: -# if rule_path.is_file(): -# gml_content = load_gml_as_text(rule_file_path) -# else: -# gml_content = rule_file_path -# except Exception as e: -# # print(f"An error occurred while loading the GML file: {e}") -# gml_content = rule_file_path -# reaction_rule = ruleGMLString(gml_content, invert=invert_rule, add=False) - -# _number_subgraphs = _get_connected_subgraphs(gml_content, invert=invert_rule) -# if len(initial_molecules) <= _number_subgraphs: -# temp_results = CoreEngine._prediction_wo_reagent( -# initial_molecules, reaction_rule, print_results, verbosity -# ) -# else: -# temp_results = CoreEngine._prediction_with_reagent( -# initial_smiles, -# initial_molecules, -# reaction_rule, -# print_results, -# verbosity, -# ) - -# reaction_processing_map = { -# "forward": lambda smiles: CoreEngine.generate_reaction_smiles( -# temp_results, ".".join(initial_smiles), is_forward=True -# ), -# "backward": lambda smiles: CoreEngine.generate_reaction_smiles( -# temp_results, ".".join(initial_smiles), is_forward=False -# ), -# } - -# # Use the reaction type to select the appropriate processing function and apply it -# if prediction_type in reaction_processing_map: -# return reaction_processing_map[prediction_type](initial_smiles) -# else: -# return "" diff --git a/synkit/Synthesis/Reactor/mod_aam.py b/synkit/Synthesis/Reactor/mod_aam.py index d047cf0..951f3bb 100644 --- a/synkit/Synthesis/Reactor/mod_aam.py +++ b/synkit/Synthesis/Reactor/mod_aam.py @@ -29,7 +29,7 @@ from synkit.Graph.ITS.its_expand import ITSExpand from synkit.Graph.ITS.normalize_aam import NormalizeAAM from synkit.Chem.Reaction.standardize import Standardize -from synkit.Chem.Reaction.rsmi_utils import reverse_reaction +from synkit.Chem.utils import reverse_reaction from synkit.Synthesis.reactor_utils import _get_unique_aam, _get_reagent, _add_reagent from synkit.Synthesis.Reactor.strategy import Strategy @@ -39,8 +39,7 @@ class MODAAM: - """ - Runs MØD (via MODReactor) then a full AAM/ITS post-processing pipeline. + """Runs MØD (via MODReactor) then a full AAM/ITS post-processing pipeline. Parameters ---------- @@ -106,9 +105,7 @@ def _run_pipeline(self) -> None: self._aam_smiles = self._process_aam(self._dg) def run(self) -> List[str]: - """ - Re-run the entire pipeline (MØD + AAM) and return fresh results. - """ + """Re-run the entire pipeline (MØD + AAM) and return fresh results.""" self._run_pipeline() return self._aam_smiles @@ -231,8 +228,7 @@ def _deduplicate(self, smiles: List[str]) -> List[str]: def expand_aam(rsmi: str, rule: str) -> List[str]: - """ - Expand Atom–Atom Mapping (AAM) for a given reaction SMARTS/SMILES (rsmi) + """Expand Atom–Atom Mapping (AAM) for a given reaction SMARTS/SMILES (rsmi) using a pre‐sanitized GML rule string. Parameters diff --git a/synkit/Synthesis/Reactor/mod_reactor.py b/synkit/Synthesis/Reactor/mod_reactor.py index a1042c3..cc15304 100644 --- a/synkit/Synthesis/Reactor/mod_reactor.py +++ b/synkit/Synthesis/Reactor/mod_reactor.py @@ -49,8 +49,7 @@ # MODReactor # ────────────────────────────────────────────────────────────────────────────── class MODReactor: - """ - Lazy, ergonomic wrapper around the MØD toolkit’s derivation pipeline. + """Lazy, ergonomic wrapper around the MØD toolkit’s derivation pipeline. Workflow -------- @@ -108,8 +107,8 @@ def __init__( # Public high‑level API # ------------------------------------------------------------------ def run(self) -> "MODReactor": - """ - Execute the chosen strategy **once** and return *self* so you can chain: + """Execute the chosen strategy **once** and return *self* so you can + chain: ```python r = MODReactor(...).run() @@ -122,8 +121,7 @@ def run(self) -> "MODReactor": # helpers for outside world ------------------------------------------------ def get_reaction_smiles(self) -> List[str]: - """ - Retrieve the reaction SMILES strings (lazy). + """Retrieve the reaction SMILES strings (lazy). Returns ------- @@ -133,8 +131,7 @@ def get_reaction_smiles(self) -> List[str]: return self.reaction_smiles def get_dg(self) -> DG: - """ - Access the underlying derivation graph. + """Access the underlying derivation graph. Returns ------- @@ -163,9 +160,7 @@ def __str__(self) -> str: __repr__ = __str__ def help(self) -> None: - """ - Print a one-page summary of reactor configuration and results. - """ + """Print a one-page summary of reactor configuration and results.""" print("MODReactor".ljust(60, "─")) print(f"Rule file : {self.rule_file}") print(f"Substrate : {'.'.join(self.initial_smiles)}") @@ -182,8 +177,7 @@ def help(self) -> None: # ------------------------------------------------------------------ @property def dg(self) -> Optional[DG]: - """ - DG or None – cached derivation graph. + """DG or None – cached derivation graph. See also -------- @@ -193,23 +187,18 @@ def dg(self) -> Optional[DG]: @property def product_sets(self) -> List[List[str]]: - """ - Raw product sets (lists of SMILES) before joining into full reactions. - """ + """Raw product sets (lists of SMILES) before joining into full + reactions.""" return self.temp_results @property def product_smiles(self) -> List[str]: - """ - Flattened list of all product SMILES (may contain duplicates). - """ + """Flattened list of all product SMILES (may contain duplicates).""" return [s for batch in self.temp_results for s in batch] @property def prediction_count(self) -> int: - """ - Number of distinct prediction batches generated. - """ + """Number of distinct prediction batches generated.""" return len(self._temp_results or []) # ------------------------------------------------------------------ @@ -217,8 +206,7 @@ def prediction_count(self) -> int: # ------------------------------------------------------------------ @property def temp_results(self) -> List[List[str]]: - """ - Lazy-loaded raw product lists. + """Lazy-loaded raw product lists. Returns ------- @@ -230,8 +218,7 @@ def temp_results(self) -> List[List[str]]: @property def reaction_smiles(self) -> List[str]: - """ - Lazy-loaded reaction SMILES strings of form “A>>B”. + """Lazy-loaded reaction SMILES strings of form “A>>B”. Returns ------- @@ -248,8 +235,7 @@ def reaction_smiles(self) -> List[str]: # Internals – setup # ------------------------------------------------------------------ def _prepare_initial_molecules(self) -> List[Any]: - """ - Convert SMILES → MØD molecule objects, dedupe, and sort. + """Convert SMILES → MØD molecule objects, dedupe, and sort. Returns ------- @@ -262,8 +248,7 @@ def _prepare_initial_molecules(self) -> List[Any]: return mols def _parse_reaction_rule(self) -> Any: - """ - Load or parse the reaction rule from raw GML or file. + """Load or parse the reaction rule from raw GML or file. Returns ------- @@ -304,8 +289,7 @@ def _parse_reaction_rule(self) -> Any: # Internals – strategy dispatch # ------------------------------------------------------------------ def _predict(self) -> List[List[str]]: - """ - Dispatch to the appropriate application strategy. + """Dispatch to the appropriate application strategy. Returns ------- @@ -350,8 +334,7 @@ def _apply_components(self) -> List[List[str]]: return products def _apply_all(self) -> List[List[str]]: - """ - Classic “ALL” strategy: VF2 with reagents included. + """Classic “ALL” strategy: VF2 with reagents included. Returns ------- @@ -413,9 +396,8 @@ def generate_reaction_smiles( arrow: str = ">>", separator: str = ".", ) -> List[str]: - """ - Build reaction SMILES of the form “A>>B”, where A and B swap - roles if invert=True. + """Build reaction SMILES of the form “A>>B”, where A and B swap roles + if invert=True. Parameters ---------- diff --git a/synkit/Synthesis/Reactor/partial_engine.py b/synkit/Synthesis/Reactor/partial_engine.py index 368ed11..d1330e0 100644 --- a/synkit/Synthesis/Reactor/partial_engine.py +++ b/synkit/Synthesis/Reactor/partial_engine.py @@ -1,25 +1,25 @@ from synkit.IO import rsmi_to_its, smiles_to_graph from synkit.Chem.Reaction.radical_wildcard import RadicalWildcardAdder from synkit.Synthesis.Reactor.syn_reactor import SynReactor -from synkit.Chem.Reaction.rsmi_utils import remove_explicit_H_from_rsmi +from synkit.Chem.utils import remove_explicit_H_from_rsmi class PartialEngine: - """ - Partial Reaction Learning Engine that applies a single‐direction + """Partial Reaction Learning Engine that applies a single‐direction (forward or backward) template transformation, injects radical wildcards, and returns a list of intermediate ITS strings. - :param smi: A reaction SMARTS (rsmi) string in the form "Reactants>>Products" or - a simple SMILES string when used for one‐sided synthesis. + :param smi: A reaction SMARTS (rsmi) string in the form + "Reactants>>Products" or a simple SMILES string when used for + one‐sided synthesis. :type smi: str - :param template: A reaction template SMARTS string, which may include explicit H. + :param template: A reaction template SMARTS string, which may + include explicit H. :type template: str """ def __init__(self, smi: str, template: str) -> None: - """ - Initialize the PartialEngine. + """Initialize the PartialEngine. - Removes explicit hydrogens from the given template SMARTS. - Parses the cleaned template into an internal template structure (ITS). @@ -41,8 +41,7 @@ def __init__(self, smi: str, template: str) -> None: self.host = smiles_to_graph(smi) def fit(self, invert: bool = False) -> list[str]: - """ - Apply the template in one direction to generate radical‐wildcarded + """Apply the template in one direction to generate radical‐wildcarded reaction SMARTS (ITS). - Instantiates a SynReactor on the host graph and ITS. diff --git a/synkit/Synthesis/Reactor/rbl_engine.py b/synkit/Synthesis/Reactor/rbl_engine.py index 19f7ce0..f3d790a 100644 --- a/synkit/Synthesis/Reactor/rbl_engine.py +++ b/synkit/Synthesis/Reactor/rbl_engine.py @@ -1,37 +1,38 @@ from rdkit import Chem from synkit.IO import its_to_rsmi, rsmi_to_its, smiles_to_graph -from synkit.Chem.Reaction.rsmi_utils import remove_explicit_H_from_rsmi +from synkit.Chem.utils import remove_explicit_H_from_rsmi from synkit.Chem.Reaction.radical_wildcard import RadicalWildcardAdder from synkit.Synthesis.Reactor.syn_reactor import SynReactor from synkit.Graph.Wildcard.fuse_graph import fuse_wc_graphs, find_wc_graph_isomorphism class RBLEngine: - """ - Reaction-based Learning Engine that takes a reaction SMARTS (rsmi) and a + """Reaction-based Learning Engine that takes a reaction SMARTS (rsmi) and a transformation template, applies forward and backward synthesis via SynReactor, augments with radical wildcards, identifies wildcard-aware graph isomorphisms between forward and backward intermediates, and fuses matching graphs into new intermediates. - :param rsmi: Reaction SMARTS string in the form "Reactants>>Products". + :param rsmi: Reaction SMARTS string in the form + "Reactants>>Products". :type rsmi: str - :param template: A reaction template SMARTS string, may include explicit H. + :param template: A reaction template SMARTS string, may include + explicit H. :type template: str """ def __init__(self, rsmi: str, template: str) -> None: - """ - Initialize the RBLEngine with a reaction SMARTS and a template. + """Initialize the RBLEngine with a reaction SMARTS and a template. - Cleans explicit hydrogens in the template, parses the template into - an ITS (internal template structure), and converts the reactant and - product SMARTS into graph representations for host forward and - backward graphs. + Cleans explicit hydrogens in the template, parses the template + into an ITS (internal template structure), and converts the + reactant and product SMARTS into graph representations for host + forward and backward graphs. :param rsmi: Reaction SMARTS string "Reactants>>Products". :type rsmi: str - :param template: Reaction template SMARTS, possibly with explicit H. + :param template: Reaction template SMARTS, possibly with + explicit H. :type template: str """ self.rsmi = rsmi @@ -43,8 +44,7 @@ def __init__(self, rsmi: str, template: str) -> None: self.host_bw = smiles_to_graph(p) def _fw(self): - """ - Generate forward reaction intermediates using SynReactor, then apply + """Generate forward reaction intermediates using SynReactor, then apply RadicalWildcardAdder to each reaction SMARTS, and return their ITS representations. @@ -59,9 +59,9 @@ def _fw(self): return [rsmi_to_its(rxn) for rxn in fw] def _bw(self): - """ - Generate backward reaction intermediates by inverting the template in - SynReactor, apply RadicalWildcardAdder, and return ITS representations. + """Generate backward reaction intermediates by inverting the template + in SynReactor, apply RadicalWildcardAdder, and return ITS + representations. :returns: List of ITS objects for backward intermediates. :rtype: List @@ -79,8 +79,8 @@ def _bw(self): return [rsmi_to_its(rxn) for rxn in bw] def fit(self): - """ - Attempt to fuse forward and backward ITS graphs into new intermediates. + """Attempt to fuse forward and backward ITS graphs into new + intermediates. For each forward ITS and backward ITS pair: 1. Find a wildcard-aware graph isomorphism. @@ -109,8 +109,7 @@ def fit(self): @staticmethod def remove_explicith_rsmi(rsmi: str) -> str: - """ - Strip any explicit hydrogens from a reaction SMARTS string. + """Strip any explicit hydrogens from a reaction SMARTS string. :param rsmi: Reaction SMARTS "Reactants>>Products". :type rsmi: str diff --git a/synkit/Synthesis/Reactor/rule_filter.py b/synkit/Synthesis/Reactor/rule_filter.py index 2064e3d..6e174ad 100644 --- a/synkit/Synthesis/Reactor/rule_filter.py +++ b/synkit/Synthesis/Reactor/rule_filter.py @@ -8,28 +8,30 @@ class RuleFilter: - """ - Filter a host graph by a list of transformation rules (patterns), - keeping only those rules whose (decomposed) pattern appears as a - subgraph in the host. + """Filter a host graph by a list of transformation rules (patterns), + keeping only those rules whose (decomposed) pattern appears as a subgraph + in the host. - :param host_graph: The host graph to search within (will be converted to explicit H). + :param host_graph: The host graph to search within (will be + converted to explicit H). :type host_graph: nx.Graph :param rules_list: A list of rule objects to filter against. :type rules_list: list - :param invert: If True, use the "modifier" component of each decomposition; otherwise use the normal part. + :param invert: If True, use the "modifier" component of each + decomposition; otherwise use the normal part. :type invert: bool - :param engine: Matching engine to use: "turbo", "sing", "nx", or "mod". + :param engine: Matching engine to use: "turbo", "sing", "nx", or + "mod". :type engine: str :param node_label: Node attribute(s) for TurboISO to match on. :type node_label: str or list :param edge_label: Edge attribute(s) for TurboISO to match on. :type edge_label: str or list - :param distance_threshold: Threshold to skip distance filtering in TurboISO. + :param distance_threshold: Threshold to skip distance filtering in + TurboISO. :type distance_threshold: int :param sing_max_path: Maximum path length for SING engine. :type sing_max_path: int - :returns: An instance with only the rules that matched. :rtype: RuleFilter """ @@ -45,14 +47,14 @@ def __init__( distance_threshold: int = 5000, sing_max_path: int = 3, ) -> None: - """ - Initialize the RuleFilter and perform the filtering pass. + """Initialize the RuleFilter and perform the filtering pass. :param host_graph: The host graph to search within. :type host_graph: nx.Graph :param rules_list: A list of rule objects to filter against. :type rules_list: list - :param invert: If True, use the "modifier" component of each decomposition. + :param invert: If True, use the "modifier" component of each + decomposition. :type invert: bool :param engine: Matching engine to use. :type engine: str @@ -60,7 +62,8 @@ def __init__( :type node_label: str or list :param edge_label: Edge attribute(s) for TurboISO to match on. :type edge_label: str or list - :param distance_threshold: Threshold to skip distance filtering in TurboISO. + :param distance_threshold: Threshold to skip distance filtering + in TurboISO. :type distance_threshold: int :param sing_max_path: Maximum path length for SING engine. :type sing_max_path: int @@ -97,8 +100,7 @@ def __init__( self._new_rules = [r for r, m in zip(self._rules, self._matches) if m] def _match(self, pattern: nx.Graph) -> bool: - """ - Test whether the given pattern occurs as a subgraph in the host. + """Test whether the given pattern occurs as a subgraph in the host. :param pattern: The query graph pattern to match. :type pattern: nx.Graph @@ -120,8 +122,7 @@ def _match(self, pattern: nx.Graph) -> bool: @property def host(self) -> nx.Graph: - """ - The explicit host graph. + """The explicit host graph. :returns: The host graph used for matching. :rtype: nx.Graph @@ -130,8 +131,7 @@ def host(self) -> nx.Graph: @property def rules(self) -> List[Any]: - """ - Original list of rules provided. + """Original list of rules provided. :returns: The list of rules. :rtype: list @@ -140,8 +140,7 @@ def rules(self) -> List[Any]: @property def patterns(self) -> List[nx.Graph]: - """ - Decomposed subgraph queries used internally. + """Decomposed subgraph queries used internally. :returns: List of ITS-decomposed query graphs. :rtype: list of nx.Graph @@ -150,8 +149,7 @@ def patterns(self) -> List[nx.Graph]: @property def matches(self) -> List[bool]: - """ - Boolean list indicating which patterns were found. + """Boolean list indicating which patterns were found. :returns: List of booleans aligned with `patterns`. :rtype: list of bool @@ -160,8 +158,7 @@ def matches(self) -> List[bool]: @property def new_rules(self) -> List[Any]: - """ - Subset of rules for which `matches[i]` is True. + """Subset of rules for which `matches[i]` is True. :returns: Filtered list of matching rules. :rtype: list @@ -170,8 +167,7 @@ def new_rules(self) -> List[Any]: @property def engine(self) -> str: - """ - Matching engine in use. + """Matching engine in use. :returns: The name of the engine. :rtype: str @@ -179,8 +175,7 @@ def engine(self) -> str: return self._engine def __repr__(self) -> str: - """ - Concise representation of the filter. + """Concise representation of the filter. :returns: Representation string. :rtype: str @@ -192,8 +187,7 @@ def __repr__(self) -> str: ) def __help__(self) -> str: - """ - Return the class docstring for interactive help. + """Return the class docstring for interactive help. :returns: The class documentation. :rtype: str diff --git a/synkit/Synthesis/Reactor/single_predictor.py b/synkit/Synthesis/Reactor/single_predictor.py index 4d51fb0..b9db0e4 100644 --- a/synkit/Synthesis/Reactor/single_predictor.py +++ b/synkit/Synthesis/Reactor/single_predictor.py @@ -1,28 +1,23 @@ from typing import List, Any, Dict - -# from synkit.Synthesis.Reactor.core_engine import CoreEngine from synkit.Synthesis.Reactor.mod_reactor import MODReactor class SinglePredictor: - """ - A class designed for one-step chemical reaction predictions using transformation rules. + """A class designed for one-step chemical reaction predictions using + transformation rules. - This class utilizes transformation rules to predict the outcomes of chemical reactions based - on provided SMILES strings. + This class utilizes transformation rules to predict the outcomes of + chemical reactions based on provided SMILES strings. """ def __init__(self) -> None: - """ - Initializes the StepPredictor instance. - """ + """Initializes the StepPredictor instance.""" pass def _single_rule( self, smiles_list: List[str], rule: str, invert: bool = False ) -> List[Any]: - """ - Applies a single transformation rule to a list of SMILES strings. + """Applies a single transformation rule to a list of SMILES strings. This function applies the transformation rule to generate potential reaction outcomes from given SMILES strings. The results are returned and the memory is cleaned up immediately @@ -47,8 +42,7 @@ def _single_rule( def _multiple_rules( self, smiles_list: List[str], rules: List[str], invert: bool = False ) -> List[Any]: - """ - Applies multiple transformation rules to a list of SMILES strings. + """Applies multiple transformation rules to a list of SMILES strings. Parameters: - smiles_list (List[str]): The list of SMILES strings to process. @@ -72,8 +66,8 @@ def _perform( rule_key: str = "gml", invert: bool = False, ) -> List[Dict[str, Any]]: - """ - Performs prediction for each entry in the data using the specified rules. + """Performs prediction for each entry in the data using the specified + rules. Parameters: - data (List[Dict[str, Any]]): The dataset containing chemical reactions. diff --git a/synkit/Synthesis/Reactor/strategy.py b/synkit/Synthesis/Reactor/strategy.py index 79d8f00..edb0c7b 100644 --- a/synkit/Synthesis/Reactor/strategy.py +++ b/synkit/Synthesis/Reactor/strategy.py @@ -3,13 +3,12 @@ class Strategy(str, Enum): - """ - Strategy for sub-graph matching/application: + """Strategy for sub-graph matching/application: - - ALL: classic VF2 on the whole graph - - COMPONENT: component-aware only (no cross-CC backtracking) - - BACKTRACK: component-aware with backtracking across CCs - - PARTIAL: partial matching (mcs) + - ALL: classic VF2 on the whole graph + - COMPONENT: component-aware only (no cross-CC backtracking) + - BACKTRACK: component-aware with backtracking across CCs + - PARTIAL: partial matching (mcs) """ ALL = "all" @@ -19,8 +18,7 @@ class Strategy(str, Enum): @classmethod def from_string(cls, value: Union[str, "Strategy"]) -> "Strategy": - """ - Convert a string or Strategy to a Strategy enum. + """Convert a string or Strategy to a Strategy enum. Parameters ---------- diff --git a/synkit/Synthesis/Reactor/syn_reactor.py b/synkit/Synthesis/Reactor/syn_reactor.py index b096d1a..fb00feb 100644 --- a/synkit/Synthesis/Reactor/syn_reactor.py +++ b/synkit/Synthesis/Reactor/syn_reactor.py @@ -14,7 +14,7 @@ graph_to_smi, ) from synkit.IO import setup_logging -from synkit.Chem.Reaction.rsmi_utils import reverse_reaction +from synkit.Chem.utils import reverse_reaction from synkit.Rule import SynRule from synkit.Graph.syn_graph import SynGraph @@ -49,33 +49,33 @@ @dataclass class SynReactor: - """ - A hardened and typed re-write of the original SynReactor, preserving API compatibility - while offering safer, faster, and cleaner behavior. + """A hardened and typed re-write of the original SynReactor, preserving API + compatibility while offering safer, faster, and cleaner behavior. - :param substrate: The input reaction substrate, as a SMILES string, a raw NetworkX graph, - or a SynGraph. + :param substrate: The input reaction substrate, as a SMILES string, + a raw NetworkX graph, or a SynGraph. :type substrate: Union[str, nx.Graph, SynGraph] - :param template: Reaction template, provided as SMILES/SMARTS, a raw NetworkX graph, - or a SynRule. + :param template: Reaction template, provided as SMILES/SMARTS, a raw + NetworkX graph, or a SynRule. :type template: Union[str, nx.Graph, SynRule] - :param invert: Whether to invert the reaction (predict precursors). Defaults to False. + :param invert: Whether to invert the reaction (predict precursors). + Defaults to False. :type invert: bool - :param canonicaliser: Optional canonicaliser for intermediate graphs. If None, a default - GraphCanonicaliser is used. + :param canonicaliser: Optional canonicaliser for intermediate + graphs. If None, a default GraphCanonicaliser is used. :type canonicaliser: Optional[GraphCanonicaliser] - :param explicit_h: If True, render all hydrogens explicitly in the reaction‑center SMARTS. - Defaults to True. + :param explicit_h: If True, render all hydrogens explicitly in the + reaction‑center SMARTS. Defaults to True. :type explicit_h: bool - :param implicit_temp: If True, treat the input template as implicit-H (forces explicit_h=False). - Defaults to False. + :param implicit_temp: If True, treat the input template as + implicit-H (forces explicit_h=False). Defaults to False. :type implicit_temp: bool - :param strategy: Matching strategy, one of Strategy.ALL, 'comp', or 'bt'. - Defaults to Strategy.ALL. + :param strategy: Matching strategy, one of Strategy.ALL, 'comp', or + 'bt'. Defaults to Strategy.ALL. :type strategy: Strategy or str - :param partial: If True, use a partial matching fallback. Defaults to False. + :param partial: If True, use a partial matching fallback. Defaults + to False. :type partial: bool - :ivar _graph: Cached SynGraph for the substrate. :vartype _graph: Optional[SynGraph] :ivar _rule: Cached SynRule for the template. @@ -86,7 +86,8 @@ class SynReactor: :vartype _its: Optional[List[nx.Graph]] :ivar _smarts: Cached list of SMARTS strings. :vartype _smarts: Optional[List[str]] - :ivar _flag_pattern_has_explicit_H: Internal flag indicating explicit‑H constraints. + :ivar _flag_pattern_has_explicit_H: Internal flag indicating + explicit‑H constraints. :vartype _flag_pattern_has_explicit_H: bool """ @@ -108,8 +109,8 @@ class SynReactor: _flag_pattern_has_explicit_H: bool = field(init=False, default=False, repr=False) def __post_init__(self) -> None: - """ - Validate and enforce consistency of `explicit_h` and `implicit_temp`. + """Validate and enforce consistency of `explicit_h` and + `implicit_temp`. :raises ValueError: If `explicit_h` is True while `implicit_temp` is False. """ @@ -169,8 +170,7 @@ def from_smiles( # ------------------------------------------------------------------ @property def graph(self) -> SynGraph: # noqa: D401 – read‑only property - """ - Lazily wrap the substrate into a SynGraph. + """Lazily wrap the substrate into a SynGraph. :returns: The reaction substrate as a `SynGraph`. :rtype: SynGraph @@ -181,8 +181,7 @@ def graph(self) -> SynGraph: # noqa: D401 – read‑only property @property def rule(self) -> SynRule: # noqa: D401 - """ - Lazily wrap the template into a SynRule. + """Lazily wrap the template into a SynRule. :returns: The reaction template as a `SynRule`. :rtype: SynRule @@ -196,8 +195,7 @@ def rule(self) -> SynRule: # noqa: D401 # ------------------------------------------------------------------ @property def mappings(self) -> List[MappingDict]: - """ - Find subgraph mappings between substrate and template. + """Find subgraph mappings between substrate and template. :returns: A list of node-mapping dictionaries. :rtype: list of dict @@ -235,8 +233,7 @@ def mappings(self) -> List[MappingDict]: @property def its_list(self) -> List[nx.Graph]: - """ - Build ITS graphs for each subgraph mapping. + """Build ITS graphs for each subgraph mapping. :returns: A list of ITS (Internal Transition State) graphs. :rtype: list of networkx.Graph @@ -264,8 +261,7 @@ def its_list(self) -> List[nx.Graph]: @property def smarts_list(self) -> List[str]: - """ - Serialise each ITS graph to a reaction-SMARTS string. + """Serialise each ITS graph to a reaction-SMARTS string. :returns: A list of SMARTS strings (inverted if `invert=True`). :rtype: list of str @@ -576,4 +572,4 @@ def _to_smarts(its: nx.Graph) -> str: p_smi = graph_to_smi(right) if r_smi is None or p_smi is None: return None - return f"{r_smi}>>{p_smi}" + return f"{r_smi}>>{p_smi}" \ No newline at end of file diff --git a/synkit/Synthesis/reactor_utils.py b/synkit/Synthesis/reactor_utils.py index 6902f54..67cf1e2 100644 --- a/synkit/Synthesis/reactor_utils.py +++ b/synkit/Synthesis/reactor_utils.py @@ -6,8 +6,8 @@ def _get_unique_aam(list_aam: list) -> list: - """ - Retrieves the unique atom-atom mappings (AAM) by clustering a list of ITS graphs. + """Retrieves the unique atom-atom mappings (AAM) by clustering a list of + ITS graphs. This function first converts each item in the provided list of AAM strings to an ITS graph using the `rsmi_to_its` function. Then, it performs iterative clustering of the ITS graphs @@ -39,8 +39,7 @@ def _get_unique_aam(list_aam: list) -> list: def _deduplicateGraphs(initial) -> list: - """ - Deduplicates a list of molecular graphs by checking for isomorphisms. + """Deduplicates a list of molecular graphs by checking for isomorphisms. This method checks each graph in the `initial` list against the others for isomorphism, and removes duplicates by keeping only one representative for each unique graph. @@ -67,9 +66,9 @@ def _deduplicateGraphs(initial) -> list: def _get_connected_subgraphs(gml: str, invert: bool = False): - """ - Given a GML string, this function returns the number of connected subgraphs based - on the 'smart' representation split or a list of subgraphs, depending on the invert flag. + """Given a GML string, this function returns the number of connected + subgraphs based on the 'smart' representation split or a list of subgraphs, + depending on the invert flag. Parameters: - gml: str, the GML string to be converted into a 'smart' format. @@ -103,8 +102,8 @@ def _get_connected_subgraphs(gml: str, invert: bool = False): def _get_reagent(original_smiles: list, output_rsmi: str, invert: bool = False): - """ - Identifies reagents present in the original SMILES list that are absent in the processed output SMILES string. + """Identifies reagents present in the original SMILES list that are absent + in the processed output SMILES string. Parameters: - original_smiles: list of SMILES strings representing the original reagents. @@ -132,10 +131,9 @@ def _get_reagent(original_smiles: list, output_rsmi: str, invert: bool = False): def _get_reagent_rsmi(rsmi: str) -> List[str]: - """ - Identifies reagents that appear in both the reactant and product sides of - a reaction SMILES string, suggesting these elements are unchanged - by the chemical reaction. + """Identifies reagents that appear in both the reactant and product sides + of a reaction SMILES string, suggesting these elements are unchanged by the + chemical reaction. Parameters: - rsmi (str): A reaction SMILES string formatted as "reactants>>products". @@ -165,8 +163,8 @@ def _get_reagent_rsmi(rsmi: str) -> List[str]: def _remove_reagent(rsmi: str) -> str: - """ - Removes common molecules from the reactants and products in a SMILES reaction string. + """Removes common molecules from the reactants and products in a SMILES + reaction string. This function identifies the molecules that appear on both sides of the reaction (reactants and products) and removes one occurrence of each common molecule from @@ -224,8 +222,8 @@ def _remove_reagent(rsmi: str) -> str: def _add_reagent(rsmi: str, reagents: list): - """ - Modifies the SMILES representation of a reaction by adding additional reagents. + """Modifies the SMILES representation of a reaction by adding additional + reagents. Parameters: - rsmi: str, the SMILES reaction string, expected to contain '>>' separating reactants and products. @@ -255,8 +253,7 @@ def _add_reagent(rsmi: str, reagents: list): def _calculate_max_depth(reaction_tree, current_node=None, depth=0): - """ - Calculate the maximum depth of a reaction tree. + """Calculate the maximum depth of a reaction tree. Parameters: - reaction_tree (dict): A dictionary where keys are reaction SMILES (RSMI) @@ -293,8 +290,8 @@ def _find_all_paths( current_depth=0, path=None, ): - """ - Recursively find all paths from the root to the maximum depth in the reaction tree. + """Recursively find all paths from the root to the maximum depth in the + reaction tree. Parameters: - reaction_tree (dict): A dictionary of reaction SMILES with products. diff --git a/synkit/Utils/utils.py b/synkit/Utils/utils.py index 4a89977..10ee153 100644 --- a/synkit/Utils/utils.py +++ b/synkit/Utils/utils.py @@ -12,8 +12,7 @@ def stratified_random_sample( seed: Optional[int] = 42, bypass: bool = False, ) -> List[Dict[str, any]]: - """ - Stratifies and samples data from a list of dictionaries based on a + """Stratifies and samples data from a list of dictionaries based on a specified property key. Parameters: @@ -66,8 +65,7 @@ def stratified_random_sample( def calculate_processing_time(start_time_str: str, end_time_str: str) -> float: - """ - Calculates the processing time in seconds between two timestamps. + """Calculates the processing time in seconds between two timestamps. Parameters: - start_time_str (str): A string representing the start time in the format @@ -94,9 +92,9 @@ def calculate_processing_time(start_time_str: str, end_time_str: str) -> float: def remove_explicit_hydrogen( Graph: nx.Graph, excluded_indices: Iterable[int] ) -> nx.Graph: - """ - Processes a molecular graph by calculating hydrogen count ('h_count') for each node and - removing hydrogen nodes that are not specified in the excluded indices. + """Processes a molecular graph by calculating hydrogen count ('h_count') + for each node and removing hydrogen nodes that are not specified in the + excluded indices. Parameters ---------- @@ -141,11 +139,10 @@ def remove_explicit_hydrogen( def fix_implicit_hydrogen(Graph: nx.Graph, indices: Iterable[int]) -> nx.Graph: - """ - Adjusts the 'h_count' attribute of specific nodes in a molecular graph, - decreasing it based on the presence of neighboring hydrogen atoms that are also - included in the specified indices. This function works on a copy - of the provided graph and returns the modified copy. + """Adjusts the 'h_count' attribute of specific nodes in a molecular graph, + decreasing it based on the presence of neighboring hydrogen atoms that are + also included in the specified indices. This function works on a copy of + the provided graph and returns the modified copy. Parameters ---------- diff --git a/synkit/Vis/embedding.py b/synkit/Vis/embedding.py index 29514a8..3121b6a 100644 --- a/synkit/Vis/embedding.py +++ b/synkit/Vis/embedding.py @@ -11,8 +11,8 @@ def __init__( verbose: int = 0, custom_tsne_params: Optional[Dict] = None, ) -> None: - """ - Initialize the Embedding class with options for caching directory, verbosity, and custom t-SNE parameters. + """Initialize the Embedding class with options for caching directory, + verbosity, and custom t-SNE parameters. Parameters: cache_dir (str): Directory where cached results are stored. @@ -32,8 +32,7 @@ def __init__( self.tsne_params = self.default_tsne_params.copy() def set_tsne_params(self, **params) -> None: - """ - Sets parameters for t-SNE computations. + """Sets parameters for t-SNE computations. Parameters: **params: Arbitrary number of parameters for t-SNE. @@ -41,14 +40,12 @@ def set_tsne_params(self, **params) -> None: self.tsne_params.update(params) def reset_tsne_params(self) -> None: - """ - Resets t-SNE parameters to default values. - """ + """Resets t-SNE parameters to default values.""" self.tsne_params = self.default_tsne_params.copy() def _compute_tsne(self, X: np.ndarray) -> np.ndarray: - """ - Direct computation of the t-SNE embedding with the current parameters. + """Direct computation of the t-SNE embedding with the current + parameters. Parameters: X (np.ndarray): High-dimensional data points. @@ -60,8 +57,7 @@ def _compute_tsne(self, X: np.ndarray) -> np.ndarray: return tsne.fit_transform(X) def compute_tsne(self, X: np.ndarray, cache: bool = True) -> np.ndarray: - """ - Computes or retrieves the t-SNE embedding from cache. + """Computes or retrieves the t-SNE embedding from cache. Parameters: X (np.ndarray): High-dimensional data points. @@ -77,8 +73,7 @@ def compute_tsne(self, X: np.ndarray, cache: bool = True) -> np.ndarray: @property def cache(self) -> Any: - """ - Decorator for caching the compute_tsne function. + """Decorator for caching the compute_tsne function. Returns: Callable: Cached function. @@ -86,7 +81,5 @@ def cache(self) -> Any: return self.memory.cache(self._compute_tsne) def clear_cache(self) -> None: - """ - Clears the cache directory. - """ + """Clears the cache directory.""" self.memory.clear() diff --git a/synkit/Vis/graph_visualizer.py b/synkit/Vis/graph_visualizer.py index de64811..f197c49 100644 --- a/synkit/Vis/graph_visualizer.py +++ b/synkit/Vis/graph_visualizer.py @@ -239,7 +239,8 @@ def plot_as_mol( ) def visualize_its(self, its: nx.Graph, **kwargs) -> plt.Figure: - """Return a Matplotlib Figure plotting the ITS graph without duplicate display.""" + """Return a Matplotlib Figure plotting the ITS graph without duplicate + display.""" # Temporarily disable interactive mode to prevent auto-display was_interactive = plt.isinteractive() plt.ioff() @@ -277,10 +278,8 @@ def help(self) -> None: ) def __repr__(self) -> str: - """ - Return a detailed representation of the GraphVisualizer, showing configured - node and edge attribute keys. - """ + """Return a detailed representation of the GraphVisualizer, showing + configured node and edge attribute keys.""" na = list(self._node_attributes.keys()) ea = list(self._edge_attributes.keys()) return f"GraphVisualizer(node_attributes={na!r}, " f"edge_attributes={ea!r})" @@ -294,8 +293,7 @@ def visualize_its_grid( figsize: tuple[float, float] = (12, 6), **kwargs, ) -> tuple[plt.Figure, list[list[plt.Axes]]]: - """ - Plot multiple ITS graphs in a grid layout. + """Plot multiple ITS graphs in a grid layout. Parameters ---------- diff --git a/synkit/Vis/pdf_writer.py b/synkit/Vis/pdf_writer.py index 6dfc4bc..3943c96 100644 --- a/synkit/Vis/pdf_writer.py +++ b/synkit/Vis/pdf_writer.py @@ -1,5 +1,6 @@ -""" -This module comprises several functions adapted from the work of Klaus Weinbauer. +"""This module comprises several functions adapted from the work of Klaus +Weinbauer. + The original code can be found at his GitHub repository: https://github.com/klausweinbauer/FGUtils. Adaptations were made to enhance functionality and integrate with other system components. """ @@ -11,8 +12,8 @@ class PdfWriter: - """ - A utility class to create PDF reports with plots from a list of figures or dynamically generated plots. + """A utility class to create PDF reports with plots from a list of figures + or dynamically generated plots. Parameters: - file (str): The file name of the output PDF. @@ -51,8 +52,7 @@ def __init__( self.show_progress = show_progress def plot(self, data: Union[List[plt.Figure], List], **kwargs): - """ - Generate plots from data or save pre-generated figures to the PDF. + """Generate plots from data or save pre-generated figures to the PDF. Parameters: - data (Union[List[matplotlib.figure.Figure], List]): Input data or list of figures. @@ -120,8 +120,7 @@ def plot(self, data: Union[List[plt.Figure], List], **kwargs): break def save_figure(self, figure: plt.Figure): - """ - Save a pre-generated matplotlib figure directly to the PDF. + """Save a pre-generated matplotlib figure directly to the PDF. Parameters: - figure (matplotlib.figure.Figure): The figure to save. @@ -134,8 +133,7 @@ def save_figure(self, figure: plt.Figure): self.pdf_pages.savefig(figure, bbox_inches="tight", pad_inches=1) def close(self): - """ - Close the PDF file, ensuring all pages are written. + """Close the PDF file, ensuring all pages are written. Returns: - None diff --git a/synkit/Vis/rule_vis.py b/synkit/Vis/rule_vis.py index 16c6c19..8a630f5 100644 --- a/synkit/Vis/rule_vis.py +++ b/synkit/Vis/rule_vis.py @@ -17,8 +17,9 @@ def __init__(self, backend: str = "nx") -> None: self.vis_graph = GraphVisualizer() def vis(self, input: Union[str, Tuple[nx.Graph, nx.Graph, nx.Graph]], **kwargs): - """ - Wrapper to select between nx_vis and mod_vis based on backend and input type. + """Wrapper to select between nx_vis and mod_vis based on backend and + input type. + Converts input as needed. """ if self.backend == "nx": @@ -60,11 +61,9 @@ def nx_vis( add_gridbox: bool = False, rule: bool = False, ) -> plt.Figure: - """ - Visualize reactants, ITS, and products side-by-side or vertically, - with interactive plotting turned off to prevent double-display, - and correct handling of matplotlib axes arrays. - """ + """Visualize reactants, ITS, and products side-by-side or vertically, + with interactive plotting turned off to prevent double-display, and + correct handling of matplotlib axes arrays.""" # Disable interactive mode & clear any leftover figures was_interactive = plt.isinteractive() plt.ioff() @@ -153,9 +152,7 @@ def nx_vis( plt.ion() def mod_vis(self, gml: str, path: str = "./") -> None: - """ - Simple MOD visualization via mod_post CLI. - """ + """Simple MOD visualization via mod_post CLI.""" from mod import ruleGMLString rule = ruleGMLString(gml, add=False) @@ -165,10 +162,7 @@ def mod_vis(self, gml: str, path: str = "./") -> None: self.post() def post(self) -> None: - """ - Generate an external report via the `mod_post` CLI. - - """ + """Generate an external report via the `mod_post` CLI.""" try: subprocess.run(["mod_post"], check=True) except subprocess.CalledProcessError as e: diff --git a/synkit/Vis/rxn_vis.py b/synkit/Vis/rxn_vis.py index 9e1cf74..2f0357d 100644 --- a/synkit/Vis/rxn_vis.py +++ b/synkit/Vis/rxn_vis.py @@ -17,8 +17,7 @@ def __init__( atom_label_font_size: int = 12, show_atom_map: bool = False, ): - """ - Initialize the reaction/molecule visualizer. + """Initialize the reaction/molecule visualizer. Parameters ---------- @@ -49,8 +48,7 @@ def __init__( def render( self, smiles: str, return_bytes: bool = False ) -> Union[Image.Image, bytes]: - """ - Render a molecule or reaction SMILES to a cropped PNG. + """Render a molecule or reaction SMILES to a cropped PNG. Parameters ---------- @@ -116,8 +114,7 @@ def render( return png if return_bytes else img def save_png(self, smiles: str, path: str) -> None: - """ - Render and save as a PNG file. + """Render and save as a PNG file. Parameters ---------- @@ -130,8 +127,7 @@ def save_png(self, smiles: str, path: str) -> None: img.save(path, format="PNG") def save_pdf(self, smiles: str, path: str, resolution: float = 300.0) -> None: - """ - Render and save as a single‐page PDF. + """Render and save as a single‐page PDF. Parameters ---------- From 0eb22366726c9c3d9c78dcad86163b32c0cada83 Mon Sep 17 00:00:00 2001 From: Tieu Long Phan <125431507+TieuLongPhan@users.noreply.github.com> Date: Tue, 22 Jul 2025 11:20:16 +0200 Subject: [PATCH 2/5] Prepare release bump version (#27) * update refractor * fix reactor * fix lint * prepare bechmark * refractor MODReactor and SynReactor * update crn * refractor cluster, change to matcher, fix code synreactor, now resnow comparable to modreactor * test 3 os * test * test * fix workflow * fix lint * fix win * fix win * fix win again * update smart * add synreactor implicit hydrogen * fix mcsmatcher * refractor visualization * fix conflict rdkit, upgrade to 2025.3.1 * fix lint * move aam_validator to Chem submodule * fix lint * prepare benchmark matcher * change backend rule to mod * prepare doc * add doc * update fih * update graph module doc * update doc * prepare release * . * fix doc * clean doc * fix docstring * fix tutorial * update fig * update explicit_hydrogen for its * prepare release * build doc * fix lint * fix bug in explicit hydrogen for ITS * fix build * fix * fix * fix * fix doc * update nauty canon, rule filters, change benchmark * prepare release * fix bug in nauty alg * update doc * add features for expanding its * add rule_matcher.py * add testcase rule matcher * add wildcard for smiles * add partial engine * update new features * update document * add data * update Chem features * format docstring and refractor Chem module * add auto-test pypi * create dependabot * test run yml * test * test docker * add docker * add docker * . * add readme * rename * release docker * remove redundant file * fix lint * fix bug * fix lint * add partial its beta * test publising conda * test publising conda * fix workflow * fix workflow * fix workflow * fix workflow again * tes pre-release * tes pre-release * tes pre-release * tes pre-release * tes pre-release * tes pre-release * fix meta.yaml * fix meta.yaml * fix meta.yaml * fix meta.yaml * fix meta.yaml * fix meta.yaml * publish beta * publish beta * debug * debug * debug * debug * debug * debug * debug * debug * prepare release --- .github/workflows/conda-forge-publish.yml | 109 +++ pyproject.toml | 4 +- recipe/meta.yaml | 37 + synkit/Graph/Canon/nauty.py | 4 +- synkit/Graph/ITS/partial_its.py | 238 +++++++ synkit/Graph/Matcher/graph_morphism.py | 813 +++++++++------------- synkit/IO/graph_to_mol.py | 4 +- 7 files changed, 728 insertions(+), 481 deletions(-) create mode 100644 .github/workflows/conda-forge-publish.yml create mode 100644 recipe/meta.yaml create mode 100644 synkit/Graph/ITS/partial_its.py diff --git a/.github/workflows/conda-forge-publish.yml b/.github/workflows/conda-forge-publish.yml new file mode 100644 index 0000000..12fd6a2 --- /dev/null +++ b/.github/workflows/conda-forge-publish.yml @@ -0,0 +1,109 @@ +name: Publish to conda-forge + +on: + release: + types: [published] + push: + branches: [refactor] + +permissions: + contents: read + +jobs: + build: + runs-on: ubuntu-latest + outputs: + pkg_paths: ${{ steps.build.outputs.paths }} + steps: + - name: Checkout code + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Setup Miniconda + uses: conda-incubator/setup-miniconda@v2 + with: + channels: conda-forge + auto-update-conda: true + auto-activate-base: true + + - name: Create & activate build env + shell: bash -l {0} + run: | + conda create -n build python=3.11 'conda-build>=3.21' -c conda-forge -y + conda activate build + + - id: build + name: Build conda package + shell: bash -l {0} + env: + GITHUB_RUN_NUMBER: ${{ github.run_number }} + run: | + conda activate build + rm -rf conda-bld && mkdir conda-bld + + conda-build recipe --output-folder ./conda-bld + + echo "DEBUG: Built files:" + find conda-bld -type f \( -name "*.conda" -o -name "*.tar.bz2" \) -print + + files=$(find conda-bld -type f \( -name "*.conda" -o -name "*.tar.bz2" \) -print | tr '\n' ' ') + echo "paths=$files" >> $GITHUB_OUTPUT + + - name: Upload built packages as artifact + uses: actions/upload-artifact@v4 + with: + name: conda-packages + path: conda-bld/ + + publish_release: + needs: build + if: github.event_name == 'release' + runs-on: ubuntu-latest + steps: + - name: Download built packages + uses: actions/download-artifact@v4 + with: + name: conda-packages + path: conda-bld + + - name: Install Anaconda Client + run: python3 -m pip install --upgrade anaconda-client + + - name: Upload to conda-forge / main + env: + ANACONDA_TOKEN: ${{ secrets.ANACONDA_TOKEN }} + run: | + for pkg in ${{ needs.build.outputs.pkg_paths }}; do + anaconda -t "$ANACONDA_TOKEN" upload \ + --user tieulongphan \ + --label main \ + --no-progress \ + "$pkg" + done + + publish_beta: + needs: build + if: github.event_name == 'push' && github.ref == 'refs/heads/refactor' + runs-on: ubuntu-latest + steps: + - name: Download built packages + uses: actions/download-artifact@v4 + with: + name: conda-packages + path: conda-bld + + - name: Install Anaconda Client + run: python3 -m pip install --upgrade anaconda-client + + - name: Upload to conda-forge / beta + env: + ANACONDA_TOKEN: ${{ secrets.ANACONDA_TOKEN }} + run: | + for pkg in ${{ needs.build.outputs.pkg_paths }}; do + anaconda -t "$ANACONDA_TOKEN" upload \ + --user tieulongphan \ + --label beta \ + --no-progress \ + "$pkg" + done diff --git a/pyproject.toml b/pyproject.toml index 688cc3d..38f4bfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,9 @@ build-backend = "hatchling.build" [project] name = "synkit" -version = "0.0.11" +version = "0.0.12" +license = { text = "MIT" } +license-files = ["LICENSE"] authors = [ {name="Tieu Long Phan", email="tieu@bioinf.uni-leipzig.de"} ] diff --git a/recipe/meta.yaml b/recipe/meta.yaml new file mode 100644 index 0000000..5b6d4fb --- /dev/null +++ b/recipe/meta.yaml @@ -0,0 +1,37 @@ +package: + name: synkit + version: 0.0.12 + +source: + path: .. + +build: + noarch: python + number: 0 + script: "{{ PYTHON }} -m pip install . --no-deps -vv" + +requirements: + host: + - python >=3.11,<3.12 # force a 3.11 build env + - pip + - hatchling # your PEP 517 build backend + run: + - python >=3.11,<3.12 + - scikit-learn >=1.4.0 + - pandas >=1.5.3 + - rdkit >=2025.3.1 + - networkx >=3.3 + - seaborn >=0.13.2 + - requests >=2.32.3 + - regex >=2024.11.6 + - numpy >=2.2.0 + +about: + home: https://github.com/TieuLongPhan/SynKit + license: MIT + license_file: LICENSE + summary: Utility for reaction modeling using graph grammar + +extra: + recipe-maintainers: + - TieuLongPhan diff --git a/synkit/Graph/Canon/nauty.py b/synkit/Graph/Canon/nauty.py index 6a0333c..f2d3405 100644 --- a/synkit/Graph/Canon/nauty.py +++ b/synkit/Graph/Canon/nauty.py @@ -7,7 +7,7 @@ class NautyCanonicalizer: - """Perform Nauty‑style canonicalization of a NetworkX graph, optionally + """Perform Nauty-style canonicalization of a NetworkX graph, optionally refining and distinguishing nodes and edges by specified attributes, and extracting automorphisms, orbits, and canonical permutations. @@ -317,4 +317,4 @@ def union_orbits(i, j): def graph_signature(self, G): G_canon = self.canonical_form(G) label = self._build_label(G_canon, sorted(G_canon.nodes())) - return hashlib.sha256(label.encode("utf-8")).hexdigest() \ No newline at end of file + return hashlib.sha256(label.encode("utf-8")).hexdigest() diff --git a/synkit/Graph/ITS/partial_its.py b/synkit/Graph/ITS/partial_its.py new file mode 100644 index 0000000..ee73d5a --- /dev/null +++ b/synkit/Graph/ITS/partial_its.py @@ -0,0 +1,238 @@ +import networkx as nx +from typing import Dict, Any, Optional, Tuple, Hashable + + +class PartialITS: + """Utility class for building **partial** Imaginary‑Transition‑State (ITS) + graphs from a pair of reactant/product `networkx` graphs. + + The resulting ITS graph contains + + * a **union** of nodes from *G* (reactant) and *H* (product), + * a per‑node attribute ``typesGH`` – a 2‑tuple ``(attrs_from_G, attrs_from_H)`` – + where missing sides are filled by the present one, + * edges categorised as **unchanged**, **broken** or **formed** and stored as + an ``order`` tuple ``(o_G, o_H)``, and + * a convenience edge attribute ``standard_order = o_G - o_H`` (optionally + zeroed when |Δ| < 1 to ignore aromaticity changes). + """ + + # --------------------------------------------------------------------- + # Helper – node‐attribute retrieval + # --------------------------------------------------------------------- + @staticmethod + def _get_node_attr_tuple( + graph: nx.Graph, + node: Hashable, + defaults: Dict[str, Any], + ) -> Tuple[Any, ...]: + """Return a tuple containing *all* attributes in *defaults* order. + + :param graph: graph to query. + :param node: node identifier. + :param defaults: mapping of attribute → default value. + :returns: tuple in the order of *defaults.keys()*. + """ + return tuple( + graph.nodes[node].get(attr, default) for attr, default in defaults.items() + ) + + # ------------------------------------------------------------------ + # Helper – standard_order + # ------------------------------------------------------------------ + @staticmethod + def _attach_standard_order( + graph: nx.Graph, + ignore_aromaticity: bool = False, + ) -> nx.Graph: + """Attach ``standard_order`` edge attribute in‑place. + + :param graph: ITS graph with ``order`` tuples. + :param ignore_aromaticity: if *True*, set Δ=0 when |Δ|<1. + :returns: *graph* (for chaining). + """ + for u, v, data in graph.edges(data=True): + o_g, o_h = data.get("order", (0, 0)) + delta = o_g - o_h + if ignore_aromaticity and abs(delta) < 1: + delta = 0 + graph[u][v]["standard_order"] = delta + return graph + + # ------------------------------------------------------------------ + # Helper – edge insertion logic + # ------------------------------------------------------------------ + @staticmethod + def _populate_edges( + its: nx.Graph, + G: nx.Graph, + H: nx.Graph, + ) -> None: + """Populate *its* with ``order`` tuples following the rules. + + Unchanged (present in both): ``(o, o)`` + Broken (present only in G): ``(o, 0)`` *when one end survives* + Formed (present only in H): ``(0, o)`` *when one end survives* + Unchanged-external (only in one, *both* ends external): ``(o, o)`` + """ + common = set(G.nodes()) & set(H.nodes()) + seen: set[Tuple[Hashable, Hashable]] = set() + + def add(u: Hashable, v: Hashable, order: Tuple[float, float]): + if (u, v) in seen or (v, u) in seen: + return + its.add_edge(u, v, order=order) + seen.add((u, v)) + + # Pass 1 – edges from G + for u, v, d in G.edges(data=True): + o_g = d.get("order", 0) + if H.has_edge(u, v): # unchanged (core) + add(u, v, (o_g, o_g)) + else: + if (u in common) ^ (v in common): # broken + add(u, v, (o_g, 0)) + else: # unchanged non‑core (G only) + add(u, v, (o_g, o_g)) + + # Pass 2 – edges unique to H + for u, v, d in H.edges(data=True): + if G.has_edge(u, v): + continue # already handled + o_h = d.get("order", 0) + if (u in common) ^ (v in common): # formed + add(u, v, (0, o_h)) + else: # unchanged non‑core (H only) + add(u, v, (o_h, o_h)) + + @staticmethod + def balance_valences(graph: nx.Graph) -> nx.Graph: + """ + Balances valences in a NetworkX graph by adding wildcard '*' nodes for atoms + that have missing bonds according to their broken bonds and hydrogen counts. + + :param graph: NetworkX Graph with node attributes: + - element: str, chemical symbol + - charge: int, formal charge + - typesGH: tuple of descriptors (element, aromatic, hcount, h_change, connections) + - atom_map: int, unique identifier (node key) + :type graph: nx.Graph + :return: Modified graph with wildcard nodes added + :rtype: nx.Graph + """ + # Copy to avoid modifying the original + G = graph.copy() + # Determine next wildcard index (integer keys only) + existing_ids = [n for n in G.nodes if isinstance(n, int)] + next_id = max(existing_ids, default=0) + 1 + + # Iterate over original nodes + for atom in list(G.nodes): + data = G.nodes[atom] + # Skip wildcards + if data.get("element") == "*": + continue + # Sum of positive standard_order values (broken bonds) + broken = sum( + d.get("standard_order", 0) + for _, _, d in G.edges(atom, data=True) + if d.get("standard_order", 0) > 0 + ) + if broken <= 0: + continue + # Available hydrogen counts from typesGH descriptors (index 2) + h_counts = [desc[2] for desc in data["typesGH"]] + # If any descriptor has hydrogen >= broken, no wildcard needed + if max(h_counts, default=0) >= broken: + continue + # Need wildcard for remaining broken bonds + wc_id = next_id + next_id += 1 + # Add wildcard node with two GH types: one providing valence and one default + G.add_node( + wc_id, + element="*", + charge=0, + typesGH=(("*", False, broken, 0, []), ("*", False, 0, 0, [])), + atom_map=wc_id, + ) + # Forming bond with wildcard: dynamic order=broken, negative standard_order + G.add_edge( + atom, wc_id, order=(0.0, float(broken)), standard_order=-float(broken) + ) + return G + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + @staticmethod + def construct( + G: nx.Graph, + H: nx.Graph, + *, + ignore_aromaticity: bool = False, + attributes_defaults: Optional[Dict[str, Any]] = None, + balance: bool = True, + ) -> nx.Graph: + """Return a partial ITS graph for *G* → *H*. + + :param G: reactant graph. + :param H: product graph. + :keyword ignore_aromaticity: if *True*, set ``standard_order`` to 0 when + |Δ|<1. + :keyword attributes_defaults: mapping of attribute → default value used + for the ``typesGH`` tuples. If *None*, a + small sensible default set is used. + :returns: an ITS graph with nodes, ``typesGH`` tuples and annotated + edges. + """ + # ------------------------------------------------------------------ + # Set defaults + # ------------------------------------------------------------------ + if attributes_defaults is None: + attributes_defaults = { + "element": "*", + "aromatic": False, + "hcount": 0, + "charge": 0, + "neighbors": [], + } + + # ------------------------------------------------------------------ + # Build node union + # ------------------------------------------------------------------ + its = nx.Graph() + its.add_nodes_from(G.nodes(data=True)) + its.add_nodes_from((n, d) for n, d in H.nodes(data=True) if n not in its) + + # ------------------------------------------------------------------ + # typesGH per node + # ------------------------------------------------------------------ + types: Dict[Hashable, Tuple[Tuple[Any, ...], Tuple[Any, ...]]] = {} + for n in its.nodes(): + in_g, in_h = n in G.nodes(), n in H.nodes() + attrs_g = PartialITS._get_node_attr_tuple( + G if in_g else H, n, attributes_defaults + ) + attrs_h = ( + PartialITS._get_node_attr_tuple(H, n, attributes_defaults) + if in_h + else attrs_g + ) + if not in_h: + attrs_h = attrs_g + types[n] = (attrs_g, attrs_h) + nx.set_node_attributes(its, types, "typesGH") + + # ------------------------------------------------------------------ + # Edges with order tuples + # ------------------------------------------------------------------ + PartialITS._populate_edges(its, G, H) + + # ------------------------------------------------------------------ + # Attach standard_order and return + # ------------------------------------------------------------------ + its = PartialITS._attach_standard_order(its, ignore_aromaticity) + if balance: + its = PartialITS.balance_valences(its) + return its diff --git a/synkit/Graph/Matcher/graph_morphism.py b/synkit/Graph/Matcher/graph_morphism.py index c608acb..4a76ddf 100644 --- a/synkit/Graph/Matcher/graph_morphism.py +++ b/synkit/Graph/Matcher/graph_morphism.py @@ -1,516 +1,377 @@ -import re +import logging +import itertools +from operator import eq +from typing import Callable, Optional, Union, List, Any, Dict import networkx as nx -from typing import Optional, List - -from rdkit import Chem -from rdkit.Chem.MolStandardize import rdMolStandardize - -__all__ = ["get_rc", "its_decompose"] - - -def get_rc( - ITS: nx.Graph, - element_key: List[str] = ["element", "charge", "typesGH", "atom_map"], - bond_key: str = "order", - standard_key: str = "standard_order", - disconnected: bool = False, -) -> nx.Graph: - """Extract the reaction-center (RC) subgraph from an ITS graph. - - This function identifies: - 1. All bonds whose standard order (difference between ITS orders) is non-zero. - 2. All H–H bonds, ensuring they are included even if no order change is detected. - 3. (Optional) Additional nodes with charge changes and reconnection of edges - if `disconnected=True`. - - :param ITS: The integrated transition-state graph with composite node/edge attributes. - :type ITS: nx.Graph - :param element_key: List of node‐attribute keys to copy into the RC graph. - :type element_key: List[str] - :param bond_key: Edge attribute key representing the tuple of bond orders. - :type bond_key: str - :param standard_key: Edge attribute key for the computed standard_order. - :type standard_key: str - :param disconnected: If True, also include nodes with charge changes and - reconnect any ITS edges between RC nodes. - :type disconnected: bool - :returns: A new graph containing only the reaction-center nodes and edges. - :rtype: nx.Graph - - :example: - >>> ITS = nx.Graph() - >>> # ... populate ITS with 'order', 'standard_order', 'typesGH', etc. ... - >>> RC = get_rc(ITS, disconnected=True) - >>> isinstance(RC, nx.Graph) - True +from networkx.algorithms import isomorphism +from networkx.algorithms.isomorphism import GraphMatcher +from networkx.algorithms.isomorphism import generic_node_match, generic_edge_match + + +# Alias for any NetworkX graph type +graph_types = Union[nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph] + + +def find_graph_isomorphism( + G1: graph_types, + G2: graph_types, + node_match: Optional[Callable[[Dict[str, Any], Dict[str, Any]], bool]] = None, + edge_match: Optional[Callable[[Dict[str, Any], Dict[str, Any]], bool]] = None, + use_defaults: bool = True, + fast_invariant_check: bool = True, + logger: Optional[logging.Logger] = None, +) -> Optional[Dict[Any, Any]]: + """Check whether two graphs are isomorphic and return the node-mapping. + + :param G1: The first NetworkX graph to compare. + :type G1: nx.Graph | nx.DiGraph | nx.MultiGraph | nx.MultiDiGraph + :param G2: The second NetworkX graph to compare. + :type G2: nx.Graph | nx.DiGraph | nx.MultiGraph | nx.MultiDiGraph + :param node_match: Optional function taking two node attribute dicts + and returning True if they match. + :type node_match: callable or None + :param edge_match: Optional function taking two edge attribute dicts + and returning True if they match. + :type edge_match: callable or None + :param use_defaults: Whether to use default matchers when None. + :type use_defaults: bool + :param fast_invariant_check: Perform quick node/edge count and + degree sequence checks prior to matcher. + :type fast_invariant_check: bool + :param logger: Logger for debug messages. Defaults to root logger. + :type logger: logging.Logger or None + :returns: A dict mapping nodes in G1 to nodes in G2 if isomorphic; + otherwise None. + :rtype: dict[Any, Any] or None """ - rc = nx.Graph() - _add_bond_order_changes(ITS, rc, element_key, bond_key, standard_key) - - # 1.5) H-H bonds (force inclusion, with fallback typesGH) - for u, v, data in ITS.edges(data=True): - elem_u = ITS.nodes[u].get("element") - elem_v = ITS.nodes[v].get("element") - if elem_u == "H" and elem_v == "H": - for n in (u, v): - node_data = dict(ITS.nodes[n]) - if "typesGH" not in node_data: - node_data["typesGH"] = ( - ("H", False, 0, 0, []), - ("*", False, 0, 0, []), - ) - # Ensure typesGH is available even if not in original element_key - final_attrs = {k: node_data[k] for k in element_key if k in node_data} - final_attrs["typesGH"] = node_data["typesGH"] - rc.add_node(n, **final_attrs) - - rc.add_edge( - u, - v, - **{ - bond_key: data.get(bond_key), - standard_key: data.get(standard_key), - }, + log = logger or logging.getLogger(__name__) + + # 1) Ensure same graph type + if type(G1) is not type(G2): + log.debug("Graph types differ: %r vs %r", type(G1), type(G2)) + return None + + # 2) Quick invariants + if fast_invariant_check: + if G1.number_of_nodes() != G2.number_of_nodes(): + log.debug( + "Node counts differ: %d vs %d", + G1.number_of_nodes(), + G2.number_of_nodes(), ) - if disconnected: - _add_charge_change_nodes(ITS, rc, element_key) - _reconnect_rc_edges(ITS, rc, bond_key, standard_key) - - return rc - - -def _carry_node_attrs(src: nx.Graph, dst: nx.Graph, n: int, keys: List[str]) -> None: - """Copy node *n* from *src* to *dst* with only *keys* attributes.""" - if dst.has_node(n): - return - attrs = {k: src.nodes[n][k] for k in keys if k in src.nodes[n]} - dst.add_node(n, **attrs) - - -def _add_charge_change_nodes( - ITS: nx.Graph, - rc: nx.Graph, - keys: List[str], -) -> None: - """Step 3a – add nodes whose *typesGH* shows a charge change.""" - for n, data in ITS.nodes(data=True): - gh = data.get("typesGH") - if ( - isinstance(gh, (list, tuple)) - and len(gh) >= 2 - and gh[0][3] != gh[1][3] - and not rc.has_node(n) - ): - _carry_node_attrs(ITS, rc, n, keys) - - -def _reconnect_rc_edges( - ITS: nx.Graph, - rc: nx.Graph, - bond_key: str, - standard_key: str, -) -> None: - """Step 3b – re-add any original ITS edge between nodes already in RC.""" - for u, v, data in ITS.edges(data=True): - if rc.has_node(u) and rc.has_node(v) and not rc.has_edge(u, v): - rc.add_edge( - u, - v, - **{bond_key: data.get(bond_key), standard_key: data.get(standard_key)}, + return None + if G1.number_of_edges() != G2.number_of_edges(): + log.debug( + "Edge counts differ: %d vs %d", + G1.number_of_edges(), + G2.number_of_edges(), ) + return None + degs1 = sorted(d for _, d in G1.degree()) + degs2 = sorted(d for _, d in G2.degree()) + if degs1 != degs2: + log.debug("Degree sequences differ") + return None + + # 3) Default matchers + if use_defaults: + if node_match is None: + node_match = isomorphism.categorical_node_match( + ["element", "atom_map", "hcount"], ["*", 0, 0] + ) + if edge_match is None: + edge_match = isomorphism.categorical_edge_match("order", 1) + # 4) Select the correct matcher + if isinstance(G1, (nx.MultiGraph, nx.MultiDiGraph)): + if isinstance(G1, nx.MultiGraph): + Matcher = nx.algorithms.isomorphism.MultiGraphMatcher + else: + Matcher = nx.algorithms.isomorphism.MultiDiGraphMatcher + else: + if isinstance(G1, nx.Graph): + Matcher = nx.algorithms.isomorphism.GraphMatcher + else: + Matcher = nx.algorithms.isomorphism.DiGraphMatcher + + matcher = Matcher(G1, G2, node_match=node_match, edge_match=edge_match) + if matcher.is_isomorphic(): + log.debug("Graphs are isomorphic; mapping found") + return matcher.mapping + else: + log.debug("Graphs are not isomorphic") + return None + + +def graph_isomorphism( + graph_1: nx.Graph, + graph_2: nx.Graph, + node_match: Optional[Callable] = None, + edge_match: Optional[Callable] = None, + use_defaults: bool = False, +) -> bool: + """Determines if two graphs are isomorphic, considering provided node and + edge matching functions. Uses default matching settings if none are + provided. -def _add_bond_order_changes( - ITS: nx.Graph, - rc: nx.Graph, - keys: List[str], - bond_key: str, - standard_key: str, -) -> None: - """Step 1 – bond-order-change edges and their nodes.""" - for u, v, data in ITS.edges(data=True): - old, new = data.get(bond_key, (None, None)) - if old == new: - continue - for n in (u, v): - _carry_node_attrs(ITS, rc, n, keys) - rc.add_edge( - u, v, **{bond_key: data[bond_key], standard_key: data.get(standard_key)} - ) - + Parameters: + - graph_1 (nx.Graph): The first graph to compare. + - graph_2 (nx.Graph): The second graph to compare. + - node_match (Optional[Callable]): The function used to match nodes. + Uses default if None. + - edge_match (Optional[Callable]): The function used to match edges. + Uses default if None. -# def get_rc( -# ITS: nx.Graph, -# element_key: List[str] = ["element", "charge", "typesGH", "atom_map"], -# bond_key: str = "order", -# standard_key: str = "standard_order", -# disconnected: bool = False, -# ) -> nx.Graph: -# """ -# Extract the reaction center (RC) from ITS graph. - -# Enhancements: -# - Adds nodes and edges where bond order changes (core logic). -# - If disconnected=True: -# - Adds nodes with charge change based on typesGH. -# - Reconnects any ITS edge between two RC nodes. -# - NEW: Always includes H-H bonds in RC. Adds default typesGH if missing. -# """ -# rc = nx.Graph() - -# # 1) edges with bond-order change -# for u, v, data in ITS.edges(data=True): -# old, new = data.get(bond_key, [None, None]) -# if old != new: -# for n in (u, v): -# if not rc.has_node(n): -# rc.add_node( -# n, -# **{ -# k: ITS.nodes[n][k] for k in element_key if k in ITS.nodes[n] -# }, -# ) -# rc.add_edge( -# u, -# v, -# **{bond_key: data.get(bond_key), standard_key: data.get(standard_key)}, -# ) - -# # 1.5) H-H bonds (force inclusion, with fallback typesGH) -# for u, v, data in ITS.edges(data=True): -# elem_u = ITS.nodes[u].get("element") -# elem_v = ITS.nodes[v].get("element") -# if elem_u == "H" and elem_v == "H": -# for n in (u, v): -# node_data = dict(ITS.nodes[n]) -# if "typesGH" not in node_data: -# node_data["typesGH"] = ( -# ("H", False, 0, 0, []), -# ("*", False, 0, 0, []), -# ) -# # Ensure typesGH is available even if not in original element_key -# final_attrs = {k: node_data[k] for k in element_key if k in node_data} -# final_attrs["typesGH"] = node_data["typesGH"] -# rc.add_node(n, **final_attrs) - -# rc.add_edge( -# u, -# v, -# **{ -# bond_key: data.get(bond_key), -# standard_key: data.get(standard_key), -# }, -# ) - -# if disconnected: -# # 2) nodes with typesGH-based charge change -# for n, data in ITS.nodes(data=True): -# gh = data.get("typesGH") -# if ( -# isinstance(gh, (list, tuple)) -# and len(gh) >= 2 -# and len(gh[0]) > 3 -# and len(gh[1]) > 3 -# and gh[0][3] != gh[1][3] -# ): -# if not rc.has_node(n): -# rc.add_node(n, **{k: data[k] for k in element_key if k in data}) - -# # 3) reconnect RC nodes -# for u, v, data in ITS.edges(data=True): -# if rc.has_node(u) and rc.has_node(v) and not rc.has_edge(u, v): -# rc.add_edge( -# u, -# v, -# **{ -# bond_key: data.get(bond_key), -# standard_key: data.get(standard_key), -# }, -# ) - -# return rc - - -# def get_rc( -# ITS: nx.Graph, -# element_key: List[str] = ["element", "charge", "typesGH", "atom_map"], -# bond_key: str = "order", -# standard_key: str = "standard_order", -# disconnected: bool = False, -# ) -> nx.Graph: -# """ -# Extract the reaction center (RC) from ITS by: - -# 1. Always adding any edge whose bond order changes -# (bond_key[0] != bond_key[1]), plus its two end-nodes. -# 2. [if disconnected=True] Adding any node whose 'typesGH' record shows a charge change -# (typesGH[0][3] != typesGH[1][3]), even if isolated. -# 3. [if disconnected=True] Re-adding any ITS edge between two nodes already in RC -# (to preserve connectivity), carrying over bond_key & standard_key. - -# Parameters: -# - ITS (nx.Graph): input ITS graph. -# - element_key (List[str]): node attrs to carry over. -# - bond_key (str): edge attr key for bond order. -# - standard_key (str): edge attr key for standard order. -# - disconnected (bool): if True, include “charge-change” nodes (step 2) and -# reconnect any edges among RC nodes (step 3). If False, only performs step 1. -# """ -# rc = nx.Graph() - -# # 1) edges with bond-order change -# for u, v, data in ITS.edges(data=True): -# old, new = data.get(bond_key, [None, None]) -# if old != new: -# for n in (u, v): -# if not rc.has_node(n): -# rc.add_node( -# n, -# **{ -# k: ITS.nodes[n][k] for k in element_key if k in ITS.nodes[n] -# }, -# ) -# rc.add_edge( -# u, -# v, -# **{bond_key: data.get(bond_key), standard_key: data.get(standard_key)}, -# ) - -# if disconnected: -# # 2) nodes with a typesGH-based charge change -# for n, data in ITS.nodes(data=True): -# gh = data.get("typesGH") -# if ( -# isinstance(gh, (list, tuple)) -# and len(gh) >= 2 -# and len(gh[0]) > 3 -# and len(gh[1]) > 3 -# and gh[0][3] != gh[1][3] -# ): -# if not rc.has_node(n): -# rc.add_node(n, **{k: data[k] for k in element_key if k in data}) - -# # 3) re-add any ITS edge between RC nodes to preserve connectivity -# for u, v, data in ITS.edges(data=True): -# if rc.has_node(u) and rc.has_node(v) and not rc.has_edge(u, v): -# rc.add_edge( -# u, -# v, -# **{ -# bond_key: data.get(bond_key), -# standard_key: data.get(standard_key), -# }, -# ) - -# return rc - - -def its_decompose(its_graph: nx.Graph, nodes_share="typesGH", edges_share="order"): - """Decompose an ITS graph into two separate reactant (G) and product (H) - graphs. - - Nodes and edges in `its_graph` carry composite attributes: - - Each node has `its_graph.nodes[nodes_share] = (node_attrs_G, node_attrs_H)`. - - Each edge has `its_graph.edges[edges_share] = (order_G, order_H)`. - - This function splits those tuples to reconstruct the original G and H graphs. - - :param its_graph: The ITS graph with composite node/edge attributes. - :type its_graph: nx.Graph - :param nodes_share: Node attribute key storing (G_attrs, H_attrs) tuples. - :type nodes_share: str - :param edges_share: Edge attribute key storing (order_G, order_H) tuples. - :type edges_share: str - :returns: A tuple of two graphs (G, H) reconstructed from the ITS. - :rtype: Tuple[nx.Graph, nx.Graph] - - :example: - >>> its = nx.Graph() - >>> # ... set its.nodes[n]['typesGH'] and its.edges[e]['order'] ... - >>> G, H = its_decompose(its) - >>> isinstance(G, nx.Graph) and isinstance(H, nx.Graph) - True + Returns: + - bool: True if the graphs are isomorphic, False otherwise. """ - G = nx.Graph() - H = nx.Graph() - - # Decompose nodes - for node, data in its_graph.nodes(data=True): - if nodes_share in data: - node_attr_g, node_attr_h = data[nodes_share] - # Unpack node attributes for G - G.add_node( - node, - element=node_attr_g[0], - aromatic=node_attr_g[1], - hcount=node_attr_g[2], - charge=node_attr_g[3], - neighbors=node_attr_g[4], - atom_map=node, + # Define default node and edge attributes and match settings + if use_defaults: + node_label_names = ["element", "charge"] + node_label_default = ["*", 0] + edge_attribute = "order" + + # Default node and edge match functions if not provided + if node_match is None: + node_match = generic_node_match( + node_label_names, node_label_default, [eq] * len(node_label_names) ) - if len(node_attr_h) > 0: - # Unpack node attributes for H - H.add_node( - node, - element=node_attr_h[0], - aromatic=node_attr_h[1], - hcount=node_attr_h[2], - charge=node_attr_h[3], - neighbors=node_attr_h[4], - atom_map=node, - ) - - # Decompose edges - for u, v, data in its_graph.edges(data=True): - if edges_share in data: - order_g, order_h = data[edges_share] - if order_g > 0: # Assuming 0 means no edge in G - G.add_edge(u, v, order=order_g) - if order_h > 0: # Assuming 0 means no edge in H - H.add_edge(u, v, order=order_h) - - return G, H - - -def compare_graphs( - graph1: nx.Graph, - graph2: nx.Graph, - node_attrs: list = ["element", "aromatic", "hcount", "charge", "neighbors"], - edge_attrs: list = ["order"], + if edge_match is None: + edge_match = generic_edge_match(edge_attribute, 1, eq) + + # Perform the isomorphism check using NetworkX + return nx.is_isomorphic( + graph_1, graph_2, node_match=node_match, edge_match=edge_match + ) + + +def subgraph_isomorphism( + child_graph: nx.Graph, + parent_graph: nx.Graph, + node_label_names: List[str] = ["element", "charge"], + node_label_default: List[Any] = ["*", 0], + edge_attribute: str = "order", + use_filter: bool = False, + check_type: str = "induced", # "induced" or "monomorphism" + node_comparator: Optional[Callable[[Any, Any], bool]] = None, + edge_comparator: Optional[Callable[[Any, Any], bool]] = None, ) -> bool: - """Compare two graphs based on specified node and edge attributes. + """Enhanced checks if the child graph is a subgraph isomorphic to the + parent graph based on customizable node and edge attributes. Parameters: - - graph1 (nx.Graph): The first graph to compare. - - graph2 (nx.Graph): The second graph to compare. - - node_attrs (list): A list of node attribute names to include in the comparison. - - edge_attrs (list): A list of edge attribute names to include in the comparison. + - child_graph (nx.Graph): The child graph. + - parent_graph (nx.Graph): The parent graph. + - node_label_names (List[str]): Labels to compare. + - node_label_default (List[Any]): Defaults for missing node labels. + - edge_attribute (str): The edge attribute to compare. + - use_filter (bool): Whether to use pre-filters based on node and edge count. + - check_type (str): "induced" (default) or "monomorphism" for the type of subgraph matching. + - node_comparator (Callable[[Any, Any], bool]): Custom comparator for node attributes. + - edge_comparator (Callable[[Any, Any], bool]): Custom comparator for edge attributes. Returns: - - bool: True if both graphs are identical with respect to the specified attributes, - otherwise False. + - bool: True if subgraph isomorphism is found, False otherwise. """ - # Compare node sets - if set(graph1.nodes()) != set(graph2.nodes()): - return False - - # Compare nodes based on attributes - for node in graph1.nodes(): - if node not in graph2: - return False - node_data1 = {attr: graph1.nodes[node].get(attr, None) for attr in node_attrs} - node_data2 = {attr: graph2.nodes[node].get(attr, None) for attr in node_attrs} - if node_data1 != node_data2: + if use_filter: + # Initial quick filters based on node and edge counts + if len(child_graph) > len(parent_graph) or len(child_graph.edges) > len( + parent_graph.edges + ): return False - # Compare edge sets with sorted tuples - if set(tuple(sorted(edge)) for edge in graph1.edges()) != set( - tuple(sorted(edge)) for edge in graph2.edges() - ): - return False - - # Compare edges based on attributes - for edge in graph1.edges(): - # Sort the edge for consistent comparison - sorted_edge = tuple(sorted(edge)) - if sorted_edge not in graph2.edges(): - return False - edge_data1 = {attr: graph1.edges[edge].get(attr, None) for attr in edge_attrs} - edge_data2 = { - attr: graph2.edges[sorted_edge].get(attr, None) for attr in edge_attrs - } - if edge_data1 != edge_data2: - return False + # Step 2: Node label filter - Only consider 'element' and 'charge' attributes + for _, child_data in child_graph.nodes(data=True): + found_match = False + for _, parent_data in parent_graph.nodes(data=True): + match = True + # Compare only the 'element' and 'charge' attributes + for label, default in zip(node_label_names, node_label_default): + child_value = child_data.get(label, default) + parent_value = parent_data.get(label, default) + if child_value != parent_value: + match = False + break + if match: + found_match = True + break + if not found_match: + return False + + # Step 3: Edge label filter - Ensure that the edge attribute 'order' matches if provided + if edge_attribute: + for child_edge in child_graph.edges(data=True): + child_node1, child_node2, child_data = child_edge + if child_node1 in parent_graph and child_node2 in parent_graph: + # Ensure the edge exists in the parent graph + if not parent_graph.has_edge(child_node1, child_node2): + return False + # Check if the 'order' attribute matches + parent_edge_data = parent_graph[child_node1][child_node2] + child_order = child_data.get(edge_attribute) + parent_order = parent_edge_data.get(edge_attribute) + + # Handle comparison of tuple values for 'order' attribute + if isinstance(child_order, tuple) and isinstance( + parent_order, tuple + ): + if child_order != parent_order: + return False + elif child_order != parent_order: + return False + else: + return False + + # Setting up attribute comparison functions + node_comparator = node_comparator if node_comparator else eq + edge_comparator = edge_comparator if edge_comparator else eq + + # Creating match conditions for nodes and edges based on custom or default comparators + node_match = generic_node_match( + node_label_names, node_label_default, [node_comparator] * len(node_label_names) + ) + edge_match = ( + generic_edge_match(edge_attribute, None, edge_comparator) + if edge_attribute + else None + ) - return True + # Graph matching setup + matcher = GraphMatcher( + parent_graph, child_graph, node_match=node_match, edge_match=edge_match + ) + + # Executing the matching based on specified type + if check_type == "induced": + return matcher.subgraph_is_isomorphic() + else: + return matcher.subgraph_is_monomorphic() -def enumerate_tautomers(reaction_smiles: str) -> Optional[List[str]]: - """Enumerates possible tautomers for reactants while canonicalizing the - products in a reaction SMILES string. This function first splits the - reaction SMILES string into reactants and products. It then generates all - possible tautomers for the reactants and canonicalizes the product - molecule. The function returns a list of reaction SMILES strings for each - tautomer of the reactants combined with the canonical product. +def maximum_connected_common_subgraph( + graph_1: nx.Graph, + graph_2: nx.Graph, + node_label_names: List[str] = ["element", "charge"], + node_label_default: List[Any] = ["*", 0], + edge_attribute: str = "standard_order", +) -> nx.Graph: + """Computes the largest connected common subgraph (MCS) between two graphs + using subgraph isomorphism based on customizable node and edge attributes. + + The function iterates over subsets of nodes from the smaller graph—starting from the largest + possible subgraph size down to 1—and returns the first (largest) candidate that is connected + and is isomorphic to a subgraph of the larger graph. Parameters: - - reaction_smiles (str): A SMILES string of the reaction formatted as - 'reactants>>products'. + - graph_1 (nx.Graph): The first graph for comparison. + - graph_2 (nx.Graph): The second graph for comparison. + - node_label_names (List[str]): List of node attribute names used for matching. + - node_label_default (List[Any]): Default values for missing node attributes. + - edge_attribute (str): The edge attribute to compare. Returns: - - List[str] | None: A list of SMILES strings for the reaction, with each string - representing a different - - tautomer of the reactants combined with the canonicalized products. Returns None if - an error occurs or if invalid SMILES strings are provided. - - Raises: - - ValueError: If the provided SMILES strings cannot be converted to molecule objects, - indicating invalid input. + - nx.Graph: A graph representing the largest connected common subgraph found; if none exists, + returns an empty graph. """ - try: - # Split the input reaction SMILES string into reactants and products - reactants_smiles, products_smiles = reaction_smiles.split(">>") - - # Convert SMILES strings to molecule objects - reactants_mol = Chem.MolFromSmiles(reactants_smiles) - products_mol = Chem.MolFromSmiles(products_smiles) - - if reactants_mol is None or products_mol is None: - raise ValueError( - "Invalid SMILES string provided for reactants or products." + node_match = generic_node_match( + node_label_names, node_label_default, [eq] * len(node_label_names) + ) + edge_match = generic_edge_match(edge_attribute, 1, eq) + + # Determine which graph is smaller for efficiency. + if graph_1.number_of_nodes() <= graph_2.number_of_nodes(): + smaller_graph, larger_graph = graph_1, graph_2 + else: + smaller_graph, larger_graph = graph_2, graph_1 + + num_nodes_smaller = smaller_graph.number_of_nodes() + # Iterate over possible subgraph sizes from the largest to 1. + for subgraph_size in range(num_nodes_smaller, 0, -1): + for nodes_subset in itertools.combinations( + smaller_graph.nodes(), subgraph_size + ): + candidate_subgraph = smaller_graph.subgraph(nodes_subset) + # If the subgraph has more than one node, check it is connected. + if candidate_subgraph.number_of_nodes() > 1 and not nx.is_connected( + candidate_subgraph + ): + continue + + # Check for subgraph isomorphism in the larger graph. + matcher = GraphMatcher( + larger_graph, + candidate_subgraph, + node_match=node_match, + edge_match=edge_match, ) + if matcher.subgraph_is_isomorphic(): + return candidate_subgraph.copy() - # Initialize tautomer enumerator - - enumerator = rdMolStandardize.TautomerEnumerator() + return nx.Graph() - # Enumerate tautomers for the reactants and canonicalize the products - try: - reactants_can = enumerator.Enumerate(reactants_mol) - except Exception as e: - print(f"An error occurred: {e}") - reactants_can = [reactants_mol] - products_can = products_mol - - # Convert molecule objects back to SMILES strings - reactants_can_smiles = [Chem.MolToSmiles(i) for i in reactants_can] - products_can_smiles = Chem.MolToSmiles(products_can) - - # Combine each reactant tautomer with the canonical product in SMILES format - rsmi_list = [i + ">>" + products_can_smiles for i in reactants_can_smiles] - if len(rsmi_list) == 0: - return [reaction_smiles] - else: - # rsmi_list.remove(reaction_smiles) - rsmi_list.insert(0, reaction_smiles) - return rsmi_list - - except Exception as e: - print(f"An error occurred: {e}") - return [reaction_smiles] +def heuristics_MCCS( + graphs: List[nx.Graph], + node_label_names: List[str] = ["element", "charge"], + node_label_default: List[Any] = ["*", 0], + edge_attribute: str = "standard_order", +) -> nx.Graph: + """Computes the Maximum Connected Common Subgraph (MCCS) over a list of + graphs using a heuristic approach. -def mapping_success_rate(list_mapping_data): - """Calculate the success rate of entries containing atom mappings in a list - of data strings. + This function computes the MCCS between the first two graphs using the + `maximum_connected_common_subgraph` function based on customizable node and edge attributes. + For more than two graphs, it iteratively updates the common subgraph by calculating the MCCS + between the current common subgraph and each subsequent graph. An early exit occurs if the + intermediate common subgraph becomes empty. Parameters: - - list_mapping_in_data (list of str): List containing strings to be searched for atom - mappings. + - graphs (List[nx.Graph]): A list of networkx graphs for which the common subgraph is to be computed. + - node_label_names (List[str]): List of node attribute names used for matching. + - node_label_default (List[Any]): Default values for missing node attributes. + - edge_attribute (str): The edge attribute to compare. Returns: - - float: The success rate of finding atom mappings in the list as a percentage. + - nx.Graph: The maximum connected common subgraph common to all provided graphs. If no common + subgraph exists, an empty graph is returned. Raises: - - ValueError: If the input list is empty. + - ValueError: If the input list of graphs is empty. """ - atom_map_pattern = re.compile(r":\d+") - if not list_mapping_data: - raise ValueError("The input list is empty, cannot calculate success rate.") + if not graphs: + raise ValueError("Input list of graphs is empty.") + + if len(graphs) == 1: + return graphs[0].copy() + + # Handle the two-graph case explicitly. + if len(graphs) == 2: + return maximum_connected_common_subgraph( + graphs[0], + graphs[1], + node_label_names=node_label_names, + node_label_default=node_label_default, + edge_attribute=edge_attribute, + ) - success = sum( - 1 for entry in list_mapping_data if re.search(atom_map_pattern, entry) + # Iteratively compute the MCCS for more than two graphs. + current_mcs = maximum_connected_common_subgraph( + graphs[0], + graphs[1], + node_label_names=node_label_names, + node_label_default=node_label_default, + edge_attribute=edge_attribute, ) - rate = 100 * (success / len(list_mapping_data)) - return round(rate, 2) \ No newline at end of file + for graph in graphs[2:]: + if current_mcs.number_of_nodes() == 0: + break # Early exit if no common subgraph remains. + current_mcs = maximum_connected_common_subgraph( + current_mcs, + graph, + node_label_names=node_label_names, + node_label_default=node_label_default, + edge_attribute=edge_attribute, + ) + + return current_mcs diff --git a/synkit/IO/graph_to_mol.py b/synkit/IO/graph_to_mol.py index 8049f0d..ddcf048 100644 --- a/synkit/IO/graph_to_mol.py +++ b/synkit/IO/graph_to_mol.py @@ -74,7 +74,7 @@ def graph_to_mol( node_to_idx: Dict[int, int] = {} for node, data in graph.nodes(data=True): - element = data.get(self.node_attributes["element"], "C") + element = data.get(self.node_attributes["element"], "*") charge = data.get(self.node_attributes["charge"], 0) atom_map = ( data.get(self.node_attributes["atom_map"], 0) @@ -93,7 +93,7 @@ def graph_to_mol( atom.SetAtomMapNum(atom_map) if hcount is not None: atom.SetNoImplicit(True) - atom.SetNumExplicitHs(hcount) + atom.SetNumExplicitHs(int(hcount)) idx = mol.AddAtom(atom) node_to_idx[node] = idx From 691f9c08b6a3bfc706f099e774491f31f39feab5 Mon Sep 17 00:00:00 2001 From: TieuLongPhan Date: Tue, 22 Jul 2025 11:27:50 +0200 Subject: [PATCH 3/5] pass test staging, prepare release --- .github/workflows/test-and-lint.yml | 2 +- doc/conf.py | 2 +- synkit/Graph/ITS/its_builder.py | 2 +- synkit/Graph/ITS/its_construction.py | 2 +- synkit/Graph/ITS/its_decompose.py | 2 +- synkit/Graph/ITS/its_expand.py | 2 +- synkit/Synthesis/Reactor/syn_reactor.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test-and-lint.yml b/.github/workflows/test-and-lint.yml index 8222b16..f62fa21 100644 --- a/.github/workflows/test-and-lint.yml +++ b/.github/workflows/test-and-lint.yml @@ -2,7 +2,7 @@ name: Test & Lint on: push: - branches: [ "main", "dev", "maintain", "refractor" ] + branches: [ "main", "dev", "staging", "refractor" ] pull_request: branches: [ "main" ] diff --git a/doc/conf.py b/doc/conf.py index 509d687..16d2f98 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -45,4 +45,4 @@ # -- Options for HTML output ------------------------------------------------- html_theme = "sphinx_rtd_theme" -html_static_path = ["_static"] \ No newline at end of file +html_static_path = ["_static"] diff --git a/synkit/Graph/ITS/its_builder.py b/synkit/Graph/ITS/its_builder.py index 2ab5591..e6999b2 100644 --- a/synkit/Graph/ITS/its_builder.py +++ b/synkit/Graph/ITS/its_builder.py @@ -111,4 +111,4 @@ def ITSGraph(G: nx.Graph, RC: nx.Graph) -> nx.Graph: # 7) Renumber atom_map to node indices ITSBuilder.update_atom_map(its) - return its \ No newline at end of file + return its diff --git a/synkit/Graph/ITS/its_construction.py b/synkit/Graph/ITS/its_construction.py index bfa51f7..c5bd596 100644 --- a/synkit/Graph/ITS/its_construction.py +++ b/synkit/Graph/ITS/its_construction.py @@ -322,4 +322,4 @@ def typesGH(self) -> Dict[str, Dict[str, Tuple[Any, Any]]]: } node_defaults = {k: (tp, 0) for k, tp in sel_nodes.items()} edge_defaults = {k: (tp, 0) for k, tp in sel_edges.items()} - return {"node": node_defaults, "edge": edge_defaults} \ No newline at end of file + return {"node": node_defaults, "edge": edge_defaults} diff --git a/synkit/Graph/ITS/its_decompose.py b/synkit/Graph/ITS/its_decompose.py index c608acb..df9a1fd 100644 --- a/synkit/Graph/ITS/its_decompose.py +++ b/synkit/Graph/ITS/its_decompose.py @@ -513,4 +513,4 @@ def mapping_success_rate(list_mapping_data): ) rate = 100 * (success / len(list_mapping_data)) - return round(rate, 2) \ No newline at end of file + return round(rate, 2) diff --git a/synkit/Graph/ITS/its_expand.py b/synkit/Graph/ITS/its_expand.py index bb3cf2f..0a23438 100644 --- a/synkit/Graph/ITS/its_expand.py +++ b/synkit/Graph/ITS/its_expand.py @@ -83,4 +83,4 @@ def expand_aam_with_its( # Convert graphs back to RSMI and standardize atom mappings expanded_rsmi = graph_to_rsmi(new_react, new_prod, its_graph, True, False) - return std.fit(expanded_rsmi, remove_aam=False) \ No newline at end of file + return std.fit(expanded_rsmi, remove_aam=False) diff --git a/synkit/Synthesis/Reactor/syn_reactor.py b/synkit/Synthesis/Reactor/syn_reactor.py index fb00feb..7ba82f8 100644 --- a/synkit/Synthesis/Reactor/syn_reactor.py +++ b/synkit/Synthesis/Reactor/syn_reactor.py @@ -572,4 +572,4 @@ def _to_smarts(its: nx.Graph) -> str: p_smi = graph_to_smi(right) if r_smi is None or p_smi is None: return None - return f"{r_smi}>>{p_smi}" \ No newline at end of file + return f"{r_smi}>>{p_smi}" From 72461ceee717749b9747bc8d8d33750031982f5a Mon Sep 17 00:00:00 2001 From: TieuLongPhan Date: Fri, 25 Jul 2025 15:20:30 +0200 Subject: [PATCH 4/5] quick fix hydrogen --- synkit/Graph/Hyrogen/hcomplete.py | 1 + 1 file changed, 1 insertion(+) diff --git a/synkit/Graph/Hyrogen/hcomplete.py b/synkit/Graph/Hyrogen/hcomplete.py index 32e6c1b..c85216d 100644 --- a/synkit/Graph/Hyrogen/hcomplete.py +++ b/synkit/Graph/Hyrogen/hcomplete.py @@ -350,4 +350,5 @@ def add_hydrogen_nodes_multiple_utils( # conjugated=False, # in_ring=False, ) + new_graph.nodes[node_id]["hcount"] -= 1 return new_graph From 54ddf1cb41f519e95aa68db6dee8e71bd0b495bf Mon Sep 17 00:00:00 2001 From: TieuLongPhan Date: Fri, 25 Jul 2025 15:21:17 +0200 Subject: [PATCH 5/5] prepare release --- pyproject.toml | 2 +- recipe/meta.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 38f4bfe..657e6cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "synkit" -version = "0.0.12" +version = "0.0.13" license = { text = "MIT" } license-files = ["LICENSE"] authors = [ diff --git a/recipe/meta.yaml b/recipe/meta.yaml index 5b6d4fb..4705fba 100644 --- a/recipe/meta.yaml +++ b/recipe/meta.yaml @@ -1,6 +1,6 @@ package: name: synkit - version: 0.0.12 + version: 0.0.13 source: path: ..