diff --git a/.github/workflows/lint_and_test.yaml b/.github/workflows/lint_and_test.yaml index 4fa9adc3..bd52a90c 100644 --- a/.github/workflows/lint_and_test.yaml +++ b/.github/workflows/lint_and_test.yaml @@ -55,10 +55,9 @@ jobs: test_digs: name: pytest (jojo) - runs-on: [jojo] + runs-on: [self-hosted] timeout-minutes: 30 needs: lint - if: github.event_name == 'workflow_dispatch' steps: - uses: actions/checkout@v4 - name: Run tests diff --git a/.github/workflows/release_and_docs.yaml b/.github/workflows/release_and_docs.yaml index ceccfec2..1aaeba3c 100644 --- a/.github/workflows/release_and_docs.yaml +++ b/.github/workflows/release_and_docs.yaml @@ -4,6 +4,7 @@ on: push: branches: - production + - doc_release jobs: release_and_docs: @@ -156,4 +157,4 @@ jobs: keep_files: true # Keep existing versions force_orphan: false # Don't force orphan, preserve history user_name: 'github-actions[bot]' - user_email: 'github-actions[bot]@users.noreply.github.com' \ No newline at end of file + user_email: 'github-actions[bot]@users.noreply.github.com' diff --git a/README.md b/README.md index 49824bfe..2fb26f7f 100644 --- a/README.md +++ b/README.md @@ -4,18 +4,22 @@ [![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg)](https://baker-laboratory.github.io/atomworks-dev/latest/index.html) [![License: BSD 3-Clause](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) -atomworks logo +
+ atomworks logo +
-**atomworks** is an open-source platform that maximizes research velocity for biomolecular modeling tasks. Much like how [Torchdata](https://docs.pytorch.org/data/beta/index.html) enables rapid prototyping within the vision and language domains, AtomWorks aims to accelerate development and experimentation within biomolecular modeling. +**atomworks** is an open-source platform that maximizes research velocity for biomolecular modeling tasks. Much like how [Torchvision](https://docs.pytorch.org/vision/stable/index.html) enables rapid prototyping within the vision domain, and [Torchaudio](https://docs.pytorch.org/audio/main/) within the audio domain, AtomWorks aims to accelerate development and experimentation within biomolecular modeling. -> **⚠️ Notice:** We are currently finalizing some cleanup work within our repositories. Please expect the APIs (e.g., function and class names, inputs and outputs) to stabilize within the next two weeks. Thank you for your patience! +> **⚠️ Notice:** We are currently finalizing some cleanup work within our repositories. Please expect the APIs (e.g., function and class names, inputs and outputs) to stabilize within the next one week. Thank you for your patience! If you're looking for the models themselves (e.g., RF3, MPNN) that integrate with AtomWorks rather than the underlying framework, check out [ModelForge](https://github.com/RosettaCommons/modelforge) +> **💡 Note:** Not sure where to start? We've made some [examples in the AtomWorks documentation](https://baker-laboratory.github.io/atomworks-dev/latest/auto_examples/index.html) that work through several helpful scenarios; a full tutorial is under construction! + AtomWorks is composed of two symbiotic libraries: -- **atomworks.io:** A universal Python toolkit for parsing, cleaning, manipulating, and converting biological data (structures, sequences, small molecules). Built on the [biotite](https://www.biotite-python.org/) API, it seamlessly loads and exports between standard formats like mmCIF, PDB, FASTA, SMILES, MOL, and more. -- **atomworks.ml:** Advanced dataset featurization and sampling for deep learning workflows that uses `atomworks.io` as its structural backbone. We provide a comprehensive, pre-built and well-tested set of `Transforms` for common tasks that can be easily composed into full deep-learning pipelines; users may also create their own `Transforms` for custom operations. +- `atomworks.io`: A universal Python toolkit for parsing, cleaning, manipulating, and converting biological data (structures, sequences, small molecules). Built on the [biotite](https://www.biotite-python.org/) API, it seamlessly loads and exports between standard formats like mmCIF, PDB, FASTA, SMILES, MOL, and more. Broadly useful for anyone who works with structural data for biomolecules. +- `atomworks.ml`: Advanced dataset featurization and sampling for deep learning workflows that uses `atomworks.io` as its structural backbone. We provide a comprehensive, pre-built and well-tested set of `Transforms` for common tasks that can be easily composed into full deep-learning pipelines; users may also create their own `Transforms` for custom operations. For more detail on the motivation for and applications of AtomWorks, please see the [preprint](https://doi.org/10.1101/2025.08.14.670328). @@ -25,7 +29,7 @@ AtomWorks is built atop [biotite](https://www.biotite-python.org/): We are grate ## atomworks.io -> *A general-purpose Python toolkit for cleaning up, standardizing, and working with biomolecular files - based on biotite* +> *A general-purpose Python toolkit for cleaning, standardizing, and manipulating with biomolecular structure files - built atop [biotite](https://www.biotite-python.org/): **atomworks.io** lets you: @@ -33,7 +37,7 @@ AtomWorks is built atop [biotite](https://www.biotite-python.org/): We are grate - Transform all data to a consistent `AtomArray` representation for further analysis or machine learning applications, regardless of initial source - Model missing atoms (those implied by the sequence but not represented in the coordinates) and initialize entity- and instance-level annotations (see the [glossary]() for more detail on our composable naming conventions) -We have found `atomworks.io` to be useful to a general bioinformatics and protein design audience; in many cases, `atomworks.io` can replace bespoke scripts and manual curation, enabling researchers to spend more time testing hypothesis and less time juggling dozens of tools and dependencies. +We have found `atomworks.io` to be generally useful to a broad bioinformatics and protein design audience; in many cases, `atomworks.io` can replace bespoke scripts and manual curation, enabling researchers to spend more time testing hypothesis and less time juggling dozens of tools and dependencies. --- @@ -45,11 +49,12 @@ We have found `atomworks.io` to be useful to a general bioinformatics and protei - A library of pre-built, well-tested `Transforms` that can be slotted into novel pipelines - An extensible framework, integrated with `atomworks.io`, to write `Transforms` for arbitrary use cases -- Scripts to pre-process the PDB or other databases into dataframes appropriate for network training -- Efficient sampling and batching utilities for training machine learning models +- Pre-built datasets and samplers suitable for most model training scenarios Within the AtomWorks paradigm, the output of each `Transform` is not an opaque dictionary with model-specific tensors but instead an updated version of our atom-level structural representation (Biotite's `AtomArray`). Operations within – and between – pipelines thus maintain a common vocabulary of inputs and outputs. +We have found that `atomworks.ml` **dramatically** reduces the overhead of starting, and completing, many ML projects; research topics that once took months now achieve signal within weeks if not days, accelerating the pace of innovation. + --- ## Installation diff --git a/docs/_static/examples/dataset_exploration_01.png b/docs/_static/examples/dataset_exploration_01.png new file mode 100644 index 00000000..a0316e52 Binary files /dev/null and b/docs/_static/examples/dataset_exploration_01.png differ diff --git a/docs/_static/examples/simple_transform_example.pdf b/docs/_static/examples/simple_transform_example.pdf new file mode 100644 index 00000000..491675ec Binary files /dev/null and b/docs/_static/examples/simple_transform_example.pdf differ diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 20ea14c7..44f008b4 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -5,5 +5,6 @@ API :maxdepth: 2 :caption: API Modules + core io ml \ No newline at end of file diff --git a/docs/core.rst b/docs/core.rst new file mode 100644 index 00000000..72d1d2f4 --- /dev/null +++ b/docs/core.rst @@ -0,0 +1,12 @@ +Core Modules +============ + +The core modules provide fundamental utilities, constants, and enumerations used throughout the atomworks library. + +.. toctree:: + :maxdepth: 2 + + core/common + core/constants + core/enums + diff --git a/docs/core/common.rst b/docs/core/common.rst new file mode 100644 index 00000000..797187ab --- /dev/null +++ b/docs/core/common.rst @@ -0,0 +1,8 @@ +Common Utilities +================ + +.. automodule:: atomworks.common + :members: + :undoc-members: + :show-inheritance: + diff --git a/docs/core/constants.rst b/docs/core/constants.rst new file mode 100644 index 00000000..0afca45b --- /dev/null +++ b/docs/core/constants.rst @@ -0,0 +1,8 @@ +Constants +========= + +.. automodule:: atomworks.constants + :members: + :undoc-members: + :show-inheritance: + diff --git a/docs/core/enums.rst b/docs/core/enums.rst new file mode 100644 index 00000000..64b59f57 --- /dev/null +++ b/docs/core/enums.rst @@ -0,0 +1,8 @@ +Enumerations +============ + +.. automodule:: atomworks.enums + :members: + :undoc-members: + :show-inheritance: + diff --git a/docs/examples/annotate_and_save_structures.py b/docs/examples/annotate_and_save_structures.py index 4ebc9a5e..7fb13648 100644 --- a/docs/examples/annotate_and_save_structures.py +++ b/docs/examples/annotate_and_save_structures.py @@ -220,6 +220,6 @@ def fix_boolean_annotation(atom_array: struc.AtomArray, annotation_name: str) -> ######################################################################## # Related Examples -# ---------- +# --------------- # # - :doc:`pocket_conditioning_transform` - Create custom transforms for ligand pocket identification and ML feature generation diff --git a/docs/examples/dataset_exploration.py b/docs/examples/dataset_exploration.py new file mode 100644 index 00000000..28102cc3 --- /dev/null +++ b/docs/examples/dataset_exploration.py @@ -0,0 +1,238 @@ +""" +Dataset Exploration and Management in AtomWorks +=============================================== + +This example demonstrates how to work with datasets in AtomWorks, from simple file-based datasets to complex tabular datasets with custom loaders and transform pipelines. + +**Prerequisites**: Familiarity with :doc:`load_and_visualize_structures` for basic structure loading and :doc:`pocket_conditioning_transform` for understanding transform pipelines. + +.. figure:: /_static/examples/dataset_exploration_01.png + :alt: Cropped structure visualization + :width: 400px + + Visualization of a cropped structure after applying transform pipelines to a dataset. +""" + +######################################################################## +# Overview +# ========= +# +# `Transform` pipelines can be used with any data loader and any dataset. They are simply functions that take as input an `AtomArray` (which is often the output of `AtomWorks.io`) and output `PyTorch` tensors ready for ingestion by a model. +# +# However, most users will not want to build datasets from scratch. For convenience, we provide pre-built datasets and dataloaders that play well with `Transform` pipelines as well, roughly adhering to `Torchvision `_ conventions. +# +# We demonstrate below a couple of different ways to connect a `Transform` pipeline with arbitrary datasets and connect them with trivial `Transform` pipelines. + +######################################################################## +# Datasets in AtomWorks +# ====================== + +######################################################################## +# Using a Folder of CIF/PDB Files as a Dataset +# --------------------------------------------- +# +# The simplest way to use AtomWorks with a Dataset is to create a `Dataset` and `Sampler` pointed to a directory of structural files (e.g., PDB, CIF). +# +# **NOTE**: All AtomWorks Datasets require a `name` attribute to support many of the logging/debugging features that are supplied out-of-the-box. + +from atomworks.ml.datasets.datasets import FileDataset + +# To setup the test pack, if not already, run `atomworks setup tests` +dataset = FileDataset.from_directory( + directory="../../tests/data/ml/af2_distillation/cif", name="example_directory_dataset" +) + +######################################################################## +# Let's explore the dataset a tiny bit. + +# Count the number of examples in the dataset +print(f"Dataset has {len(dataset)} examples.") + +# Print the raw data of the first 5 examples +for i, example in enumerate(dataset): + if i >= 5: + break + print(f"Example {i + 1}: {example}") + +######################################################################## +# Understanding Dataset Requirements +# ---------------------------------- +# +# At a high level, to train models with AtomWorks, we need typically need a Dataset that: +# +# (1) Takes as input an item index and returns the corresponding example information; typically includes: +# a. Path to a structural file saved on disk (`/path/to/dataset/my_dataset_0.cif`) +# b. Additional item-specific metadata (e.g., class labels) +# +# (2) Pre-loads structural information from the returned example into an `AtomArray` and assembles inputs for the Transform pipeline +# +# (3) Feed the input dictionary through a Transform pipeline and returns the result +# +# So far, the `FileDataset` we initialized only accomplishes (1) from above - returning the raw data. +# +# To accomplish (2), we can additionally pass a loading function at dataset initialization that takes the raw example data as input and returns a pre-processed ready for a Transform pipeline. +# +# In most cases, this will involve using `parse` or `load_any` from `AtomWorks.io` to build an `AtomArray`, which is the common language of our `Transform` library. + +from atomworks.io import parse + + +def simple_loading_fn(raw_data) -> dict: + """Simple loading function that parses structural data and returns an AtomArray.""" + parse_output = parse(raw_data) + return {"atom_array": parse_output["assemblies"]["1"][0]} + + +dataset_with_loading_fn = FileDataset.from_directory( + directory="../../tests/data/pdb", name="example_pdb_dataset", loader=simple_loading_fn +) +output = dataset_with_loading_fn[1] +print(f"Output AtomArray has {len(output['atom_array'])} atoms!") + +######################################################################## +# Adding Transform Pipelines +# --------------------------- +# +# Next up is adding in a pipeline. Let's create a simple one with a dramatic crop. + +from atomworks.ml.transforms.base import Compose +from atomworks.ml.transforms.crop import ( + CropSpatialLikeAF3, +) +from atomworks.ml.transforms.atom_array import ( + AddGlobalAtomIdAnnotation, +) +from atomworks.ml.transforms.atomize import AtomizeByCCDName +from atomworks.constants import STANDARD_AA + +pipe = Compose( + [ + # (We need to add these transforms before we can crop) + AddGlobalAtomIdAnnotation(), + AtomizeByCCDName(atomize_by_default=True, res_names_to_ignore=STANDARD_AA), + # Crop to 20 tokens (which in this case is number amino acids/nucleic acid bases + number of small molecule atoms) + CropSpatialLikeAF3(crop_size=20), + ], + track_rng_state=False, +) + +######################################################################## +# Just like with the loading function, we can also pass a composed `Transform` pipeline to our datasets. + +dataset_with_loading_fn_and_transforms = FileDataset.from_directory( + directory="../../tests/data/pdb", name="example_pdb_dataset", loader=simple_loading_fn, transform=pipe +) + +######################################################################## +# Visualizing the Results +# ------------------------ +# +# Let's visualize the result of our transform pipeline: + +from atomworks.io.utils.visualize import view + +pipeline_output = dataset_with_loading_fn_and_transforms[ + 0 +] # This will trigger the loading function and print the row information + +view(pipeline_output["atom_array"]) + +######################################################################## +# .. figure:: /_static/examples/dataset_exploration_01.png +# :alt: Cropped structure visualization + +######################################################################## +# And indeed, we have a cropped example! +# +# We will then sample uniformly (with or without replacement) from this dataset during training. Such a simple application may be appropriate for many fine-tuning cases such as distillation. +# +# The only "gotcha" outside of normal PyTorch sampling is that you'll need to implement a default collate function (which could simply be the identity) so long as your output dictionary contains an `AtomArray`. + +from torch.utils.data import RandomSampler, DataLoader + +sampler = RandomSampler(dataset_with_loading_fn_and_transforms) +loader = DataLoader( + dataset=dataset_with_loading_fn_and_transforms, + sampler=sampler, + collate_fn=lambda x: x, # Identity collate: returns the batch as-is +) + +for i, example in enumerate(loader): + # (Since we now have a batch dimension, we need the extra indexing dimension) + print(f"Example: {i}, Length of AtomArray: {len(example[0]['atom_array'])}") + if i > 2: + break + +######################################################################## +# For more complicated sampling strategies, including distributed sampling for multi-GPU training, see the API documentation for `samplers.py`, and the tests in `test_samplers.py` + +######################################################################## +# Tabular Datasets +# ================= +# +# So far, we have seen how to make and use simple datasets with just paths. In many applications, however, we may want more nuanced dataset schemes. For example, when training on the PDB, we typically want to sample at the chain or interface-level rather than the entry-level (since we are cropping, the two are distinct). We may also want to provide additional information other than the raw CIF file (e.g., class labels) to be used by the model during training. +# +# We thus support instantiating datasets from tabular sources stored on disk. +# +# We have implemented a `PandasDataset` class for this purpose; however, any tabular format (e.g., `PolarsDataset`) could be similarly implemented without difficulty should the need arise (PR's welcome!) + +######################################################################## +# PandasDataset +# -------------- +# +# The `PandasDataset` class requires a couple of arguments: +# - `data`: Either a pandas DataFrame or path to a CSV/Parquet file containing the tabular data. Each row represents one example. +# - `name`: Descriptive name for this dataset, just as in `FileDataset` and all AtomWorks `Dataset` classes. Used for debugging and some downstream functions when using nested datasets. +# +# Again, we can also pass a `transform` pipeline and `loader`: +# - `transform`: Transform pipeline to apply to loaded data. +# - `loader`: Optional function to process raw DataFrame rows into Transform-ready format. +# +# There's also a few other `PandasDataset`-specific arguments to note: +# - `filters`: Optional list of pandas query strings to filter the data. Applied in order during initialization. +# - `columns_to_load`: Optional list of column names to load when reading from a file. If None, all columns are loaded. Can dramatically reduce memory usage and load time if loading from a columnar format like Parquet. + +######################################################################## +# We will start by exploring an example metadata dataframe, then load it into a `PandasDataset`. + +from atomworks.ml.utils.io import read_parquet_with_metadata + +interfaces_metadata_parquet_path = "../../tests/data/ml/pdb_interfaces/metadata.parquet" +interfaces_df = read_parquet_with_metadata(interfaces_metadata_parquet_path) +print("DataFrame shape:", interfaces_df.shape) +print("Columns:", list(interfaces_df.columns)) +print("\nFirst few rows:") +print(interfaces_df.head()) + +######################################################################## +# Understanding the Metadata +# --------------------------- +# +# This dataframe includes a row for every interface between two `pn_units` (essentially, chains) in the Protein Data Bank. For illustration purposes, however, we're loading the test dataframe, which only includes information for a small subset of the full PDB. +# +# The complete dataframes can be downloaded with `atomworks setup metadata` and will be described in greater detail elsewhere in the documentation. +# +# For our purposes, note that we have a `path` column that points to a `.cif` file stored on disk, an `example_id` column which is unique across every row in the dataset, and two columns `pn_unit_1_iid` and `pn_unit_2_iid` that specify the interface of interest for this particular row. +# +# **NOTE**: Because a given PDB ID may contain many interfaces and thus may appear multiple times in our dataset, we must also incorporate the `assembly_id` and the `pn_unit_iids` of the two interacting chains within the `example_id`. + +from atomworks.ml.datasets.datasets import PandasDataset +from atomworks.ml.datasets.loaders import create_loader_with_query_pn_units + +dataset = PandasDataset( + data=interfaces_df, + name="interfaces_dataset", + # We use a pre-built loader that takes in a list of column names and returns a loader function + loader=create_loader_with_query_pn_units(pn_unit_iid_colnames=["pn_unit_1_iid", "pn_unit_2_iid"]), + transform=pipe, +) + +print(f"Created PandasDataset with {len(dataset)} examples") + +######################################################################## +# Related Examples +# --------------- +# +# - :doc:`load_and_visualize_structures` - Learn how to load and explore protein structures +# - :doc:`pocket_conditioning_transform` - Create custom transforms for ligand pocket identification and ML feature generation +# - :doc:`annotate_and_save_structures` - Learn how to add custom annotations to structures and save them for later use diff --git a/docs/examples/load_and_visualize_structures.py b/docs/examples/load_and_visualize_structures.py index 09aa358e..5841e48a 100644 --- a/docs/examples/load_and_visualize_structures.py +++ b/docs/examples/load_and_visualize_structures.py @@ -37,7 +37,7 @@ ######################################################################## # Using ``parse()`` for Full Processing -# ------------------------------------ +# ------------------------------------- # # For RCSB structures, we typically load structures with ``parse()`` to get clean data suitable for most downstream tasks. # @@ -58,7 +58,7 @@ ######################################################################## # Using ``load_any()`` for Lightweight Loading -# ------------------------------------------- +# -------------------------------------------- # For comparison: load_any() for lightweight loading (no extensive processing) # Useful when you have clean data (e.g., from distillation) and/or want to preserve all annotations @@ -155,9 +155,22 @@ else: print(f" {key}: {value}") +######################################################################## +# Accessing the Original mmCIF Data +# ----------------------------------- +# +# If there is information contained in the mmCIF file that is *not* extracted by `parse`, we can still gain access +# to the original Biotite CIF block using the ``keep_cif_block=True`` argument to `parse`. +# We can then use the Biotite API to explore any additional data we might need. +# (E.g., we could write a simple `Transform` that extracts the necessary information) + +# Load with original CIF block retained +parse_output_with_cif = parse(pdb_path, keep_cif_block=True) +cif_block = parse_output_with_cif.get("cif_block", None) + ######################################################################## # Related Examples -# ---------- +# --------------- # # - :doc:`annotate_and_save_structures` - Learn how to add custom annotations to structures and save them for later use # - :doc:`pocket_conditioning_transform` - Create custom transforms for ligand pocket identification and ML feature generation diff --git a/docs/examples/pocket_conditioning_transform.py b/docs/examples/pocket_conditioning_transform.py index 668edcf7..01b683cf 100644 --- a/docs/examples/pocket_conditioning_transform.py +++ b/docs/examples/pocket_conditioning_transform.py @@ -29,7 +29,7 @@ # Conventions # ----------- # **A.** Store information in ``AtomArray`` annotations, not in the state dictionary. -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # This ensures robustness when atoms are added/removed downstream. # @@ -39,7 +39,7 @@ # - ❌ Store ``pocket_atom_indices`` in dictionary (which creates significant dependencies with operations that delete or re-order atoms) # # **B.** Within ``forward()``, call a stand-alone function with the same name as the transform class. -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # We thus maintain an object-oriented and a functional API, making our core logic re-usable and testable outside of the ``Transform`` framework. # @@ -51,7 +51,7 @@ # Additionally, this function should preserve the input (e.g., not modify the underlying ``AtomArray``) and take as arguments any necessary parameters. # # **C.** Each ``Transform`` should follow the single-responsibility-principle; in particular separate Annotation from Featurization ``Transforms`` -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # To ensure our ``Transform`` code is maximally forward-compatible and re-usable across disparate pipelines, we adhere to the single responsibility principle - that is, each transform should do *exactly one* action. # @@ -88,7 +88,7 @@ ######################################################################## # Building ``AnnotateLigandPockets`` -# =============================== +# ================================== # # Let's create a ``Transform`` that identifies atoms near ligands (non-polymer molecules) of sufficient size. # @@ -209,7 +209,7 @@ def forward(self, data: dict) -> dict: ######################################################################## # Building ``FeaturizePocketAtoms`` -# ============================== +# ================================= # # Now let's create a model-specific transform that converts derived pocket annotations into numeric features. # diff --git a/docs/io.rst b/docs/io.rst index 9f94cdec..322558e5 100644 --- a/docs/io.rst +++ b/docs/io.rst @@ -4,9 +4,6 @@ IO .. toctree:: :maxdepth: 2 - io/common - io/constants - io/enums io/parser io/tools io/transforms diff --git a/docs/io/common.rst b/docs/io/common.rst deleted file mode 100644 index 9b963d94..00000000 --- a/docs/io/common.rst +++ /dev/null @@ -1,7 +0,0 @@ -common -====== - -.. automodule:: atomworks.io.common - :members: - :undoc-members: - :show-inheritance: \ No newline at end of file diff --git a/docs/io/constants.rst b/docs/io/constants.rst deleted file mode 100644 index bb3abb5e..00000000 --- a/docs/io/constants.rst +++ /dev/null @@ -1,7 +0,0 @@ -constants -====== - -.. automodule:: atomworks.io.constants - :members: - :undoc-members: - :show-inheritance: \ No newline at end of file diff --git a/docs/io/enums.rst b/docs/io/enums.rst deleted file mode 100644 index 35e8c960..00000000 --- a/docs/io/enums.rst +++ /dev/null @@ -1,7 +0,0 @@ -enums -===== - -.. automodule:: atomworks.io.enums - :members: - :undoc-members: - :show-inheritance: \ No newline at end of file diff --git a/docs/io/parser.rst b/docs/io/parser.rst index adcb68a8..901ffde1 100644 --- a/docs/io/parser.rst +++ b/docs/io/parser.rst @@ -1,5 +1,5 @@ parser -==== +====== .. automodule:: atomworks.io.parser :members: diff --git a/docs/io/tools/fasta.rst b/docs/io/tools/fasta.rst index 1ccf7145..b106de84 100644 --- a/docs/io/tools/fasta.rst +++ b/docs/io/tools/fasta.rst @@ -1,5 +1,5 @@ FASTA Tools -========== +=========== .. automodule:: atomworks.io.tools.fasta :members: diff --git a/docs/io/tools/inference.rst b/docs/io/tools/inference.rst index bca36db2..4b7ec663 100644 --- a/docs/io/tools/inference.rst +++ b/docs/io/tools/inference.rst @@ -1,5 +1,5 @@ Inference Tools -============== +=============== .. automodule:: atomworks.io.tools.inference :members: diff --git a/docs/io/transforms/atom_array.rst b/docs/io/transforms/atom_array.rst index 28045f71..e39345fd 100644 --- a/docs/io/transforms/atom_array.rst +++ b/docs/io/transforms/atom_array.rst @@ -1,5 +1,5 @@ Atom Array Transforms -=================== +==================== .. automodule:: atomworks.io.transforms.atom_array :members: diff --git a/docs/io/transforms/categories.rst b/docs/io/transforms/categories.rst index f069af86..65789f8d 100644 --- a/docs/io/transforms/categories.rst +++ b/docs/io/transforms/categories.rst @@ -1,5 +1,5 @@ Category Transforms -================= +================== .. automodule:: atomworks.io.transforms.categories :members: diff --git a/docs/io/utils/ccd.rst b/docs/io/utils/ccd.rst index a3758a4a..947dcd4b 100644 --- a/docs/io/utils/ccd.rst +++ b/docs/io/utils/ccd.rst @@ -1,5 +1,5 @@ CCD Utilities -============ +============= .. automodule:: atomworks.io.utils.ccd :members: diff --git a/docs/io/utils/io_utils.rst b/docs/io/utils/io_utils.rst index 576dd2f0..1d453546 100644 --- a/docs/io/utils/io_utils.rst +++ b/docs/io/utils/io_utils.rst @@ -1,5 +1,5 @@ I/O Utilities -============ +============= .. automodule:: atomworks.io.utils.io_utils :members: diff --git a/docs/io/utils/sequence.rst b/docs/io/utils/sequence.rst index aae48e65..f8782b63 100644 --- a/docs/io/utils/sequence.rst +++ b/docs/io/utils/sequence.rst @@ -1,5 +1,5 @@ Sequence Utilities -================ +================== .. automodule:: atomworks.io.utils.sequence :members: diff --git a/docs/ml.rst b/docs/ml.rst index d4476fc2..4618eb95 100644 --- a/docs/ml.rst +++ b/docs/ml.rst @@ -7,8 +7,6 @@ Core Modules .. toctree:: :maxdepth: 2 - ml/common - ml/enums ml/encoding_definitions ml/samplers @@ -19,14 +17,10 @@ Data Processing Modules :maxdepth: 2 ml/datasets - ml/datasets/parsers - ml/preprocessing - ml/preprocessing/utils ml/pipelines ml/transforms ml/transforms/diffusion ml/transforms/dna - ml/transforms/esm ml/transforms/feature_aggregation ml/transforms/msa ml/utils \ No newline at end of file diff --git a/docs/ml/common.rst b/docs/ml/common.rst deleted file mode 100644 index ebfb0df2..00000000 --- a/docs/ml/common.rst +++ /dev/null @@ -1,9 +0,0 @@ -Common Utilities -=============== - -This module contains common utilities and functions used throughout the atomworks.ml package. - -.. automodule:: atomworks.ml.common - :members: - :undoc-members: - :show-inheritance: \ No newline at end of file diff --git a/docs/ml/datasets.rst b/docs/ml/datasets.rst index 910d96a6..0dc8424e 100644 --- a/docs/ml/datasets.rst +++ b/docs/ml/datasets.rst @@ -1,7 +1,7 @@ Datasets ======== -This module contains dataset classes and utilities for loading and processing molecular data. +This module contains dataset classes and utilities for loading and processing molecular data using a modern, composable architecture. Core Dataset Classes -------------------- @@ -11,10 +11,19 @@ Core Dataset Classes :undoc-members: :show-inheritance: -Parsers -------- +Functional Loaders +------------------ -.. automodule:: atomworks.ml.datasets.parsers +.. automodule:: atomworks.ml.datasets.loaders :members: :undoc-members: - :show-inheritance: \ No newline at end of file + :show-inheritance: + +Dataset Architecture and Migration Guide +---------------------------------------- + +.. toctree:: + :maxdepth: 2 + + datasets/datasets + datasets/parsers \ No newline at end of file diff --git a/docs/ml/datasets/datasets.rst b/docs/ml/datasets/datasets.rst new file mode 100644 index 00000000..e12134ec --- /dev/null +++ b/docs/ml/datasets/datasets.rst @@ -0,0 +1,187 @@ +Dataset Architecture +==================== + +AtomWorks provides a modern, composable dataset architecture that separates data loading, processing, and transformation concerns. This approach replaces the legacy parser-based system with functional loaders and transform pipelines. + +.. warning:: + The metadata parser system (``atomworks.ml.datasets.parsers``) is **deprecated** and will be removed in a future version. + Use the new loader-based approach with ``FileDataset`` and ``PandasDataset`` instead. + +Modern Dataset Architecture +--------------------------- + +The current AtomWorks dataset system consists of three main components: + +1. **Datasets**: Container classes that manage data access and indexing +2. **Loaders**: Functions that process raw data into transform-ready format +3. **Transforms**: Pipelines that convert loaded data into model inputs + +Dataset Classes +--------------- + +.. automodule:: atomworks.ml.datasets.datasets + :members: + :undoc-members: + :show-inheritance: + +Functional Loaders +------------------ + +Loaders are functions that process raw dataset output (e.g., pandas Series) into a Transform-ready format. +They replace the legacy parser classes with a more flexible, functional approach. + +.. automodule:: atomworks.ml.datasets.loaders + :members: + :undoc-members: + :show-inheritance: + +Basic Usage Examples +~~~~~~~~~~~~~~~~~~~~ + +**File-based datasets** (replacing simple file parsers): + +.. code-block:: python + + from atomworks.ml.datasets.datasets import FileDataset + from atomworks.io import parse + + def simple_loading_fn(raw_data) -> dict: + """Simple loading function that parses structural data.""" + parse_output = parse(raw_data) + return {"atom_array": parse_output["assemblies"]["1"][0]} + + dataset = FileDataset.from_directory( + directory="/path/to/structures", + name="my_dataset", + loader=simple_loading_fn + ) + +**Tabular datasets** (replacing metadata parsers): + +.. code-block:: python + + from atomworks.ml.datasets.datasets import PandasDataset + from atomworks.ml.datasets.loaders import loader_with_query_pn_units + + dataset = PandasDataset( + data="metadata.parquet", + name="interfaces_dataset", + loader=loader_with_query_pn_units( + pn_unit_iid_colnames=["pn_unit_1_iid", "pn_unit_2_iid"] + ) + ) + +**Custom loaders** for specialized use cases: + +.. code-block:: python + + def custom_loader(row: pd.Series) -> dict: + """Custom loader with specific processing logic.""" + # Load structure + structure_path = Path(row["path"]) + parse_output = parse(structure_path) + + # Extract specific metadata + metadata = { + "resolution": row.get("resolution", None), + "method": row.get("method", "unknown"), + "custom_field": row.get("custom_field", "default_value") + } + + return { + "atom_array": parse_output["assemblies"]["1"][0], + "extra_info": metadata, + "example_id": row["example_id"] + } + + dataset = PandasDataset( + data=my_dataframe, + name="custom_dataset", + loader=custom_loader + ) + +Common Loader Patterns +~~~~~~~~~~~~~~~~~~~~~~ + +**Base loader** for standard structure loading: + +.. code-block:: python + + from atomworks.ml.datasets.loaders import loader_base + + loader = loader_base( + example_id_colname="example_id", + path_colname="path", + assembly_id_colname="assembly_id", + base_path="/data/structures", + extension=".cif" + ) + +**Interface loader** for protein-protein interfaces: + +.. code-block:: python + + from atomworks.ml.datasets.loaders import loader_with_query_pn_units + + loader = loader_with_query_pn_units( + pn_unit_iid_colnames=["pn_unit_1_iid", "pn_unit_2_iid"], + base_path="/data/pdb", + extension=".cif.gz" + ) + +**Validation loader** with scoring targets: + +.. code-block:: python + + from atomworks.ml.datasets.loaders import loader_with_interfaces_and_pn_units_to_score + + loader = loader_with_interfaces_and_pn_units_to_score( + interfaces_to_score_colname="interfaces_to_score", + pn_units_to_score_colname="pn_units_to_score" + ) + +Integration with Transform Pipelines +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Loaders work seamlessly with AtomWorks transform pipelines. The loader output becomes the input to the transform pipeline: + +.. code-block:: python + + from atomworks.ml.transforms.base import Compose + from atomworks.ml.transforms.crop import CropSpatialLikeAF3 + from atomworks.ml.transforms.atom_array import AddGlobalAtomIdAnnotation + + # Create a transform pipeline + transform_pipeline = Compose([ + AddGlobalAtomIdAnnotation(), + CropSpatialLikeAF3(crop_size=256), + ]) + + # Create dataset with both loader and transforms + dataset = PandasDataset( + data="metadata.parquet", + name="my_dataset", + loader=loader_with_query_pn_units( + pn_unit_iid_colnames=["pn_unit_1_iid", "pn_unit_2_iid"] + ), + transform=transform_pipeline + ) + + # Access processed data + example = dataset[0] # Returns transformed data ready for model input + +Data Flow +~~~~~~~~~ + +The complete data flow in the new architecture is: + +1. **Raw Data**: File paths or DataFrame rows +2. **Loader**: Processes raw data into standardized format with ``AtomArray`` +3. **Transform Pipeline**: Converts loaded data into model-ready tensors +4. **Model Input**: Final processed data ready for training/inference + +This separation allows for: +- **Reusable loaders** across different datasets +- **Composable transforms** that can be mixed and matched +- **Easy testing** of individual components +- **Clear debugging** when issues arise diff --git a/docs/ml/datasets/parsers.rst b/docs/ml/datasets/parsers.rst index 4556ffc5..53a1edf0 100644 --- a/docs/ml/datasets/parsers.rst +++ b/docs/ml/datasets/parsers.rst @@ -1,28 +1,22 @@ Dataset Parsers =============== -This module contains parsers for different types of dataset metadata and structures. - -Base Parser ------------ +.. automodule:: atomworks.ml.datasets.parsers + :members: + :undoc-members: + :show-inheritance: .. automodule:: atomworks.ml.datasets.parsers.base :members: :undoc-members: :show-inheritance: -Custom Metadata Parsers ----------------------- - .. automodule:: atomworks.ml.datasets.parsers.custom_metadata_row_parsers :members: :undoc-members: :show-inheritance: -Default Metadata Parsers ------------------------ - .. automodule:: atomworks.ml.datasets.parsers.default_metadata_row_parsers :members: :undoc-members: - :show-inheritance: \ No newline at end of file + :show-inheritance: diff --git a/docs/ml/encoding_definitions.rst b/docs/ml/encoding_definitions.rst index 334c6cd8..0e74107e 100644 --- a/docs/ml/encoding_definitions.rst +++ b/docs/ml/encoding_definitions.rst @@ -1,5 +1,5 @@ Encoding Definitions -=================== +==================== This module contains definitions for various encoding schemes used in the atomworks.ml package. diff --git a/docs/ml/enums.rst b/docs/ml/enums.rst deleted file mode 100644 index 83e5d8e4..00000000 --- a/docs/ml/enums.rst +++ /dev/null @@ -1,9 +0,0 @@ -Enums -===== - -This module contains enumeration classes used throughout the atomworks.ml package. - -.. automodule:: atomworks.ml.enums - :members: - :undoc-members: - :show-inheritance: \ No newline at end of file diff --git a/docs/ml/preprocessing.rst b/docs/ml/preprocessing.rst index a51043d1..1fe05690 100644 --- a/docs/ml/preprocessing.rst +++ b/docs/ml/preprocessing.rst @@ -4,7 +4,7 @@ Preprocessing This module contains utilities for preprocessing molecular structures and data. Core Preprocessing Functions ---------------------------- +---------------------------- .. automodule:: atomworks.ml.preprocessing.get_pn_unit_data_from_structure :members: diff --git a/docs/ml/preprocessing/utils.rst b/docs/ml/preprocessing/utils.rst deleted file mode 100644 index 4a794a2a..00000000 --- a/docs/ml/preprocessing/utils.rst +++ /dev/null @@ -1,28 +0,0 @@ -Preprocessing Utilities -====================== - -This module contains utility functions for preprocessing tasks. - -Clustering Utilities -------------------- - -.. automodule:: atomworks.ml.preprocessing.utils.clustering - :members: - :undoc-members: - :show-inheritance: - -FASTA Utilities ---------------- - -.. automodule:: atomworks.ml.preprocessing.utils.fasta - :members: - :undoc-members: - :show-inheritance: - -Structure Utilities ------------------- - -.. automodule:: atomworks.ml.preprocessing.utils.structure_utils - :members: - :undoc-members: - :show-inheritance: \ No newline at end of file diff --git a/docs/ml/transforms.rst b/docs/ml/transforms.rst index c13d5804..74ccc7ad 100644 --- a/docs/ml/transforms.rst +++ b/docs/ml/transforms.rst @@ -1,10 +1,10 @@ Transforms -========= +========== This module contains various transformation classes and utilities for processing molecular data. Core Transform Classes --------------------- +----------------------- .. automodule:: atomworks.ml.transforms.base :members: @@ -87,7 +87,7 @@ Core Transform Classes :show-inheritance: Utility Modules -------------- +--------------- .. automodule:: atomworks.ml.transforms._checks :members: @@ -115,10 +115,10 @@ Utility Modules :show-inheritance: Submodules ---------- +----------- DNA Transforms -~~~~~~~~~~~~ +~~~~~~~~~~~~~~ .. automodule:: atomworks.ml.transforms.dna :members: @@ -126,23 +126,15 @@ DNA Transforms :show-inheritance: Diffusion Transforms -~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~ .. automodule:: atomworks.ml.transforms.diffusion :members: :undoc-members: :show-inheritance: -ESM Transforms -~~~~~~~~~~~~ - -.. automodule:: atomworks.ml.transforms.esm - :members: - :undoc-members: - :show-inheritance: - Feature Aggregation Transforms -~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: atomworks.ml.transforms.feature_aggregation :members: @@ -150,7 +142,7 @@ Feature Aggregation Transforms :show-inheritance: MSA Transforms -~~~~~~~~~~~~ +~~~~~~~~~~~~~~ .. automodule:: atomworks.ml.transforms.msa :members: diff --git a/docs/ml/transforms/diffusion.rst b/docs/ml/transforms/diffusion.rst index a7b5322b..d58e0ff5 100644 --- a/docs/ml/transforms/diffusion.rst +++ b/docs/ml/transforms/diffusion.rst @@ -1,5 +1,5 @@ Diffusion Transforms -================== +==================== This module contains transformations for diffusion-based structure processing. diff --git a/docs/ml/transforms/dna.rst b/docs/ml/transforms/dna.rst index 754f3b05..08565a6f 100644 --- a/docs/ml/transforms/dna.rst +++ b/docs/ml/transforms/dna.rst @@ -1,5 +1,5 @@ DNA Transforms -============= +============== This module contains transformations specific to DNA processing. diff --git a/docs/ml/transforms/esm.rst b/docs/ml/transforms/esm.rst deleted file mode 100644 index 1bf44857..00000000 --- a/docs/ml/transforms/esm.rst +++ /dev/null @@ -1,9 +0,0 @@ -ESM Transforms -============= - -This module contains transformations specific to ESM models and features. - -.. automodule:: atomworks.ml.transforms.esm - :members: - :undoc-members: - :show-inheritance: \ No newline at end of file diff --git a/docs/ml/utils.rst b/docs/ml/utils.rst index ebd4970a..2d5db923 100644 --- a/docs/ml/utils.rst +++ b/docs/ml/utils.rst @@ -36,7 +36,7 @@ I/O Utilities :show-inheritance: Miscellaneous Utilities ----------------------- +----------------------- .. automodule:: atomworks.ml.utils.misc :members: @@ -44,7 +44,7 @@ Miscellaneous Utilities :show-inheritance: Nested Dictionary Utilities --------------------------- +--------------------------- .. automodule:: atomworks.ml.utils.nested_dict :members: @@ -60,7 +60,7 @@ NumPy Utilities :show-inheritance: Random Number Generation ------------------------ +------------------------ .. automodule:: atomworks.ml.utils.rng :members: diff --git a/pyproject.toml b/pyproject.toml index 22181121..927107a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "atomworks" version = "1.0.0" description = "A research-oriented data toolkit for training biomolecular deep-learning foundation models" readme = "README.md" -requires-python = ">=3.12" +requires-python = ">=3.11" authors = [ { name = "Institute for Protein Design", email = "contact@ipd.uw.edu" } ] diff --git a/src/atomworks/__init__.py b/src/atomworks/__init__.py index bb4e2848..25956688 100644 --- a/src/atomworks/__init__.py +++ b/src/atomworks/__init__.py @@ -1,8 +1,8 @@ -""" -atomworks - Unified package for biological data I/O and machine learning. +"""Unified package for biological data I/O and machine learning. -This package combines functionality from atomworks.io (I/O operations) and atomworks.ml (ML utilities) -into a unified interface for biological data processing and machine learning. +This package combines functionality from :mod:`atomworks.io` (I/O operations) and +:mod:`atomworks.ml` (ML utilities) into a unified interface for biological data +processing and machine learning. """ import importlib diff --git a/src/atomworks/biotite_patch.py b/src/atomworks/biotite_patch.py index a1af4cc6..66743ff2 100644 --- a/src/atomworks/biotite_patch.py +++ b/src/atomworks/biotite_patch.py @@ -1,4 +1,12 @@ -"""Collection of monkey patches for biotite.""" +"""Collection of monkey patches for biotite. + +This module provides patches and extensions to the Biotite library to enhance +functionality and fix version-specific issues. + +References: + `Biotite Documentation `_ + `Biotite Structure Module `_ +""" from typing import Callable @@ -16,8 +24,7 @@ def apply_if_version_lt(version: str, min_version: str) -> Callable: - """ - Decorator to apply a function only if the given version is less than the given minimal version. + """Decorator to apply a function only if the given version is less than the given minimal version. Args: version: Version to check. diff --git a/src/atomworks/common.py b/src/atomworks/common.py index 89a8d94b..6559b88e 100644 --- a/src/atomworks/common.py +++ b/src/atomworks/common.py @@ -1,4 +1,4 @@ -"""Common functions used throughout the project.""" +"""Common utility functions used throughout the project.""" import copy import hashlib @@ -11,41 +11,94 @@ def exists(obj: Any) -> bool: - """Check that `obj` is not `None`.""" + """Check that obj is not None. + + Args: + obj: The object to check. + + Returns: + True if obj is not None, False otherwise. + """ return obj is not None def default(obj: Any, default: Any) -> Any: - """Return `obj` if not `None`, otherwise return `default`.""" + """Return obj if not None, otherwise return default. + + Args: + obj: The primary object to return. + default: The fallback value if obj is None. + + Returns: + obj if it is not None, otherwise default. + """ return obj if exists(obj) else default def to_hashable(element: Any) -> Any: - """Convert an element to a hashable type.""" + """Convert an element to a hashable type. + + Args: + element: The element to convert. + + Returns: + The element if already hashable, otherwise converted to a tuple. + """ return element if isinstance(element, int | str | np.integer | np.str_) else tuple(element) def string_to_md5_hash(s: str, truncate: int = 32) -> str: - """Generate an MD5 hash of a string and return the first `truncate` characters.""" + """Generate an MD5 hash of a string and return the first truncate characters. + + Args: + s: The string to hash. + truncate: Number of characters to return from the hash. + + Returns: + The truncated MD5 hash as a string. + """ full_hash = hashlib.md5(s.encode("utf-8")).hexdigest() return full_hash[:truncate] def sum_string_arrays(*objs: np.ndarray | str) -> np.ndarray: - """ - Sum a list of string arrays or strings into a single string array by concatenating them and - determining the shortest string length to set as dtype. + """Sum a list of string arrays or strings into a single string array. + + Concatenates the arrays and determines the shortest string length to set as dtype. + + Args: + *objs: Variable number of string arrays or strings to sum. + + Returns: + A single concatenated string array. """ return reduce(np.char.add, objs).astype(object).astype(str) def not_isin(element: np.ndarray, array: np.ndarray, **isin_kwargs) -> np.ndarray: - """Like `~np.isin`, but more efficient.""" + """Like ~np.isin, but more efficient. + + Args: + element: The array to test. + array: The array of values to test against. + **isin_kwargs: Additional keyword arguments for np.isin. + + Returns: + Boolean array indicating which elements are not in the array. + """ return np.isin(element, array, invert=True, **isin_kwargs) def listmap(func: Callable, *iterables) -> list: - """Like `map`, but returns a list instead of an iterator.""" + """Like map, but returns a list instead of an iterator. + + Args: + func: The function to apply. + *iterables: Variable number of iterables to map over. + + Returns: + A list containing the results of applying func to the iterables. + """ return compose(list, map)(func, *iterables) @@ -55,6 +108,12 @@ def as_list(value: Any) -> list: Handles various types using duck typing: - Iterable objects (lists, tuples, strings, etc.): converted to list - Single values: wrapped in a list + + Args: + value: The value to convert to a list. + + Returns: + A list containing the value(s). """ try: # Try to iterate over the value (duck typing approach) @@ -68,7 +127,16 @@ def as_list(value: Any) -> list: def immutable_lru_cache(maxsize: int = 128, typed: bool = False, deepcopy: bool = True) -> Callable: - """An immutable version of `lru_cache` for caching functions that return mutable objects.""" + """An immutable version of lru_cache for caching functions that return mutable objects. + + Args: + maxsize: Maximum number of items to cache. + typed: Whether to treat different types as separate cache entries. + deepcopy: Whether to use deep copy for immutable caching. + + Returns: + A decorator that provides immutable caching functionality. + """ copy_func = copy.deepcopy if deepcopy else copy.copy def decorator(func: Callable) -> Callable: @@ -84,27 +152,33 @@ def wrapper(*args, **kwargs) -> Any: class KeyToIntMapper: - """ - Maps keys to unique integers based on the order of the first appearance of the key. + """Maps keys to unique integers based on the order of the first appearance of the key. - This is useful for mapping id's such as `chain_id`, `chain_entity`, `molecule_iid`, etc. + This is useful for mapping id's such as chain_id, chain_entity, molecule_iid, etc. to integers. Example: - ```python - chain_id_to_int = KeyToIntMapper() - chain_id_to_int("A") # 0 - chain_id_to_int("C") # 1 - chain_id_to_int("A") # 0 - chain_id_to_int("B") # 2 - ``` + >>> chain_id_to_int = KeyToIntMapper() + >>> chain_id_to_int("A") # 0 + >>> chain_id_to_int("C") # 1 + >>> chain_id_to_int("A") # 0 + >>> chain_id_to_int("B") # 2 """ def __init__(self): + """Initialize KeyToIntMapper with empty mapping.""" self.key_to_id = {} self.next_id = 0 def __call__(self, value: Any) -> int: + """Map a key to a unique integer. + + Args: + value: The key to map. + + Returns: + The unique integer assigned to the key. + """ if value not in self.key_to_id: self.key_to_id[value] = self.next_id self.next_id += 1 diff --git a/src/atomworks/constants.py b/src/atomworks/constants.py index 06f907a5..100d71e2 100644 --- a/src/atomworks/constants.py +++ b/src/atomworks/constants.py @@ -1,4 +1,4 @@ -"""Constants used in the `atomworks.io` package.""" +"""Constants used in the AtomWorks library.""" import logging import os @@ -16,7 +16,14 @@ def _load_env_var(var_name: str) -> str | None: - """Load an environment variable, returning None if it is not set.""" + """Load an environment variable, returning None if it is not set. + + Args: + var_name: The name of the environment variable to load. + + Returns: + The value of the environment variable, or None if not set. + """ try: return os.environ[var_name] except KeyError: @@ -32,10 +39,18 @@ def _load_env_var(var_name: str) -> str | None: CCD_MIRROR_PATH: Final[str] = _load_env_var("CCD_MIRROR_PATH") -"""A path to a carbon-copy mirror of the CCD ligands in the RCSB CCD.""" +"""A path to a carbon-copy mirror of the CCD ligands in the RCSB CCD. + +Reference: + `RCSB Chemical Component Dictionary `_ +""" PDB_MIRROR_PATH: Final[str] = _load_env_var("PDB_MIRROR_PATH") -"""A path to a mirror of the PDB.""" +"""A path to a mirror of the PDB. + +Reference: + `Protein Data Bank `_ +""" UNKNOWN_ELEMENT: Final[str] = "X" """The element name for an unknown element.""" @@ -59,13 +74,27 @@ def _load_env_var(var_name: str) -> str | None: "Rg": 111, "Cn": 112, "Nh": 113, "Fl": 114, "Mc": 115, "Lv": 116, "Ts": 117, "Og": 118, UNKNOWN_ELEMENT: UNKNOWN_ATOMIC_NUMBER })) -"""Map canonical *UPPERCASE* 2 letter element names to their atomic numbers. WARNING: Case-sensitive.""" +"""Map canonical *UPPERCASE* 2 letter element names to their atomic numbers. + +Warning: + Case-sensitive. + +Reference: + `IUPAC Periodic Table `_ +""" ATOMIC_NUMBER_TO_ELEMENT: Final[MappingProxyType[int | str, str]] = MappingProxyType( {v: k for k, v in ELEMENT_NAME_TO_ATOMIC_NUMBER.items()} | {str(v): k for k, v in ELEMENT_NAME_TO_ATOMIC_NUMBER.items()} ) -"""Map atomic numbers (int/str) to their canonical *UPPERCASE* 2 letter element names. WARNING: Case-sensitive.""" +"""Map atomic numbers (int/str) to their canonical *UPPERCASE* 2 letter element names. + +Warning: + Case-sensitive. + +Reference: + `IUPAC Periodic Table `_ +""" METAL_ELEMENTS: Final[frozenset[str]] = frozenset(map(str.upper, [ "Li", "Na", "K", "Rb", "Cs", "Be", "Mg", "Ca", "Sr", "Ba", @@ -74,7 +103,14 @@ def _load_env_var(var_name: str) -> str | None: "La", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Al", "Ga", "In", "Sn", "Tl", "Pb", "Bi", ])) -"""A set of all metal elements, all *UPPERCASE*. WARNING: Case-sensitive.""" +"""A set of all metal elements, all *UPPERCASE*. + +Warning: + Case-sensitive. + +Reference: + `IUPAC Periodic Table - Metals `_ +""" # fmt: on CHEM_COMP_TYPES: Final[tuple[str, ...]] = tuple( @@ -114,10 +150,11 @@ def _load_env_var(var_name: str) -> str | None: ] ) """Allowed Chemical Component Types for residues in the PDB + `mask`. + All uppercase. Reference: - - http://mmcif.rcsb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_chem_comp.type.html + `RCSB mmCIF Dictionary - chem_comp.type `_ """ AA_LIKE_CHEM_TYPES: Final[frozenset[str]] = frozenset( @@ -273,7 +310,7 @@ def _load_env_var(var_name: str) -> str | None: """A set of bond types that are considered when adding bonds to the atom array. Reference: - - https://mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_struct_conn.conn_type_id.html + `struct_conn.conn_type_id `_ """ STRUCT_CONN_BOND_ORDER_TO_INT: Final[MappingProxyType[str, int]] = MappingProxyType( @@ -287,8 +324,8 @@ def _load_env_var(var_name: str) -> str | None: """ Mapping from `struct_conn.pdbx_value_order` to integer bond orders. -References: - - https://mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_struct_conn.pdbx_value_order.html +Reference: + `struct_conn.pdbx_value_order `_ """ BIOTITE_BOND_TYPE_TO_BOND_ORDER: Final[MappingProxyType[BondType, int]] = MappingProxyType( @@ -320,7 +357,7 @@ def _load_env_var(var_name: str) -> str | None: Only elements that have unambiguous valences are included. Reference: - - https://www.rdkit.org/docs/RDKit_Book.html#valence-calculation-and-allowed-valences + `RDKit Book - Valence Calculation `_ """ CRYSTALLIZATION_AIDS: Final[list[str]] = [ @@ -344,7 +381,7 @@ def _load_env_var(var_name: str) -> str | None: """A list of CCD codes of common crystallization aids used in the crystallization of proteins. Reference: - - AF3 (Supp. Table 9) https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf + `AF3 (Supp. Table 9) `_ """ AF3_EXCLUDED_LIGANDS: Final[list[str]] = [ @@ -483,7 +520,7 @@ def _load_env_var(var_name: str) -> str | None: """A list of CCD codes of ligands that were excluded in AF3. Reference: - - AF3 (Supp. Table 10) https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf + `AF3 (Supp. Table 10) `_ """ AF3_EXCLUDED_LIGANDS_REGEX: Final[str] = r"(?:^|,)\s*(?:" + "|".join(AF3_EXCLUDED_LIGANDS) + r")\s*(?:,|$)" @@ -519,21 +556,21 @@ def _load_env_var(var_name: str) -> str | None: """A dictionary that maps three-letter amino acid codes to one-letter codes. Reference: - - Biotite: https://github.com/biotite-dev/biotite/blob/v0.41.0/src/biotite/sequence/seqtypes.py#L348-L556 + `Biotite seqtypes.py `_ """ UNKNOWN_LIGAND: Final[str] = sys.intern("UNL") """The CCD code for unknown ligands (`UNL`). Reference: - - https://www.wwpdb.org/documentation/procedure + `wwPDB Documentation `_ """ UNKNOWN_AA: Final[str] = sys.intern("UNK") """The CCD code for unknown amino acids (`UNK`). Reference: - - https://www.wwpdb.org/documentation/procedure + `wwPDB Documentation `_ """ # TODO: Change these to something unique. @@ -541,21 +578,21 @@ def _load_env_var(var_name: str) -> str | None: """The CCD code for unknown RNA nucleotides (`N`). Reference: - - https://www.wwpdb.org/documentation/procedure + `wwPDB Documentation `_ """ UNKNOWN_DNA: Final[str] = sys.intern("DN") """The CCD code for unknown DNA nucleotides (`DN`). Reference: - - https://www.wwpdb.org/documentation/procedure + `wwPDB Documentation `_ """ UNKNOWN_ATOM: Final[str] = sys.intern("UNX") """The CCD code for unknown atoms (`UNX`). Reference: - - https://www.wwpdb.org/documentation/procedure + `wwPDB Documentation `_ """ GAP: Final[str] = sys.intern("") diff --git a/src/atomworks/enums.py b/src/atomworks/enums.py index 11ba5bc9..b7d6a1a8 100644 --- a/src/atomworks/enums.py +++ b/src/atomworks/enums.py @@ -1,4 +1,4 @@ -"""Enums used accross `atomworks`.""" +"""Enums used across atomworks.""" from enum import IntEnum, StrEnum, auto from types import MappingProxyType @@ -19,13 +19,15 @@ class ChainType(IntEnum): """IntEnum representing the type of chain in a RCSB mmCIF file from the Protein Data Bank (PDB). - Useful constants relating to ChainType are defined in ChainTypeInfo. + Useful constants relating to ChainType are defined in :class:`ChainTypeInfo`. - Sources: - - https://mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_entity.type.html - - https://mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_entity_poly.type.html + Note: + The chain type fields in the PDB are not stable; note the specific versions + of the dictionaries used (updated November, 2024) - NOTE: The chain type fields in the PDB are not stable; note the specific versions of the dictionaries used (updated November, 2024) + References: + `RCSB mmCIF Dictionary - entity.type `_ + `RCSB mmCIF Dictionary - entity_poly.type `_ """ # Polymers @@ -46,7 +48,17 @@ class ChainType(IntEnum): @classmethod def from_string(cls, str_value: str) -> "ChainType": - """Convert a string to a ChainType enum.""" + """Convert a string to a ChainType enum. + + Args: + str_value: The string value to convert. + + Returns: + The corresponding ChainType enum. + + Raises: + ValueError: If the string value is not a valid chain type. + """ try: return ChainTypeInfo.STRING_TO_ENUM[str_value.upper()] except KeyError: @@ -56,36 +68,67 @@ def from_string(cls, str_value: str) -> "ChainType": @staticmethod def get_chain_type_strings() -> list[str]: - """Get a list of all chain type strings.""" + """Get a list of all chain type strings. + + Returns: + List of all valid chain type strings. + """ return list(ChainTypeInfo.STRING_TO_ENUM.keys()) @staticmethod def get_polymers() -> list["ChainType"]: - """Get a list of all polymer chain types.""" + """Get a list of all polymer chain types. + + Returns: + List of polymer chain types. + """ return ChainTypeInfo.POLYMERS @staticmethod def get_non_polymers() -> list["ChainType"]: - """Get a list of all non-polymer chain types.""" + """Get a list of all non-polymer chain types. + + Returns: + List of non-polymer chain types. + """ return ChainTypeInfo.NON_POLYMERS @staticmethod def get_proteins() -> list["ChainType"]: - """Get a list of all protein chain types.""" + """Get a list of all protein chain types. + + Returns: + List of protein chain types. + """ return ChainTypeInfo.PROTEINS @staticmethod def get_nucleic_acids() -> list["ChainType"]: - """Get a list of all nucleic acid chain types.""" + """Get a list of all nucleic acid chain types. + + Returns: + List of nucleic acid chain types. + """ return ChainTypeInfo.NUCLEIC_ACIDS @staticmethod def get_all_types() -> list["ChainType"]: - """Get a list of all chain types.""" + """Get a list of all chain types. + + Returns: + List of all chain types. + """ return list(ChainType) def __eq__(self, other: Union["ChainType", int, str]) -> bool: - """Check if two ChainType enums are equal.""" + """Check if two ChainType enums are equal. + + Args: + other: Another ChainType, int, or string to compare with. + + Returns: + True if the chain types are equal, False otherwise. + """ if isinstance(other, ChainType): return self.value == other.value elif isinstance(other, int): @@ -101,44 +144,85 @@ def __eq__(self, other: Union["ChainType", int, str]) -> bool: return NotImplemented def __hash__(self): - """Hash a ChainType enum.""" + """Hash a ChainType enum. + + Returns: + Hash value of the enum. + """ return hash(self.value) def __str__(self) -> str: - """Convert a ChainType enum to a string.""" + """Convert a ChainType enum to a string. + + Returns: + String representation of the chain type. + """ return self.to_string() def get_valid_chem_comp_types(self) -> set[str]: - """Get the set of valid chemical component types for a ChainType.""" + """Get the set of valid chemical component types for a ChainType. + + Returns: + Set of valid chemical component types for this chain type. + """ return ChainTypeInfo.VALID_CHEM_COMP_TYPES[self] def is_protein(self) -> bool: - """Check if a ChainType is a protein.""" + """Check if a ChainType is a protein. + + Returns: + True if this chain type represents a protein, False otherwise. + """ return self in ChainTypeInfo.PROTEINS def is_nucleic_acid(self) -> bool: - """Check if a ChainType is a nucleic acid.""" + """Check if a ChainType is a nucleic acid. + + Returns: + True if this chain type represents a nucleic acid, False otherwise. + """ return self in ChainTypeInfo.NUCLEIC_ACIDS def is_polymer(self) -> bool: - """Check if a ChainType is a polymer.""" + """Check if a ChainType is a polymer. + + Returns: + True if this chain type represents a polymer, False otherwise. + """ return self in ChainTypeInfo.POLYMERS def is_non_polymer(self) -> bool: - """Check if a ChainType is a non-polymer.""" + """Check if a ChainType is a non-polymer. + + Returns: + True if this chain type represents a non-polymer, False otherwise. + """ return self in ChainTypeInfo.NON_POLYMERS def to_string(self) -> str: - """ - Convert a ChainType enum to a string. + """Convert a ChainType enum to a string. + + Note: + Returns UPPERCASE string (e.g., "POLYPEPTIDE(D)" instead of "polypeptide(D)") - NOTE: Returns UPPERCASE string (e.g., "POLYPEPTIDE(D)" instead of "polypeptide(D)") + Returns: + Uppercase string representation of the chain type. """ return ChainTypeInfo.ENUM_TO_STRING[self] @staticmethod def as_enum(value: Union[str, int, "ChainType"]) -> "ChainType": - """Convert a string, int, or ChainType to a ChainType enum.""" + """Convert a string, int, or ChainType to a ChainType enum. + + Args: + value: The value to convert to a ChainType enum. + + Returns: + The corresponding ChainType enum. + + Raises: + ValueError: If the value cannot be converted to a ChainType. + """ if isinstance(value, ChainType): return value elif isinstance(value, str): @@ -150,10 +234,10 @@ def as_enum(value: Union[str, int, "ChainType"]) -> "ChainType": class ChainTypeInfo: - """ - Companion class containing metadata and helper methods for ChainType enum. + """Companion class containing metadata and helper methods for ChainType enum. - This class should not be instantiated - it serves as a namespace for ChainType-related constants and utilities. + This class should not be instantiated - it serves as a namespace for + ChainType-related constants and utilities. """ POLYMERS: Final[tuple[ChainType, ...]] = ( @@ -253,11 +337,17 @@ class GroundTruthConformerPolicy(IntEnum): """Enum for ground truth conformer policy. Possible values are: - - REPLACE: Use the ground-truth coordinates as the reference conformer, replacing the coordinated generated by RDKit in-place (and add a flag to indicate that the coordinates were replaced) - - ADD: Return an additional feature (with the same shape as `ref_pos`) containing the ground-truth coordinates - - FALLBACK: Use the ground-truth coordinates only if our standard conformer generation pipeline fails (e.g., we cannot generate a conformer with RDKit, - and the molecule is either not in the CCD or the CCD entry is invalid) - - IGNORE: Do not use the ground-truth coordinates as the reference conformer under any circumstances + - REPLACE: Use the ground-truth coordinates as the reference conformer, + replacing the coordinates generated by RDKit in-place (and add a flag + to indicate that the coordinates were replaced) + - ADD: Return an additional feature (with the same shape as ref_pos) + containing the ground-truth coordinates + - FALLBACK: Use the ground-truth coordinates only if our standard + conformer generation pipeline fails (e.g., we cannot generate a + conformer with RDKit, and the molecule is either not in the CCD or + the CCD entry is invalid) + - IGNORE: Do not use the ground-truth coordinates as the reference + conformer under any circumstances """ REPLACE = 1 @@ -270,9 +360,9 @@ class HydrogenPolicy(StrEnum): """Enum for hydrogen policy. Possible values are: - - KEEP: Keep the hydrogens as they are - - REMOVE: Remove the hydrogens - - INFER: Infer the hydrogens from the atom array + - KEEP: Keep the hydrogens as they are + - REMOVE: Remove the hydrogens + - INFER: Infer the hydrogens from the atom array """ KEEP = auto() diff --git a/src/atomworks/io/__init__.py b/src/atomworks/io/__init__.py index d545e9ca..504d52bb 100644 --- a/src/atomworks/io/__init__.py +++ b/src/atomworks/io/__init__.py @@ -1,5 +1,4 @@ -""" -atomworks.io - Input/Output operations for biological data structures. +"""Input/Output operations for biological data structures. This subpackage provides functionality for parsing, converting, and manipulating biological data formats, originally from the atomworks.io package. diff --git a/src/atomworks/io/parser.py b/src/atomworks/io/parser.py index 7d8dc3f5..888e9ba0 100644 --- a/src/atomworks/io/parser.py +++ b/src/atomworks/io/parser.py @@ -1,4 +1,12 @@ -"""Entrypoint for parsing atomic-level structure files (e.g., PDB, CIF) into Biotite-compatible data structures.""" +"""Entrypoint for parsing atomic-level structure files into Biotite-compatible data structures. + +This module provides functionality for parsing PDB, CIF, and other structure files +into Biotite-compatible data structures with various processing options. + +References: + `Biotite Structure I/O `_ + `mmCIF Format Specification `_ +""" from __future__ import annotations @@ -102,22 +110,22 @@ def parse( - Perform analogous cleaning/processing steps on an existing AtomArray or AtomArrayStack. We categorize arguments into two groups: - - Wrapper arguments: Arguments that are used within the wrapping `parse` method (e.g., caching) + - Wrapper arguments: Arguments that are used within the wrapping parse method (e.g., caching) - CIF parsing arguments: Arguments that control structure parsing and are ultimately are passed - to the `_parse_from_atom_array` method (regardless of file type, we convert to an AtomArray before parsing) + to the _parse_from_atom_array method (regardless of file type, we convert to an AtomArray before parsing) Args: filename (PathLike | io.StringIO | io.BytesIO): Either a Path or buffer to the file. This may be any format of - atomic-level structure (e.g. .cif, .bcif, .cif.gz, .pdb), although .cif files are *strongly* recommended. + atomic-level structure (e.g. .cif, .bcif, .cif.gz, .pdb), although .cif files are strongly recommended. - *** Wrapper arguments *** + **Wrapper arguments:** file_type (Literal["cif", "pdb"] | None, optional): The file type of the structure file. If not provided, the file type will be inferred automatically. load_from_cache (bool, optional): Whether to load pre-compiled results from cache. Defaults to False. cache_dir (PathLike, optional): Directory path to save pre-compiled results. Defaults to None. save_to_cache (bool, optional): Whether to save the results to cache when building the structure. Defaults to False. - *** Parsing arguments *** + **Parsing arguments:** ccd_mirror_path (str, optional): Path to the local mirror of the Chemical Component Dictionary (recommended). If not provided, Biotite's built-in CCD will be used. add_missing_atoms (bool, optional): Whether to add missing atoms to the @@ -150,19 +158,26 @@ def parse( build_assembly (string, list, or tuple, optional): Specifies which assembly to build, if any. Options are None (e.g., asymmetric unit), "first", "all", or a list or tuple of assembly IDs. Defaults to "all". extra_fields (list, optional): A list of extra fields to include in the AtomArrayStack. Defaults to None. "all" includes all fields. - Only support mmCIF files. + Only supports mmCIF files. keep_cif_block (bool, optional): Whether to keep the CIF block in the result. Defaults to False. Returns: dict: A dictionary containing the following keys: - chain_info: A dictionary mapping chain ID to sequence, type (as an IntEnum), RCSB entity, + + chain_info + A dictionary mapping chain ID to sequence, type (as an IntEnum), RCSB entity, EC number, and other information. - ligand_info: A dictionary containing ligand of interest information. - asym_unit: An AtomArrayStack instance representing the asymmetric unit. - assemblies: A dictionary mapping assembly IDs to AtomArrayStack instances. - metadata: A dictionary containing metadata about the structure + ligand_info + A dictionary containing ligand of interest information. + asym_unit + An AtomArrayStack instance representing the asymmetric unit. + assemblies + A dictionary mapping assembly IDs to AtomArrayStack instances. + metadata + A dictionary containing metadata about the structure (e.g., resolution, deposition date, etc.). - extra_info: A dictionary with information for cross-compatibility and caching. + extra_info + A dictionary with information for cross-compatibility and caching. Should typically not be used directly. """ @@ -536,7 +551,7 @@ def parse_atom_array( if msa_path != "": data_dict["chain_info"][chain]["msa_path"] = Path(msa_path) - # ... optionally, build assemblies and add assembly-specifc annotation (instance IDs) + # ... optionally, build assemblies and add assembly-specifc annotation (instance IDs like `chain_iid`, `pn_unit_iid`, `molecule_iid`) if exists(build_assembly): assert ( build_assembly in ["first", "all", "_spoof"] or isinstance(build_assembly, list | tuple) @@ -720,7 +735,7 @@ def _parse_from_pdb(filename: os.PathLike, **parse_from_cif_kwargs) -> dict[str, updated_chain_hetero_annotations = atom_array_stack.hetero[atom_array_stack.chain_id == chain_id] assert np.all(updated_chain_hetero_annotations) or np.all(~updated_chain_hetero_annotations) - # ...parse the CIF block into a dictionary + # ... parse the CIF block into a dictionary parse_from_cif_kwargs["file_type"] = "pdb" parse_from_cif_kwargs["extra_fields"] = None parse_from_cif_kwargs["build_assembly"] = "_spoof" diff --git a/src/atomworks/io/template.py b/src/atomworks/io/template.py index e0465033..cc62ed54 100644 --- a/src/atomworks/io/template.py +++ b/src/atomworks/io/template.py @@ -150,6 +150,42 @@ def match_residue_to_template( return template +def _find_residue_mask_fast( + residue_keys: np.ndarray, + sorted_keys: np.ndarray, + sort_idx: np.ndarray, + chain_id: str, + res_name: str, + res_id: int, +) -> np.ndarray: + """ + Efficient method of getting a residue mask from a sorted list of residue keys. + + Args: + - residue_keys: Structured np array of residue keys to search through + - sorted_keys: Sorted list of residue keys + - sort_idx: Index of the sorted list (get from doing np.argsort(residue_keys)) + - chain_id: Chain ID of the residue + - res_name: Residue name of the residue + - res_id: Residue ID of the residue + + Returns: + - mask: Boolean mask of the residue keys + """ + key = np.array([(chain_id, res_name, res_id)], dtype=residue_keys.dtype) + + # Find start and end indices using binary search + start_idx = np.searchsorted(sorted_keys, key)[0] + end_idx = np.searchsorted(sorted_keys, key, side="right")[0] + + # Create mask + mask = np.zeros(len(residue_keys), dtype=bool) + if start_idx < end_idx: + mask[sort_idx[start_idx:end_idx]] = True + + return mask + + def build_template_atom_array( chain_info_dict: dict[str, dict[str, Any]], atom_array: AtomArray | None = None, @@ -252,6 +288,15 @@ def build_template_atom_array( # ... create a list of atoms based on the reference CCD entries template_residues = [] chain_identifiers = chain_iids if use_chain_iids else chain_ids + + # ... get the sorted list of residue keys. This will make the residue mask lookup much faster. + residue_keys = np.array( + list(zip(chain_identifiers, res_names, res_ids, strict=True)), + dtype=np.dtype([("chain_id", "object"), ("res_name", "object"), ("res_id", " bool: + """Check if we are given a CCD CIF file, which by convention includes the _chem_comp_atom field but not the atom_site field""" + cif = read_any(self.path) + keys = list(cif.block.keys()) + + has_atom_site = "atom_site" in keys + has_chem_comp_atom = "chem_comp_atom" in keys + + return has_chem_comp_atom and not has_atom_site + + def _parse_ccd_style_cif(self) -> None: + """Parse a CCD-style CIF file.""" + + if self.custom_parse_kwargs is not None: + raise ValueError("Custom parse kwargs are not supported for CCD CIF files.") + + logger.warning( + f"CCD CIF file detected: {self.path}. " + "This file will be parsed as a CCD CIF file rather than a regular CIF file " + "(e.g., with an `atom_site` category)." + ) + + self.atom_array = parse_ccd_cif(read_any(self.path)) + self.atom_array.set_annotation("is_polymer", np.full(len(self.atom_array), False)) + self.chain_ids = np.unique(self.atom_array.chain_id) + + # Set occupancy to all 1s since we presumably want to predict everything + self.atom_array.occupancy = np.full(len(self.atom_array), 1.0) + + def _parse_standard_pdb_or_cif(self) -> None: + """Parse a standard PDB or CIF structure file.""" if self.custom_parse_kwargs is None: self.custom_parse_kwargs = {} + # We add missing atoms later to the fully-concatenated inference AtomArray parse_kwargs = {**DEFAULT_PARSE_KWARGS, "add_missing_atoms": False} | self.custom_parse_kwargs if parse_kwargs["add_missing_atoms"]: logger.warning( "Missing atoms will be added later to the fully-concatenated inference AtomArray. " - "It is recommended to set this argument to False in initial CIFOrPDBFileComponent parsing. " + "It is recommended to set this argument to False in initial CIFOrPDBFileComponent parsing." ) - # Parse using atomworks.io parsing_results = parse(self.path, **parse_kwargs) if "assemblies" in parsing_results: assemblies = parsing_results["assemblies"] - # We will keep only the first assembly that was parsed first_assembly_id = next(iter(assemblies.keys())) - # Give warning if multiple assemblies were parsed if len(assemblies) > 1: logger.warning( f"Multiple biological assemblies found in {self.path} and none were specified. " f"Only the first assembly (assembly_id={first_assembly_id}) will be used for inference. " - f"If you would like to use a different assembly, please specify this in the `parse_kwargs`." + "If you would like to use a different assembly, please specify this in the `parse_kwargs`." ) - # Get the atom array stack corresponding to this assembly atom_array_stack = assemblies[first_assembly_id] - - # Use the asymmetric unit if no assemblies were returned else: atom_array_stack = parsing_results["asym_unit"] - # We will keep only the first model of the parsed structure if atom_array_stack.stack_depth() > 1: logger.warning( f"Multiple models found in {self.path}. Only the first model will be used for inference. " - f"If you would like to use a different model, please specify this in the `parse_kwargs`." + "If you would like to use a different model, please specify this in the `parse_kwargs`." ) - structure_file_atom_array = atom_array_stack[0] - # Record chain ids and AtomArray + structure_file_atom_array = atom_array_stack[0] self.chain_ids = np.unique(structure_file_atom_array.chain_id) self.atom_array = structure_file_atom_array @@ -663,6 +691,13 @@ def components_to_atom_array( for component in components: # CIFOrPDBFileComponents already have parsed AtomArrays if isinstance(component, CIFOrPDBFileComponent): + atom_array = component.atom_array + if np.any(atom_array.chain_id == ""): + atom_array.chain_id = np.full(atom_array.array_length(), next(chain_id_generator)) + logger.warning( + f"Chain ID was not set for {component.path}. " + f"The next available chain ID was assigned, assuming that this is a single-chain structure: {atom_array.chain_id[0]}" + ) atom_arrays.append(component.atom_array) continue @@ -672,12 +707,16 @@ def components_to_atom_array( atom_arrays.append(sequence_to_annotated_atom_array(**component.as_dict(), custom_residues=custom_residues)) elif isinstance(component, SmilesComponent): ligand_array = smiles_to_annotated_atom_array(**component.as_dict()) - atom_arrays.append(assign_res_name_from_atom_array_hash(ligand_array, ligand_hash_to_id)) + if component.res_name == UNKNOWN_LIGAND: + ligand_array = assign_res_name_from_atom_array_hash(ligand_array, ligand_hash_to_id) + atom_arrays.append(ligand_array) elif isinstance(component, CCDComponent): atom_arrays.append(ccd_code_to_annotated_atom_array(**component.as_dict())) elif isinstance(component, SDFComponent): ligand_array = sdf_to_annotated_atom_array(**component.as_dict()) - atom_arrays.append(assign_res_name_from_atom_array_hash(ligand_array, ligand_hash_to_id)) + if component.res_name == UNKNOWN_LIGAND: + ligand_array = assign_res_name_from_atom_array_hash(ligand_array, ligand_hash_to_id) + atom_arrays.append(ligand_array) else: raise ValueError(f"Unknown chemical component type: {type(component)}") diff --git a/src/atomworks/io/tools/rdkit.py b/src/atomworks/io/tools/rdkit.py index c7b25d3d..c260b96a 100644 --- a/src/atomworks/io/tools/rdkit.py +++ b/src/atomworks/io/tools/rdkit.py @@ -54,7 +54,8 @@ """ Mapping from RDKit hybridization types to integers. -Reference: https://www.rdkit.org/docs/cppapi/classRDKit_1_1Atom.html#a58e40e30db6b42826243163175cac976 +Reference: + `RDKit Atom Documentation `_ """ RDKIT_BOND_TYPE_TO_BIOTITE: Final[dict[tuple[Chem.BondType, bool], struc.bonds.BondType]] = { @@ -64,6 +65,7 @@ (Chem.BondType.DOUBLE, False): struc.bonds.BondType.DOUBLE, (Chem.BondType.TRIPLE, False): struc.bonds.BondType.TRIPLE, (Chem.BondType.QUADRUPLE, False): struc.bonds.BondType.QUADRUPLE, + (Chem.BondType.DATIVE, False): struc.bonds.BondType.COORDINATION, (Chem.BondType.SINGLE, True): struc.bonds.BondType.AROMATIC_SINGLE, (Chem.BondType.DOUBLE, True): struc.bonds.BondType.AROMATIC_DOUBLE, (Chem.BondType.TRIPLE, True): struc.bonds.BondType.AROMATIC_TRIPLE, @@ -82,6 +84,7 @@ struc.bonds.BondType.DOUBLE: (Chem.BondType.DOUBLE, False), struc.bonds.BondType.TRIPLE: (Chem.BondType.TRIPLE, False), struc.bonds.BondType.QUADRUPLE: (Chem.BondType.QUADRUPLE, False), + struc.bonds.BondType.COORDINATION: (Chem.BondType.DATIVE, False), # NOTE: We map aromatics to single/double/triple instead of Chem.BondType.AROMATIC # because the PDB specified bond-order (from a kekulized form of the molecule) # is lost when we map to aromatic, which can lead to incorrect bond-order @@ -108,8 +111,8 @@ class ChEMBLNormalizer: This is useful for `rescuing` molecules that failed to be sanitized by RDKit alone. - References: - - https://github.com/chembl/ChEMBL_Structure_Pipeline/blob/master/chembl_structure_pipeline/standardizer.py#L33C1-L73C15 + Reference: + `ChEMBL Structure Pipeline `_ """ def __init__(self): @@ -290,9 +293,9 @@ def fix_mol( References: - - https://www.rdkit.org/docs/RDKit_Book.html#molecular-sanitization - - https://github.com/chembl/ChEMBL_Structure_Pipeline/blob/master/chembl_structure_pipeline/standardizer.py - - https://github.com/datamol-io/datamol/blob/0312388b956e2b4eeb72d791167cfdb873c7beab/datamol/mol.py + `RDKit Molecular Sanitization `_ + `ChEMBL Structure Pipeline `_ + `datamol mol.py `_ """ if not in_place: @@ -379,8 +382,8 @@ def get_morgan_fingerprint_from_rdkit_mol(mol: Chem.Mol, *, radius: int = 2, n_b - ExplicitBitVect: The Morgan fingerprint for the input molecule. References: - - AF-3 Supplement - - https://greglandrum.github.io/rdkit-blog/posts/2023-01-18-fingerprint-generator-tutorial.html + AF-3 Supplement + `RDKit Fingerprint Generator Tutorial `_ """ morgan_fingerprint_generator = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=n_bits) fingerprint = morgan_fingerprint_generator.GetFingerprint(mol) diff --git a/src/atomworks/io/transforms/atom_array.py b/src/atomworks/io/transforms/atom_array.py index 7f9c2e5c..d788e4f1 100644 --- a/src/atomworks/io/transforms/atom_array.py +++ b/src/atomworks/io/transforms/atom_array.py @@ -1,6 +1,6 @@ -""" -Transforms operating predominantly on Biotite's `AtomArray` objects. -These operations should take as input, and return, `AtomArray` objects. +"""Transforms operating predominantly on Biotite's AtomArray objects. + +These operations should take as input, and return, AtomArray objects. """ import logging @@ -32,7 +32,15 @@ def subset_atom_array(atom_array: AtomArray | AtomArrayStack, keep: np.ndarray) -> AtomArray | AtomArrayStack: - """Subsets an AtomArray or AtomArrayStack by a boolean mask.""" + """Subsets an AtomArray or AtomArrayStack by a boolean mask. + + Args: + atom_array: The AtomArray or AtomArrayStack to subset. + keep: Boolean mask indicating which atoms to keep. + + Returns: + The subsetted AtomArray or AtomArrayStack. + """ if isinstance(atom_array, AtomArrayStack): return atom_array[:, keep] else: @@ -40,7 +48,14 @@ def subset_atom_array(atom_array: AtomArray | AtomArrayStack, keep: np.ndarray) def is_any_coord_nan(atom_array: AtomArray | AtomArrayStack) -> np.ndarray: - """Returns a boolean mask of shape [n_atoms] indicating whether any coordinate is NaN for each atom in the AtomArray or AtomArrayStack.""" + """Returns a boolean mask indicating whether any coordinate is NaN for each atom. + + Args: + atom_array: The AtomArray or AtomArrayStack to check. + + Returns: + Boolean mask of shape [n_atoms] indicating NaN coordinates. + """ if isinstance(atom_array, AtomArrayStack): return np.isnan(atom_array.coord).any(axis=(0, -1)) else: diff --git a/src/atomworks/io/transforms/categories.py b/src/atomworks/io/transforms/categories.py index f8002088..eabea63a 100644 --- a/src/atomworks/io/transforms/categories.py +++ b/src/atomworks/io/transforms/categories.py @@ -1,5 +1,4 @@ -""" -Transforms operating on Biotite's CIFBlock and CIFCategory objects. +"""Transforms operating on Biotite's CIFBlock and CIFCategory objects. These transforms are used to extract information from the CIFBlock and return a dictionary containing processed information. """ @@ -27,12 +26,28 @@ def category_to_df(cif_block: CIFBlock, category: str) -> pd.DataFrame | None: - """Convert a CIF block to a pandas DataFrame.""" + """Convert a CIF block to a pandas DataFrame. + + Args: + cif_block: The CIF block to convert. + category: The category name to extract. + + Returns: + DataFrame containing the category data, or None if category doesn't exist. + """ return pd.DataFrame(category_to_dict(cif_block, category)) if category in cif_block else None def category_to_dict(cif_block: CIFBlock, category: str) -> dict[str, np.ndarray]: - """Convert a CIF block to a dictionary.""" + """Convert a CIF block to a dictionary. + + Args: + cif_block: The CIF block to convert. + category: The category name to extract. + + Returns: + Dictionary containing the category data as numpy arrays. + """ if exists(cif_block.get(category)): return toolz.valmap(lambda x: x.as_array(), dict(cif_block[category])) else: @@ -303,7 +318,7 @@ def get_ligand_of_interest_info(cif_block: CIFBlock) -> dict: """Extract ligand of interest information from a CIF block. Reference: - - https://pdb101.rcsb.org/learn/guide-to-understanding-pdb-data/small-molecule-ligands + `PDB101 Small Molecule Ligands Guide `_ """ # Extract binary flag for whether the ligand of interest is specified # NOTE: This is being used in addition to the below as it has slightly higher coverage across the PDB diff --git a/src/atomworks/io/utils/bonds.py b/src/atomworks/io/utils/bonds.py index ff499eda..33e5c145 100644 --- a/src/atomworks/io/utils/bonds.py +++ b/src/atomworks/io/utils/bonds.py @@ -333,8 +333,8 @@ def get_struct_conn_bonds( bonds (np.array[[int, int, struc.BondType]]): A List of bonds to be added to the atom array. leaving (np.ndarray): An array of indices of atoms that are leaving groups for bookkeeping. - References: - - https://mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_struct_conn.conn_type_id.html + Reference: + `struct_conn.conn_type_id `_ """ # ... validate input invalid_bond_types = set(add_bond_types) - STRUCT_CONN_BOND_TYPES @@ -812,6 +812,8 @@ def spoof_struct_conn_dict_from_string(bonds: list[tuple[str, str]]) -> dict[str NOTE: We only support covalent bonds. + TODO: Use AtomSelection to parse the bond strings + Args: bonds (list[tuple[str, str]]): A list of bond strings. Each bond string should be in the format: diff --git a/src/atomworks/io/utils/ccd.py b/src/atomworks/io/utils/ccd.py index c6b2cee9..3530bab5 100644 --- a/src/atomworks/io/utils/ccd.py +++ b/src/atomworks/io/utils/ccd.py @@ -3,6 +3,7 @@ import os from collections import defaultdict from collections.abc import Iterable +from pathlib import Path from typing import Literal import biotite.structure as struc @@ -31,25 +32,41 @@ @functools.cache def aa_chem_comps() -> frozenset[str]: - """Set of amino acid chemical components. E.g. {'ALA', 'ARG', ...}""" + """Set of amino acid chemical components. + + Returns: + Set of amino acid chemical components (e.g., {'ALA', 'ARG', ...}). + """ return frozenset(struc.info.groups._get_group_members(list(AA_LIKE_CHEM_TYPES))) @functools.cache def na_chem_comps() -> frozenset[str]: - """Set of nucleic acid chemical components. E.g. {'DA', 'DC', ...}""" + """Set of nucleic acid chemical components. + + Returns: + Set of nucleic acid chemical components (e.g., {'DA', 'DC', ...}). + """ return frozenset(struc.info.groups._get_group_members(list(NA_LIKE_CHEM_TYPES))) @functools.cache def rna_chem_comps() -> frozenset[str]: - """Set of RNA chemical components. E.g. {'A', 'C', ...}""" + """Set of RNA chemical components. + + Returns: + Set of RNA chemical components (e.g., {'A', 'C', ...}). + """ return frozenset(struc.info.groups._get_group_members(list(RNA_LIKE_CHEM_TYPES))) @functools.cache def dna_chem_comps() -> frozenset[str]: - """Set of DNA chemical components. E.g. {'DA', 'DC', ...}""" + """Set of DNA chemical components. + + Returns: + Set of DNA chemical components (e.g., {'DA', 'DC', ...}). + """ return frozenset(struc.info.groups._get_group_members(list(DNA_LIKE_CHEM_TYPES))) @@ -57,8 +74,16 @@ def dna_chem_comps() -> frozenset[str]: def chem_comp_to_one_letter() -> dict[str, str]: """Dictionary mapping the chemical components to their 1-letter code. - NOTE: Chemical components historically used to be 3-letter codes, - but nowadays longer codes exist. + Note: + Chemical components historically used to be 3-letter codes, + but nowadays longer codes exist. + + Returns: + Dictionary mapping chemical component names to their 1-letter codes. + + References: + `RCSB Chemical Component Dictionary `_ + `Biotite CCD Module `_ """ ccd = struc.info.ccd.get_ccd() three_letter_code = ccd["chem_comp"]["three_letter_code"].as_array() @@ -80,30 +105,66 @@ def get_available_ccd_codes_in_mirror(ccd_mirror_path: os.PathLike = CCD_MIRROR_ """Set of all CCD codes available in the local mirror. Only counts codes when they adhere to the CCD mirror layout (e.g. .../H/HEM/HEM.cif) + + Args: + ccd_mirror_path: Path to the CCD mirror directory. + + Returns: + Set of all available CCD codes in the mirror. + + References: + `RCSB Chemical Component Dictionary `_ + `CCD Mirror Layout `_ """ root = os.fspath(ccd_mirror_path) + + # Check if we have a pre-computed cache file + cache_file = os.path.join(root, ".ccd_codes_cache") + if os.path.exists(cache_file): + try: + # Check if cache is newer than the directory + cache_mtime = os.path.getmtime(cache_file) + dir_mtime = os.path.getmtime(root) + if cache_mtime > dir_mtime: + with open(cache_file) as f: + codes = {line.strip() for line in f if line.strip()} + return frozenset(codes) + except OSError: + # If cache is corrupted, fall back to scanning + pass + + # Fall back to filesystem scan codes: set[str] = set() - # NOTE: The below is an optimized file-system scan since this is run at every - with os.scandir(root) as level1: - for l1 in level1: - if not l1.is_dir(follow_symlinks=False): + root_path = Path(root) + + for level1_dir in root_path.iterdir(): + if not level1_dir.is_dir(): + continue + first_letter = level1_dir.name + if len(first_letter) != 1: + continue + + for level2_dir in level1_dir.iterdir(): + if not level2_dir.is_dir(): continue - first_letter = l1.name - if len(first_letter) != 1: + code = level2_dir.name + if not code or code[0] != first_letter: continue - with os.scandir(l1.path) as level2: - for l2 in level2: - if not l2.is_dir(follow_symlinks=False): - continue - code = l2.name - if not code or code[0] != first_letter: - continue - - expected = os.path.join(l2.path, f"{code}.cif") - if os.path.isfile(expected): - codes.add(code) + expected_file = level2_dir / f"{code}.cif" + if expected_file.is_file(): + codes.add(code) + + # Cache the results for next time + try: + with open(cache_file, "w") as f: + for code in sorted(codes): + f.write(f"{code}\n") + except OSError: + # If we can't write cache, that's okay + pass + return frozenset(codes) @@ -210,25 +271,23 @@ def parse_ccd_cif( add_properties: bool = False, add_mapping: bool = False, ) -> struc.AtomArray: - """ - Parses a Chemical Component Dictionary CIF file into a Biotite AtomArray structure. + """Parses a Chemical Component Dictionary CIF file into a Biotite AtomArray structure. Args: - - cif (CIFFile): The CIF file containing the component data. - - coords (Literal["model", "ideal_pdbx", "ideal_rdkit"] | None | tuple[str, ...]): - Type of coordinates to use. Defaults to ("ideal_pdbx", "model", "ideal_rdkit"). + cif: The CIF file containing the component data. + coords: Type of coordinates to use. Defaults to ("ideal_pdbx", "model", "ideal_rdkit"). Can be a single coordinate type or a tuple of fallback preferences (e.g., ("ideal_pdbx", "model", "ideal_rdkit")). - - "model": Use the coordinates that are found in a random (but fixed) pdb file. - - "ideal_pdbx": Use the idealized coordinates computed by the RCSB PDB (sometimes not available). - - "ideal_rdkit": Use the idealized coordinates computed by RDKit (sometimes unrealistic). - - add_properties (bool): Whether to include RDKit-computed properties. Defaults to False. - Properties are available under the `properties` attribute of the returned `AtomArray`. - - add_mapping (bool): Whether to include external resource mappings, such as e.g. the ChEMBL ID. + - "model": Use the coordinates that are found in a random (but fixed) pdb file. + - "ideal_pdbx": Use the idealized coordinates computed by the RCSB PDB (sometimes not available). + - "ideal_rdkit": Use the idealized coordinates computed by RDKit (sometimes unrealistic). + add_properties: Whether to include RDKit-computed properties. Defaults to False. + Properties are available under the ``properties`` attribute of the returned ``AtomArray``. + add_mapping: Whether to include external resource mappings, such as e.g. the ChEMBL ID. Defaults to False. - Mappings are available under the `mapping` attribute of the returned `AtomArray`. + Mappings are available under the ``mapping`` attribute of the returned ``AtomArray``. Returns: - - AtomArray: The parsed atomic structure with requested annotations and properties. + AtomArray: The parsed atomic structure with requested annotations and properties. Example: >>> cif = pdbx.CIFFile.read("path/to/ALA.cif") @@ -371,21 +430,20 @@ def parse_ccd_cif( def get_ccd_component_from_mirror( ccd_code: str, ccd_mirror_path: os.PathLike = CCD_MIRROR_PATH, **parse_ccd_cif_kwargs ) -> struc.AtomArray: - """ - Retrieves and parses a component from a local mirror of the Chemical Component Dictionary. + """Retrieves and parses a component from a local mirror of the Chemical Component Dictionary. Args: - - ccd_code (str): The three-letter code of the chemical component. - - ccd_mirror_path (os.PathLike): Path to the root of the CCD mirror directory. - - **parse_ccd_cif_kwargs: Additional keyword arguments passed to parse_ccd_cif(): - - coords (Literal["model", "ideal_pdbx", "ideal_rdkit"] | None): - Type of coordinates to use. Defaults to "ideal_pdbx". - - add_properties (bool): Whether to include RDKit-computed properties. Defaults to True. - - add_mapping (bool): Whether to include external resource mappings, such as e.g. the ChEMBL ID. + ccd_code: The three-letter code of the chemical component. + ccd_mirror_path: Path to the root of the CCD mirror directory. + **parse_ccd_cif_kwargs: Additional keyword arguments passed to parse_ccd_cif(): + coords: Type of coordinates to use ("model", "ideal_pdbx", "ideal_rdkit", or None). + Defaults to "ideal_pdbx". + add_properties: Whether to include RDKit-computed properties. Defaults to True. + add_mapping: Whether to include external resource mappings, such as e.g. the ChEMBL ID. Defaults to False. Returns: - - AtomArray: The parsed atomic structure of the requested component. + AtomArray: The parsed atomic structure of the requested component. Example: >>> atom_array = get_ccd_component_from_mirror("ALA", coords="ideal_pdbx") diff --git a/src/atomworks/io/utils/io_utils.py b/src/atomworks/io/utils/io_utils.py index 5a0c43c3..0ba58753 100644 --- a/src/atomworks/io/utils/io_utils.py +++ b/src/atomworks/io/utils/io_utils.py @@ -1,6 +1,4 @@ -""" -General utility functions for working with CIF files in Biotite. -""" +"""General utility functions for working with CIF files in Biotite.""" __all__ = ["get_structure", "load_any", "read_any", "to_cif_buffer", "to_cif_file", "to_cif_string"] @@ -36,8 +34,10 @@ def _get_logged_in_user() -> str: - """ - Get the logged in user. + """Get the logged in user. + + Returns: + The username of the logged in user, or "unknown_user" if unavailable. """ try: return os.getlogin() @@ -57,19 +57,20 @@ def load_any( """Convenience function for loading a structure from a file or buffer. Args: - - file_or_buffer: Path to the file or buffer to load the structure from. - - file_type: Type of the file to load. If None, it will be inferred. - - extra_fields: List of extra fields to include as AtomArray annotations. + file_or_buffer: Path to the file or buffer to load the structure from. + file_type: Type of the file to load. If None, it will be inferred. + extra_fields: List of extra fields to include as AtomArray annotations. If "all", all fields in the 'atom_site' category of the file will be included. - - include_bonds: Whether to include bonds in the structure. - - model: The model number to use for loading the structure. If None, all models will be loaded. - - altloc: The altloc ID to use for loading the structure. + include_bonds: Whether to include bonds in the structure. + model: The model number to use for loading the structure. If None, all models will be loaded. + altloc: The altloc ID to use for loading the structure. Returns: - AtomArrayStack: The loaded structure with the specified fields and assumptions. + The loaded structure with the specified fields and assumptions. - Reference: - Biotite documentation (https://www.biotite-python.org/apidoc/biotite.structure.io.pdbx.get_structure.html#biotite.structure.io.pdbx.get_structure) + References: + `Biotite Structure I/O `_ + `mmCIF Format Specification `_ """ file_obj = read_any(file_or_buffer, file_type=file_type) return get_structure( @@ -87,17 +88,17 @@ def _add_bonds( add_bond_types_from_struct_conn: list[str] = ["covale"], fix_bond_types: bool = True, ) -> AtomArray | AtomArrayStack: - """ - Add bonds to the AtomArray and filter by a given altloc strategy. + """Add bonds to the AtomArray and filter by a given altloc strategy. + Avoids the issue where spurious bonds are added due to uninformative label_seq_ids. Args: - - atom_array: The AtomArray to add bonds to. Must contain `auth_seq_id` annotation. - - cif_block: The CIFBlock containing the structure data. - - add_bond_types_from_struct_conn (list, optional): A list of bond types to add to the structure + atom_array: The AtomArray to add bonds to. Must contain `auth_seq_id` annotation. + cif_block: The CIFBlock containing the structure data. + add_bond_types_from_struct_conn: A list of bond types to add to the structure from the `struct_conn` category. Defaults to `["covale"]`. This means that we will only add covalent bonds to the structure (excluding metal coordination and disulfide bonds). - - fix_bond_types (bool, optional): Whether to correct for nucleophilic additions on atoms involved in inter-residue bonds. + fix_bond_types: Whether to correct for nucleophilic additions on atoms involved in inter-residue bonds. Returns: AtomArray | AtomArrayStack: The AtomArray or AtomArrayStack with bonds and filtered by altloc. @@ -209,7 +210,7 @@ def get_structure( AtomArray | AtomArrayStack: The loaded structure with the specified fields and assumptions. Reference: - Biotite documentation (https://www.biotite-python.org/apidoc/biotite.structure.io.pdbx.get_structure.html#biotite.structure.io.pdbx.get_structure) + `Biotite documentation `_ """ tmp_altloc = altloc if altloc in {"first", "occupancy", "all"} else "all" @@ -569,6 +570,12 @@ def _to_cif_or_bcif( if not include_nan_coords: structure = ta.remove_nan_coords(structure) + if include_bonds and structure.bonds is not None: + # TODO: Switch to using the `convert_bond_type` method once we upgrade to Biotite v1.4.0 + # structure.bonds.convert_bond_type(struc.bonds.BondType.COORDINATION, struc.bonds.BondType.SINGLE) + mask = structure.bonds._bonds[:, 2] == struc.bonds.BondType.COORDINATION + structure.bonds._bonds[mask, 2] = struc.bonds.BondType.SINGLE + pdbx.set_structure(cif_file, structure, data_block=id, include_bonds=include_bonds, extra_fields=extra_fields) # Add extra categories if provided diff --git a/src/atomworks/io/utils/non_rcsb.py b/src/atomworks/io/utils/non_rcsb.py index 7a44a668..54c1781f 100644 --- a/src/atomworks/io/utils/non_rcsb.py +++ b/src/atomworks/io/utils/non_rcsb.py @@ -120,7 +120,7 @@ def initialize_chain_info_from_atom_array( In particular, this function adds the following information to the chain_info_dict: - The RCSB entity ID for each chain (e.g., 1, 2, 3, etc.), if present in the AtomArray (under the entity_id atom site label) - The unprocessed one-letter entity canonical and non-canonical sequences. - - (OptionallyA boolean flag indicating whether the chain is a polymer. + - (Optionally) A boolean flag indicating whether the chain is a polymer. - (Optionally) The chain type as an IntEnum (e.g., polypeptide(L), non-polymer, etc.) - (Optionally) The residue IDs and residue names, inferred from the AtomArray. @@ -147,8 +147,12 @@ def initialize_chain_info_from_atom_array( res_names = atom_array.res_name[_res_starts] hetero = atom_array.hetero[_res_starts] - # Loop through chains for chain_identifier in np.unique(chain_identifiers): + if not chain_identifier: + raise ValueError( + 'Chain identifier is empty! Please ensure that in your input file, each chain has a unique identifier (e.g., `label_asym_id` in a CIF file cannot be "").' + ) + is_in_chain = chain_identifiers == chain_identifier seq = res_names[is_in_chain] diff --git a/src/atomworks/io/utils/selection.py b/src/atomworks/io/utils/selection.py index b9bcd7c9..c8575861 100644 --- a/src/atomworks/io/utils/selection.py +++ b/src/atomworks/io/utils/selection.py @@ -1,10 +1,21 @@ -"""Utility functions for selecting segments of an AtomArray""" +"""Tools for atom and segment selection on ``AtomArray`` and ``AtomArrayStack``. -__all__ = ["annot_start_stop_idxs", "get_annotation", "get_residue_starts"] +Provides helpers to compute segment boundaries and apply expressive selection syntax to structures. + +Key public objects: +- :py:class:`~atomworks.io.utils.selection.AtomSelection` +- :py:class:`~atomworks.io.utils.selection.AtomSelectionStack` +- :py:class:`~atomworks.io.utils.selection.SegmentSlice` + +See individual docstrings for usage and examples. +""" + +__all__ = ["AtomSelection", "AtomSelectionStack", "annot_start_stop_idxs", "get_annotation", "get_residue_starts"] import re from abc import ABC, abstractmethod from functools import reduce +from itertools import product from typing import Any, Literal import biotite.structure as struc @@ -18,16 +29,15 @@ def annot_start_stop_idxs( atom_array: AtomArray | AtomArrayStack, annots: str | list[str], add_exclusive_stop: bool = False ) -> np.ndarray: - """ - Computes the start and stop indices for segments in an AtomArray where any of the specified annotation(s) change. + """Computes the start and stop indices for segments in an AtomArray where any of the specified annotation(s) change. Args: - - atom_array (AtomArray): The AtomArray to process. - - annots (str | list[str]): The annotation(s) to consider for determining segment boundaries. - - add_exclusive_stop (bool): If True, an exclusive stop index (the length of the AtomArray) is added to the result. + atom_array: The AtomArray to process. + annots: Annotation name or names to define segments. + add_exclusive_stop: Append an exclusive stop index at the end. Defaults to ``False``. Returns: - - np.ndarray: An array of start and stop indices for segments where the annotations change. + 1D array of start/stop indices that bound segments. Example: >>> atom_array = AtomArray(...) @@ -50,14 +60,21 @@ def annot_start_stop_idxs( def get_residue_starts(atom_array: AtomArray | AtomArrayStack, add_exclusive_stop: bool = False) -> np.ndarray: """Get the start (and optionally stop) indices of residues in an AtomArray. - More robust version of `biotite.structure.residues.get_residue_starts` that also - differentiates between residues resulting from different transformation ids. + This is a more robust version of :py:func:`biotite.structure.residues.get_residue_starts` + that additionally differentiates residues across different ``transformation_id`` values + when present. It is backwards compatible if the annotation is absent. + + Args: + atom_array: Structure to analyze. + add_exclusive_stop: Append an exclusive stop index at the end. Defaults to ``False``. - Backwards compatible with `biotite.structure.residues.get_residue_starts` if the - `transformation_id` annotation is not present. + Returns: + 1D array of residue boundary indices. References: - - https://github.com/biotite-dev/biotite/blob/231eefed334e1d3509c1b7cb3f2bfd71d4b0eeb0/src/biotite/structure/residues.py#L35 + * `Biotite get_residue_starts`_ + + .. _Biotite get_residue_starts: https://github.com/biotite-dev/biotite/blob/231eefed334e1d3509c1b7cb3f2bfd71d4b0eeb0/src/biotite/structure/residues.py#L35 """ _annots_to_check = ["chain_id", "res_name", "res_id", "ins_code", "transformation_id"] existing_annots = atom_array.get_annotation_categories() @@ -66,7 +83,17 @@ def get_residue_starts(atom_array: AtomArray | AtomArrayStack, add_exclusive_sto def _validate_n_body_and_type(atom_array: AtomArray | AtomArrayStack, n_body: int, operation: str) -> None: - """Validate n_body parameter and atom_array type compatibility.""" + """Validate ``n_body`` value and structure type. + + Args: + atom_array: Structure to validate. + n_body: Annotation dimensionality (1 or 2). + operation: Description used in error messages. + + Raises: + ValueError: If ``n_body > 1`` but ``atom_array`` is not ``AtomArrayPlus`` or ``AtomArrayStack``. + NotImplementedError: If ``n_body`` is not 1 or 2. + """ if n_body > 1 and not isinstance(atom_array, (AtomArrayPlus | AtomArrayStack)): raise ValueError(f"Cannot {operation} with n_body={n_body} on non-AtomArrayPlus!") @@ -77,7 +104,19 @@ def _validate_n_body_and_type(atom_array: AtomArray | AtomArrayStack, n_body: in def get_annotation( atom_array: AtomArray | AtomArrayStack, annot: str, n_body: int | None = None, default: Any = None ) -> np.ndarray: - """Get the annotation for an AtomArray or AtomArrayStack if it exists, otherwise return the default value.""" + """Return an annotation array if present, otherwise ``default``. + + If ``n_body`` is ``None``, the dimensionality is auto-detected by probing 1D then 2D annotation categories. + + Args: + atom_array: Structure to query. + annot: Annotation category name. + n_body: 1 for 1D annotations, 2 for 2D annotations; auto-detected if ``None``. + default: Value to return if the annotation is missing. Defaults to ``None``. + + Returns: + The requested annotation array or ``default`` if missing. + """ if n_body is not None: _validate_n_body_and_type(atom_array, n_body, f"get annotation for {annot}") else: @@ -98,11 +137,11 @@ def get_annotation_categories(atom_array: AtomArray | AtomArrayStack, n_body: in """Get annotation categories for the specified n_body. Args: - atom_array: The AtomArray or AtomArrayStack to query. - n_body: 1 for 1D annotations, 2 for 2D annotations, or "all" for all available n_body. + atom_array: Structure to query. + n_body: ``1`` for 1D, ``2`` for 2D, or ``"all"`` for both. Returns: - categories: list[str] List of annotation category names. + Names of available annotation categories for the requested dimensionality. """ # Map n_body to the corresponding method name n_body_to_method = { @@ -127,8 +166,7 @@ def get_annotation_categories(atom_array: AtomArray | AtomArrayStack, n_body: in class SegmentSlice(ABC): - """ - Abstract base class for slicing segments of an AtomArray or AtomArrayStack. + """Abstract base class for slicing segments of an AtomArray or AtomArrayStack. Provides functionality analogous to Python's built-in slice object but operates on structural segments (e.g., residues or chains indices) rather than individual atom indices. To subclass, implement the @@ -140,8 +178,8 @@ class SegmentSlice(ABC): - to slice to the last two residues: `atom_array[ResIdxSlice(-2, None)]` Args: - - start (int | None): The starting segment index. If None, starts from the beginning. - - stop (int | None): The ending segment index (exclusive). If None, continues to the end. + start: Starting segment index. Defaults to ``None``. + stop: Exclusive ending segment index. Defaults to ``None``. """ def __init__(self, start: int | None = None, stop: int | None = None): @@ -153,14 +191,13 @@ def _get_segment_bounds(self, atom_array: AtomArray | AtomArrayStack) -> np.ndar pass def __call__(self, atom_array: AtomArray | AtomArrayStack) -> slice: - """ - Creates a slice object for the specified segment range in the atom array. + """Creates a slice object for the specified segment range in the atom array. Args: - - atom_array (AtomArray | AtomArrayStack): The structure to slice. + atom_array: Structure to slice. Returns: - - slice: A slice object that can be used to index the atom array. + A Python ``slice`` that can be used to index ``atom_array``. """ seg_bounds = self._get_segment_bounds(atom_array) n_segments = len(seg_bounds) - 1 @@ -175,11 +212,10 @@ def __call__(self, atom_array: AtomArray | AtomArrayStack) -> slice: class ResIdxSlice(SegmentSlice): - """ - Slice atoms by residue indices. + """Slice atoms by residue indices. - Allows for selecting ranges of residues using Python slice-like syntax. Each residue is considered - as a segment, defined by changes in chain_id, res_name, res_id, ins_code, or transformation_id. + Residues are segmented by changes in ``chain_id``, ``res_name``, ``res_id``, + ``ins_code``, or ``transformation_id``. Example: >>> atom_array = AtomArray(...) @@ -192,8 +228,7 @@ def _get_segment_bounds(self, atom_array: AtomArray | AtomArrayStack) -> np.ndar class ChainIdxSlice(SegmentSlice): - """ - Slice atoms by chain indices. + """Slice atoms by chain indices. Allows for selecting ranges of chains using Python slice-like syntax. Each chain is considered as a segment, defined by changes in the chain_id annotation. @@ -209,15 +244,7 @@ def _get_segment_bounds(self, atom_array: AtomArray | AtomArrayStack) -> np.ndar class AtomSelection: - """Class that represents a selection of atoms in a molecular structure. - - We can specify a selection by chain_id, res_name, res_id, atom_name, and (optionally) transformation_id. - - For example: - - If we specify only chain_id, we will select all atoms in that chain (across all transformations) - - If we specify chain_id and res_name, we will select all atoms in that chain and residue - - If we specify only atom_name, we will select all atoms with that name, regardless of chain or residue - """ + """Represent a selection of atoms in a molecular structure.""" def __init__( self, @@ -227,6 +254,15 @@ def __init__( atom_name: str = "*", transformation_id: int | str = "*", ): + """Initialize a selection. + + Args: + chain_id: Chain identifier or ``"*"`` for any. Defaults to ``"*"``. + res_name: Residue name or ``"*"`` for any. Defaults to ``"*"``. + res_id: Residue index (integer) or ``"*"`` for any. Defaults to ``"*"``. + atom_name: Atom name or ``"*"`` for any. Defaults to ``"*"``. + transformation_id: Transformation id or ``"*"`` for any. Defaults to ``"*"``. + """ self.chain_id = chain_id self.res_name = res_name self.atom_name = atom_name @@ -248,7 +284,7 @@ def __repr__(self) -> str: def __eq__(self, other: Any) -> bool: if isinstance(other, str): # Convert the string to an AtomSelection for comparison - other = self.from_str(other) + other = self.from_selection_str(other) if not isinstance(other, AtomSelection): return False @@ -262,11 +298,24 @@ def __eq__(self, other: Any) -> bool: ) @classmethod - def from_str(cls, selection_string: str) -> "AtomSelection": - """Create a new AtomSelection from a selection string. + def from_selection_str(cls, selection_string: str) -> "AtomSelection": + """Create a selection from ``CHAIN/RES/RESID/ATOM/TRANSFORM`` syntax. + + ``"*"`` acts as a wildcard for any field. Trailing fields may be omitted + and default to ``"*"``. - Selection strings are of the form: `CHAIN_ID/RES_NAME/RES_ID/ATOM_NAME/TRANSFORMATION_ID` - We use "*" as a wildcard to select all atoms in a given granularity. + Examples: + >>> # Selects the CA atom of the ALA residue at chain A, residue index 1 + >>> AtomSelection.from_selection_str("A/ALA/1/CA") + + >>> # Selects the CB atom of the ALA residue in any chain at any residue index + >>> AtomSelection.from_selection_str("*/ALA/*/CB") + + >>> # Selects all atoms of the ALA residue at chain A + >>> AtomSelection.from_selection_str("A/ALA/") + + >>> # Selects the CA atom of the ALA residue at chain A, residue index 1, transformation index 1 + >>> AtomSelection.from_selection_str("A/ALA/1/CA/1") """ selection = parse_selection_string(selection_string) @@ -280,12 +329,18 @@ def from_str(cls, selection_string: str) -> "AtomSelection": @classmethod def from_pymol_str(cls, pymol_string: str) -> "AtomSelection": - """Create a new AtomSelection from a PyMOL string. + """Create a selection from a PyMOL atom label string. + + PyMOL strings are of the form ``CHAIN/RES`RESID/ATOM`` and do not support + ``transformation_id``. ``"*"`` may be used as a wildcard. + + PyMOL strings do not support transformation_id. - PyMOL strings, found by clicking on an atom or residue, are of the form: CHAIN_ID/RES_NAME`RES_ID/ATOM_NAME - For example: "A/ASP`37/OD2" + We introduce to default PyMOL syntax the "*" operator as a wildcard to select all atoms in a given granularity. - We introduce "*" as a wildcard to select all atoms in a given granularity. + Example: + >>> # Selects the OD2 atom of the ASP residue at chain A, residue index 37 + >>> AtomSelection.from_pymol_str("A/ASP`37/OD2") """ selection = parse_pymol_string(pymol_string) return cls( @@ -305,21 +360,13 @@ def get_idxs(self, atom_array: AtomArray) -> np.ndarray: def parse_selection_string(selection_string: str) -> AtomSelection: - """Convert a selection string into a AtomSelection dataclass. + """Parse ``CHAIN/RES/RESID/ATOM/TRANSFORM`` into an :py:class:`AtomSelection`. - Selection strings are of the form: `CHAIN_ID/RES_NAME/RES_ID/ATOM_NAME/TRANSFORMATION_ID` + ``"*"`` acts as a wildcard for any field. Trailing fields may be omitted + and default to ``"*"``. - We use "*" as a wildcard to select all atoms in a given granularity. - - Example: - >>> parse_selection_string("A/ALA/1/CA") - AtomSelection(chain_id='A', res_name='ALA', res_id=1, atom_name='CA') - >>> parse_selection_string("*/ALA/*/CB") # (select all CB atoms in ALA residues) - AtomSelection(chain_id='*', res_name='ALA', res_id='*', atom_name='CB') - >>> parse_selection_string("A/ALA/") - AtomSelection(chain_id='A', res_name='ALA') - >>> parse_selection_string("A/*/*/*/1") - AtomSelection(chain_id='A', res_name='*', res_id='*', atom_name='*', transformation_id=1) + See Also: + :py:meth:`~atomworks.io.utils.selection.AtomSelection.from_selection_str` """ granularity_tiers = ["chain_id", "res_name", "res_id", "atom_name", "transformation_id"] values = selection_string.split("/") @@ -331,20 +378,14 @@ def parse_selection_string(selection_string: str) -> AtomSelection: def parse_pymol_string(pymol_string: str) -> AtomSelection: - """Convert a PyMOL selection string into an AtomSelection instance. + """Parse a PyMOL string ``CHAIN/RES`RESID/ATOM`` into an :py:class:`AtomSelection`. PyMOL selection strings are of the form: CHAIN_ID/RES_NAME`RES_ID/ATOM_NAME - Wildcards can be used with "*". PyMOL selection strings do not support transformation_id. - Examples: - >>> parse_pymol_string("A/ASP`37/OD2") - AtomSelection(chain_id='A', res_name='ASP', res_id=37, atom_name='OD2') - >>> parse_pymol_string("A/ASP") - AtomSelection(chain_id='A', res_name='ASP', res_id='*', atom_name='*') - >>> parse_pymol_string("*/ASP`*/OD2") - AtomSelection(chain_id='*', res_name='ASP', res_id='*', atom_name='OD2') + See Also: + :py:meth:`~atomworks.io.utils.selection.AtomSelection.from_pymol_str` """ # Replace backtick with slash to standardize the format standardized_string = pymol_string.replace("`", "/") @@ -354,20 +395,27 @@ def parse_pymol_string(pymol_string: str) -> AtomSelection: def get_mask_from_selection_string(atom_array: AtomArray, selection_string: str) -> np.ndarray: """Create a boolean mask from an AtomArray sequence selection string. - Selection strings are of the form: `CHAIN_ID/RES_NAME/RES_ID/ATOM_NAME/TRANSFORMATION_ID` - - We use "*" as a wildcard to select all atoms in a given granularity. + Selection strings follow ``CHAIN/RES/RESID/ATOM/TRANSFORM`` with ``"*"`` as a + wildcard for any field. Trailing fields may be omitted. Example: >>> atom_array = AtomArray(...) >>> mask = get_mask_from_selection_string(atom_array, "A/ALA/1/CA") [False, True, False, False, ...] + + See Also: + :py:func:`~atomworks.io.utils.selection.parse_selection_string` """ return get_mask_from_atom_selection(atom_array, parse_selection_string(selection_string)) def get_mask_from_atom_selection(atom_array: AtomArray, atom_selection: AtomSelection) -> np.ndarray: - """Create a boolean mask from a AtomSelection dataclass.""" + """Create a boolean mask from an :py:class:`AtomSelection`. + + See Also: + :py:func:`~atomworks.io.utils.selection.parse_selection_string` + """ + # TODO: Refactor using AtomArray query syntax mask = np.ones(atom_array.array_length(), dtype=bool) # ... add the masks @@ -393,21 +441,47 @@ def get_mask_from_atom_selection(atom_array: AtomArray, atom_selection: AtomSele class AtomSelectionStack: - """Class that represents a stack of AtomSelections. + """Manage multiple :py:class:`AtomSelection` objects as a unioned query. + + Supports ranges and comma-separated tokens via :py:meth:`from_query` and + contiguous ranges via :py:meth:`from_contig`. - Useful for managing multiple selections and applying them to an AtomArrayStack. - Notably, enables the use of a single selection string to select multiple segments. + See Also: + :py:meth:`~atomworks.io.utils.selection.AtomSelectionStack.from_query`, + :py:meth:`~atomworks.io.utils.selection.AtomSelectionStack.from_contig` """ def __init__(self, selections: list[AtomSelection]): + """Initialize a stack of selections. + + Args: + selections: Sequence of selections to be unioned. + """ self.selections = selections @classmethod - def from_contig_string(cls, contig_string: str) -> "AtomSelectionStack": + def from_contig(cls, contig: str) -> "AtomSelectionStack": + """Create a stack from contiguous residue ranges. + + Contig strings specify inclusive residue index ranges, e.g. ``"A1-2"`` + or ``"A1-2, B3-10"``. + + Args: + contig: Contiguous residue selection string like ``"A1-2, B3-10"``. + + Examples: + >>> # Selects residues 1..2 in chain A + >>> AtomSelectionStack.from_contig("A1-2") + >>> # Selects residues 1..2 in chain A and 3..10 in chain B + >>> AtomSelectionStack.from_contig("A1-2, B3-10") + + See Also: + :py:meth:`~atomworks.io.utils.selection.AtomSelectionStack.from_query` + """ # First define a regex that matches the elements of the contig string CONTIG_REGEX = re.compile(r"([A-Za-z]+)(\d+)-(\d+)") # noqa selections = [] - for selection in contig_string.replace(" ", "").split(","): + for selection in contig.replace(" ", "").split(","): match = CONTIG_REGEX.match(selection) if not match: raise ValueError(f"Invalid contig string: {selection}") @@ -419,12 +493,137 @@ def from_contig_string(cls, contig_string: str) -> "AtomSelectionStack": selections.append(atom_selection) return cls(selections) + @classmethod + def from_query(cls, query: str | list[str]) -> "AtomSelectionStack": + """Create a stack from extended query syntax with ranges. + + Extended syntax overview: + - **Chains**: ``A`` (all atoms in chain A), ``A/ALA`` (all ALA in chain A) + - **Ranges (``res_id`` only)**: ``A/*/5-10`` selects residues 5..10 in chain A + + Grammar per field (``CHAIN/RES/RESID/ATOM/TRANSFORM``): + - ``"*"`` wildcard + - Exact value, e.g. ``"A"``, ``"ALA"``, ``"CA"`` + - Range (``res_id`` only): ``"5-10"`` (inclusive) + + Notes: + - Fields are in order: CHAIN_ID/RES_NAME/RES_ID/ATOM_NAME/TRANSFORMATION_ID + - Wildcard is "*". Missing trailing fields default to "*". + - Multiple comma-separated tokens are combined by union. + + Multiple tokens may be provided as a comma-separated string or ``list[str]``. + + Examples: + >>> # Selects residues 5..10 in chain A + >>> AtomSelectionStack.from_query("A/*/5-10") + >>> # Selects residues 5..10 in chain A and 3..10 in chain B + >>> AtomSelectionStack.from_query("A/*/5-10, B/*/3-10") + >>> # Selects residues 5..10 in chain A and 3..10 in chain B + >>> AtomSelectionStack.from_query(["A/*/5-10", "B/*/3-10"]) + """ + tokens = cls._parse_query_tokens(query) + selections: list[AtomSelection] = [] + + for token in tokens: + field_values = cls._parse_token_fields(token) + token_selections = cls._build_selections_from_fields(field_values) + selections.extend(token_selections) + + return cls(selections) + + @classmethod + def _parse_query_tokens(cls, query: str | list[str]) -> list[str]: + """Parse query input into individual tokens.""" + if isinstance(query, str): + return [tok.strip() for tok in query.split(",") if tok.strip()] + else: + return [tok.strip() for tok in query if tok and tok.strip()] + + @classmethod + def _parse_token_fields(cls, token: str) -> dict[str, list[Any]]: + """Parse a single token into field values.""" + parts = token.split("/") + + # Ensure five fields with '*' defaults + while len(parts) < 5: + parts.append("*") + chain_val, res_name_val, res_id_val, atom_name_val, trans_id_val = parts[:5] + + return { + "chain_id": cls._parse_field_value(chain_val, is_res_id=False), + "res_name": cls._parse_field_value(res_name_val, is_res_id=False), + "res_id": cls._parse_field_value(res_id_val, is_res_id=True), + "atom_name": cls._parse_field_value(atom_name_val, is_res_id=False), + "transformation_id": cls._parse_field_value(trans_id_val, is_res_id=False), + } + + @classmethod + def _parse_field_value(cls, value: str, *, is_res_id: bool = False) -> list[Any]: + """Parse a field value into a list of options. + + For ``res_id``, values are integers; for others, strings. + """ + v = value.strip() + if v == "*" or v == "": + return ["*"] + + return cls._extract_field_options(v, is_res_id=is_res_id) + + @classmethod + def _extract_field_options(cls, value: str, *, is_res_id: bool = False) -> list[Any]: + """Extract options from a field value (ranges or scalars).""" + # Range syntax: 5-10 (res_id only) + if is_res_id and re.fullmatch(r"-?\d+-?\d+", value): + start_s, stop_s = value.split("-", 1) + start_i, stop_i = int(start_s), int(stop_s) + step = 1 if start_i <= stop_i else -1 + return list(range(start_i, stop_i + step, step)) + + # Scalar value + return [int(value)] if is_res_id and value not in ("*", "") else [value] + + @classmethod + def _build_selections_from_fields(cls, field_values: dict[str, list[Any]]) -> list[AtomSelection]: + """Build selections from parsed field values, expanding sets and ranges.""" + # Extract field values directly as lists + chain_vals = field_values["chain_id"] + resn_vals = field_values["res_name"] + resi_vals = field_values["res_id"] + atom_vals = field_values["atom_name"] + tran_vals = field_values["transformation_id"] + + # Build selections via Cartesian product + selections = [ + AtomSelection( + chain_id=c if c != "*" else "*", + res_name=r if r != "*" else "*", + res_id=i if i != "*" else "*", + atom_name=a if a != "*" else "*", + transformation_id=t if t != "*" else "*", + ) + for c, r, i, a, t in product(chain_vals, resn_vals, resi_vals, atom_vals, tran_vals) + ] + + return selections + def get_mask(self, atom_array: AtomArray | AtomArrayStack) -> np.ndarray: - """Create a boolean mask using this AtomSelection on an AtomArray.""" - return reduce(np.logical_or, [selection.get_mask(atom_array) for selection in self.selections]) + """Create a boolean mask by unioning all selections.""" + if not self.selections: + return np.zeros(atom_array.array_length(), dtype=bool) + + masks = [selection.get_mask(atom_array) for selection in self.selections] + return reduce(np.logical_or, masks) def get_center_of_mass(self, atom_array: AtomArray | AtomArrayStack) -> np.ndarray: - """Get the center of mass of the selected atoms in the AtomArray.""" + """Return the center of mass of the selected atoms. + + Returns: + For :py:class:`~biotite.structure.AtomArray`: ``(3,)`` array. + For :py:class:`~biotite.structure.AtomArrayStack`: ``(n_models,)`` array of means. + + Raises: + ValueError: If no atoms are selected. + """ mask = self.get_mask(atom_array) if not np.any(mask): raise ValueError("No atoms selected by the AtomSelectionStack.") @@ -437,10 +636,14 @@ def get_center_of_mass(self, atom_array: AtomArray | AtomArrayStack) -> np.ndarr raise ValueError(f"Cannot get center of mass for {type(atom_array)}!") def get_principal_components(self, atom_array: AtomArray | AtomArrayStack) -> np.ndarray: - """Get the principal components of the selected atoms in the AtomArray. + """Return principal axes (eigenvectors) of the selected atoms via SVD. Returns: - - np.ndarray: Principal axes (eigenvectors). For AtomArray: (3, 3). For AtomArrayStack: (n_models, 3, 3). + ``(3, 3)`` array for :py:class:`~biotite.structure.AtomArray`. + ``(n_models, 3, 3)`` array for :py:class:`~biotite.structure.AtomArrayStack`. + + Raises: + ValueError: If no atoms are selected. """ mask = self.get_mask(atom_array) if not np.any(mask): diff --git a/src/atomworks/io/utils/sequence.py b/src/atomworks/io/utils/sequence.py index 1c1052a8..869ff045 100644 --- a/src/atomworks/io/utils/sequence.py +++ b/src/atomworks/io/utils/sequence.py @@ -36,8 +36,13 @@ @functools.cache def aa_chem_comp_3to1(standard_only: bool = False) -> dict[str, str]: - """ - Returns a dictionary mapping 3-letter amino acid codes to 1-letter codes. + """Returns a dictionary mapping 3-letter amino acid codes to 1-letter codes. + + Args: + standard_only: If True, only include standard amino acids. + + Returns: + Dictionary mapping 3-letter to 1-letter amino acid codes. """ aa_3to1 = toolz.keyfilter(lambda x: x in aa_chem_comps(), chem_comp_to_one_letter()) if standard_only: @@ -47,8 +52,13 @@ def aa_chem_comp_3to1(standard_only: bool = False) -> dict[str, str]: @functools.cache def na_chem_comp_3to1(standard_only: bool = False) -> dict[str, str]: - """ - Returns a dictionary mapping 3-letter DNA codes to 1-letter codes. + """Returns a dictionary mapping 3-letter DNA codes to 1-letter codes. + + Args: + standard_only: If True, only include standard nucleic acids. + + Returns: + Dictionary mapping 3-letter to 1-letter nucleic acid codes. """ na_3to1 = toolz.keyfilter(lambda x: x in na_chem_comps(), chem_comp_to_one_letter()) if standard_only: diff --git a/src/atomworks/io/utils/testing.py b/src/atomworks/io/utils/testing.py index 937f6aec..582abd3d 100644 --- a/src/atomworks/io/utils/testing.py +++ b/src/atomworks/io/utils/testing.py @@ -196,6 +196,10 @@ def assert_same_atom_array( arr2, AtomArray | AtomArrayStack ), f"arr2 is not an AtomArray or AtomArrayStack but has type {type(arr2)}" + # Copy both arrays to avoid modifying the original arrays + arr1 = arr1.copy() + arr2 = arr2.copy() + # If the input is a stack, only compare the first array if isinstance(arr1, AtomArrayStack): arr1 = arr1[0] @@ -310,6 +314,14 @@ def convert_atom_array_to_sorted_tuples(arr: AtomArray, annotations: list[str]) assert arr1.bonds is not None, "arr1.bonds is None" assert arr2.bonds is not None, "arr2.bonds is None" + # TODO: Switch to using the `convert_bond_type` method once we upgrade to Biotite v1.4.0 + # structure.bonds.convert_bond_type(struc.bonds.BondType.COORDINATION, struc.bonds.BondType.SINGLE) + mask_1 = arr1.bonds._bonds[:, 2] == struc.bonds.BondType.COORDINATION + arr1.bonds._bonds[mask_1, 2] = struc.bonds.BondType.SINGLE + + mask_2 = arr2.bonds._bonds[:, 2] == struc.bonds.BondType.COORDINATION + arr2.bonds._bonds[mask_2, 2] = struc.bonds.BondType.SINGLE + if enforce_order: # Compare bond arrays directly bonds1 = arr1.bonds.as_array() diff --git a/src/atomworks/ml/datasets/datasets.py b/src/atomworks/ml/datasets/datasets.py index 4bfd2493..c42ca6cf 100644 --- a/src/atomworks/ml/datasets/datasets.py +++ b/src/atomworks/ml/datasets/datasets.py @@ -1,8 +1,29 @@ +"""AtomWorks Dataset classes and common APIs. + +At a high level, to train models with AtomWorks, we need a Dataset class that: + (1) Takes as input an item index and returns the corresponding example information; typically includes: + a. Path to a structural file saved on disk (`/path/to/dataset/my_dataset_0.cif`) + b. Additional item-specific metadata (e.g., class labels) + (2) Pre-loads structural information from the returned example into an `AtomArray` and assembles inputs for the Transform pipeline + (3) Feed the input dictionary through a Transform pipeline and return the result + +Due to the heterogeneity of biomolecular data, in many cases, we may also want: + (4) In the event of a failure during the Transform pipeline, fall back to a different example + +For bespoke use cases, users may choose to write a custom Dataset that accomplish these steps; downstream code makes no assumptions. + +To accelerate development, we also provide an off-the-shelf, composable approach following common patterns: + - :class:`MolecularDataset`: Base class that handles pre-loading structural data and executing the Transform pipeline with error handling and debugging utilities + - :class:`PandasDataset`: A subclass of MolecularDataset for tabular data stored as pandas DataFrames + - :class:`FileDataset`: A subclass of MolecularDataset where each file is one example +""" + import copy import os import socket import time -from abc import abstractmethod +import warnings +from abc import ABC, abstractmethod from collections.abc import Callable from functools import cached_property from os import PathLike @@ -13,408 +34,456 @@ import pandas as pd from torch.utils.data import ConcatDataset, Dataset -from atomworks.common import default, exists from atomworks.ml.datasets import logger -from atomworks.ml.datasets.parsers import MetadataRowParser, load_example_from_metadata_row from atomworks.ml.preprocessing.constants import NA_VALUES -from atomworks.ml.transforms.base import Compose, Transform, TransformedDict +from atomworks.ml.transforms.base import TransformedDict from atomworks.ml.utils.debug import save_failed_example_to_disk -from atomworks.ml.utils.io import read_parquet_with_metadata +from atomworks.ml.utils.io import read_parquet_with_metadata, scan_directory from atomworks.ml.utils.rng import capture_rng_states -_USER = default(os.getenv("USER"), "") - -class BaseDataset(Dataset): - """ - Abstract base class for datasets. All dataset types (e.g., Pandas, Polars) should inherit from this class - and implement its methods. +class ExampleIDMixin(ABC): + """Mixin providing example ID functionality to a Dataset. - In addition to the standard PyTorch Dataset methods (`__getitem__`, `__len__`), this class requires - implementations for converting between example IDs and indices, which is necessary for our nested dataset structure. + Provides methods for converting between example IDs and indices, and checking + if an example ID exists in the dataset. """ @abstractmethod - def __getitem__(self, idx: int) -> Any: - pass + def __contains__(self, example_id: str) -> bool: + """Check if the dataset contains the example ID. - @abstractmethod - def __len__(self) -> int: - pass + Args: + example_id: The ID to check for. - @abstractmethod - def __contains__(self, example_id: str) -> bool: - """Check if the dataset contains the example ID.""" + Returns: + True if the ID exists in the dataset. + """ pass @abstractmethod def id_to_idx(self, example_id: str | list[str]) -> int | list[int]: - """Convert an example ID or list of example IDs to the corresponding index or indices.""" + """Convert example ID(s) to index(es). + + Args: + example_id: Single ID or list of IDs to convert. + + Returns: + Corresponding index or list of indices. + """ pass @abstractmethod def idx_to_id(self, idx: int | list[int]) -> str | list[str]: - """Convert an index or list of indices to the corresponding example ID or IDs.""" + """Convert index(es) to example ID(s). + + Args: + idx: Single index or list of indices to convert. + + Returns: + Corresponding ID or list of IDs. + """ pass -class FileDataset(BaseDataset): +class MolecularDataset(Dataset): + """Base class for AtomWorks molecular datasets. + + Handles Transform pipelines and loader functionality for molecular data. + Subclasses implement :meth:`__getitem__` with their own data access patterns. + """ + def __init__( self, - source: PathLike | list[str | PathLike], - filter_fn: Callable[[PathLike], bool] | None = None, - max_depth: int = 3, + *, + name: str, + transform: Callable | None = None, + loader: Callable | None = None, + save_failed_examples_to_dir: str | Path | None = None, ): - """Initialize a FileDataset that loads files from a directory or uses a pre-provided list. + """Initialize MolecularDataset. Args: - source: Either a directory path to scan for files, or a pre-built list of file paths - filter_fn: Optional function that takes a file path and returns True if the file should be included - max_depth: Maximum directory depth to scan (only used when source is a directory path) + name: Descriptive name for this dataset. Used for debugging and some + downstream functions when using nested datasets. + transform: Transform function or pipeline to apply to loaded data. + Should accept the output of the loader and return featurized data. + loader: Optional function to process raw dataset output into Transform-ready + format. For example, parsing structural files or gathering columns + into structured data. + save_failed_examples_to_dir: Optional directory path where failed examples + will be saved for debugging. Includes RNG state and error information. """ - if isinstance(source, str | Path): - # Directory scanning mode - self.dir_path = Path(source) - assert self.dir_path.is_dir(), f"Directory {source} does not exist." - - # Default filter accepts all files - self.filter_fn = filter_fn if filter_fn is not None else lambda x: True + self.loader = loader - # Scan directory for any files below - file_paths = self._scan_directory(max_depth=max_depth) + self.transform = transform + self.name = name + self.save_failed_examples_to_dir = Path(save_failed_examples_to_dir) if save_failed_examples_to_dir else None - elif isinstance(source, list): - # Pre-provided file list mode - self.dir_path = None + def _apply_loader(self, raw_data: Any) -> Any: + """Apply the loader function to raw data with timing and debugging. - # Convert to strings and apply filter if provided - file_paths = [str(path) for path in source] - if filter_fn is not None: - file_paths = [path for path in file_paths if filter_fn(path)] - self.filter_fn = filter_fn + Args: + raw_data: The raw data to process. - else: - raise ValueError("source must be either a directory path (str/Path) or a list of file paths") + Returns: + Processed data ready for transforms. + """ + if self.loader is None: + return raw_data + + # Apply loader function with timing + _start_load_time = time.time() + data = self.loader(raw_data) + _stop_load_time = time.time() + + # Add timing information if data supports it (preserving TransformDataset behavior) + if isinstance(data, dict): + data = TransformedDict(data) + data.__transform_history__.append( + { + "name": "apply loader", + "instance": hex(id(self.loader)), + "start_time": _start_load_time, + "end_time": _stop_load_time, + "processing_time": _stop_load_time - _start_load_time, + } + ) - # Sort paths alphabetically for id<>idx consistency - file_paths.sort() + return data - self.file_paths = file_paths - self.path_to_idx = {path: i for i, path in enumerate(file_paths)} + def _apply_transform(self, data: Any, example_id: str | None = None, idx: int | None = None) -> Any: + """Apply the Transform pipeline with error handling and debugging support. - def _scan_directory(self, max_depth: int) -> list[str]: - """Fast directory scan without worrying about order.""" - file_paths = [] + Args: + data: The loaded data ready for transforms. + example_id: Optional example ID for debugging purposes. If not provided, + will generate one using dataset name and index. + idx: Optional dataset index for error reporting. - for root, dirs, files in os.walk(self.dir_path): - current_depth = len(Path(root).relative_to(self.dir_path).parts) + Returns: + Transformed data. - if current_depth >= max_depth: - dirs.clear() - continue + Raises: + KeyboardInterrupt: Always re-raised if encountered. + Exception: Any exception from the transform pipeline is re-raised. + """ + if self.transform is None: + return data - for file in files: - file_path = os.path.join(root, file) - if self.filter_fn(file_path): - file_paths.append(file_path) + # Generate default example_id from idx and dataset name if not provided + if example_id is None and idx is not None: + example_id = f"{self.name}_{idx}" - return file_paths + # Get process id and hostname for debugging + if example_id: + logger.debug(f"({socket.gethostname()}:{os.getpid()}) Processing example: {example_id}") - def __len__(self) -> int: - return len(self.file_paths) + try: + # Capture RNG state for reproducibility before applying Transforms + rng_state_dict = capture_rng_states(include_cuda=False) + data = self.transform(data) + return data - def __contains__(self, example_id: str) -> bool: - return example_id in self.path_to_idx + except KeyboardInterrupt: + # Always re-raise keyboard interrupts + raise + except Exception as e: + logger.error(e) + + if self.save_failed_examples_to_dir and example_id: + save_failed_example_to_disk( + example_id=example_id, + error_msg=e, + rng_state_dict=rng_state_dict, + data={}, # We do not save the data by default, since it may be large + fail_dir=self.save_failed_examples_to_dir, + ) + + # Re-raise the original exception + raise + + def __getitem__(self, index: int) -> Any: + """Return a fully-featurized data example given an index. + + Subclasses should implement this method to: + 1. Query the underlying data source for raw data at the given index + 2. Optionally pre-process data to prepare for the Transform pipeline + 3. Feed the input dictionary through a Transform pipeline + + Typical output for an activity prediction network: + Step 1: ``{"path": "/path/to/dataset", "class_label": 5}`` + Step 2: ``{"atom_array": AtomArray, "extra_info": {"class_label": 5}}`` + Step 3: ``{"features": torch.Tensor, "class_label": torch.Tensor}`` - def id_to_idx(self, example_id: str | list[str]) -> int | list[int]: - if isinstance(example_id, list): - return [self.path_to_idx[id] for id in example_id] - return self.path_to_idx[example_id] + Args: + index: The index of the example to retrieve. - def idx_to_id(self, idx: int | list[int]) -> str | list[str]: - if isinstance(idx, list): - return [self.file_paths[i] for i in idx] - return self.file_paths[idx] + Returns: + Fully-featurized data example. + """ + raise - def __getitem__(self, idx: int) -> Any: - """Return the file path at the given index. + def __len__(self) -> int: + """Return the number of examples in the dataset. - Subclasses can override this to load and process the file content instead. + Returns: + The dataset length. """ - return self.file_paths[idx] - + pass -class StructuralFileDataset(FileDataset): - """FileDataset with StructuralDatasetWrapper compatibility. - Inherits all functionality from FileDataset but adds: - - .data property that returns a pandas DataFrame for compatibility - - __getitem__ returns pandas Series instead of just file paths - - Optional name attribute for logging/debugging +class FileDataset(MolecularDataset, ExampleIDMixin): + """Dataset that loads molecular data from individual files. - Allows integration with StructuralDatasetWrapper, samplers, and weight calculation. + Each file represents one example in the dataset. If creating a dataset from a + directory, use the :meth:`from_directory` class method instead of the default + constructor. """ def __init__( self, - source: PathLike | list[str | PathLike], + *, + file_paths: list[str | PathLike], + name: str, filter_fn: Callable[[PathLike], bool] | None = None, - max_depth: int = 3, - name: str | None = None, + **kwargs: Any, ): - """ + """Initialize FileDataset. + Args: - source: Either a directory path to scan for files, or a pre-built list of file paths - filter_fn: Optional function that takes a file path and returns True if the file should be included - max_depth: Maximum directory depth to scan (only used when source is a directory path) - name: Optional name for the dataset (useful for logging and debugging) + file_paths: List of file paths for the dataset. Each file represents + one example. + name: Descriptive name for this dataset. Used for debugging and some + downstream functions when using nested datasets. + filter_fn: Optional function to filter file paths. Should return True + for files to include. + **kwargs: Additional arguments passed to :class:`MolecularDataset`. + + Examples: + Create from explicit file list: + >>> files = ["/path/to/file1.cif", "/path/to/file2.cif"] + >>> dataset = FileDataset(file_paths=files, name="my_dataset") """ - super().__init__(source, filter_fn, max_depth) - self.name = name if name is not None else f"StructuralFileDataset({source})" + super().__init__(name=name, **kwargs) - assert len(self.file_paths) == len(set(self.file_paths)), "File paths must be unique." + self.filter_fn = filter_fn or (lambda x: True) - @cached_property - def data(self) -> pd.DataFrame: - """Return a pandas DataFrame with file paths and generated example IDs. + # Convert to Path objects and filter + file_paths = [Path(path) for path in file_paths if self.filter_fn(path)] + if not file_paths: + raise ValueError("No files found after applying filters") + if len(file_paths) != len(set(file_paths)): + raise ValueError("File paths must be unique") - This property makes StructuralFileDataset compatible with StructuralDatasetWrapper - and other components that expect a .data attribute. - """ - # Generate example IDs from file paths (use filename without extension) - example_ids = [] - for file_path in self.file_paths: - filename = Path(file_path).stem # filename without extension - # If filename has multiple extensions (e.g., .cif.gz), remove them all - while "." in filename: - filename = Path(filename).stem - example_ids.append(filename) - - # Create DataFrame with path and example_id columns - df = pd.DataFrame( - { - "path": self.file_paths, - "example_id": example_ids, - } - ) + # Sort for consistent id<>idx mapping + file_paths.sort() + self.file_paths = file_paths - # Set example_id as index for fast lookups - df.set_index("example_id", inplace=True, drop=False, verify_integrity=True) # No duplicates allowed + # Create ID mapping + self.id_to_idx_map = {self._get_example_id(i): i for i, _ in enumerate(self.file_paths)} - return df + # Verify that all example IDs are unique + if len(self.id_to_idx_map) != len(self.file_paths): + raise ValueError("Example IDs must be unique. Found duplicate example IDs.") - def __getitem__(self, idx: int) -> Any: - return self.data.iloc[idx] + @classmethod + def from_directory( + cls, + *, + directory: PathLike, + name: str, + max_depth: int = 3, + **kwargs: Any, + ) -> "FileDataset": + """Create a FileDataset by scanning a directory for files. + Args: + directory: Path to directory to scan for files. + name: Descriptive name for this dataset. + max_depth: Maximum depth to scan for files in subdirectories. + **kwargs: Additional arguments passed to :class:`FileDataset`. -class StructuralDatasetWrapper(BaseDataset): - def __init__( - self, - dataset: Dataset, - dataset_parser: MetadataRowParser, - cif_parser_args: dict | None = None, - transform: Transform | Compose | None = None, - return_key: str | None = None, - save_failed_examples_to_dir: PathLike | str | None = None, - ): - """ - Decorator (wrapper) for an arbitrary Dataset (e.g., PandasDataset, PolarsDataset, etc.) to handle loading of structural data from PDB or CIF files, - parsing, and applying a Transformation pipeline to the data. + Returns: + FileDataset instance with files discovered from the directory. - Designed to be used with a Transforms pipeline to process the data and a MetadataRowParser to convert the dataset rows into a common dictionary format. + Example: + Create from directory: + >>> dataset = FileDataset.from_directory(directory="/path/to/files", name="my_dataset", max_depth=2) + """ + dir_path = Path(directory) + if not dir_path.exists(): + raise FileNotFoundError(f"Directory {directory} does not exist.") + if not dir_path.is_dir(): + raise ValueError(f"Path {directory} is not a directory.") + + file_paths = scan_directory(dir_path=dir_path, max_depth=max_depth) + return cls(file_paths=file_paths, name=name, **kwargs) + + @classmethod + def from_file_list( + cls, + *, + file_paths: list[str | PathLike], + name: str, + **kwargs: Any, + ) -> "FileDataset": + """Create a FileDataset from an explicit list of file paths. - For more detail, see the README in the `datasets` directory. + This is an alias for the main constructor for clarity and consistency + with :meth:`from_directory`. Args: - dataset (Dataset): The dataset to wrap. For example, a PandasDataset, PolarsDataset, or standard PyTorch Dataset. - dataset_parser (MetadataRowParser): Parser to convert dataset metadata rows into a common dictionary format. See `atomworks.ml.datasets.dataframe_parsers`. - cif_parser_args (dict, optional): Arguments to pass to `atomworks.io.parse` (will override the defaults). Defaults to None. - transform (Transform | Compose, optional): Transformation pipeline to apply to the data. See `atomworks.ml.transforms.base`. - return_key (str, optional): Key to return from the data dictionary instead of the entire dictionary. - save_failed_examples_to_dir (PathLike | str | None, optional): Directory to save failed examples. - - Example usage: - ```python - dataset = StructuralDatasetDecorator(dataset=PandasDataset(data="path/to/data.csv"), ...) - dataset[0] # Returns the processed data for the first example. - ``` + file_paths: List of file paths for the dataset. Each file represents one example. + name: Descriptive name for this dataset. + **kwargs: Additional arguments passed to :class:`FileDataset`. + + Returns: + FileDataset instance with the provided file paths. """ - # ...basic assignments - self.transform = transform - self.return_key = return_key - self.save_failed_examples_to_dir = ( - Path(save_failed_examples_to_dir) if exists(save_failed_examples_to_dir) else None - ) - self.cif_parser_args = cif_parser_args - self.dataset_parser = dataset_parser - self.dataset = dataset + return cls(file_paths=file_paths, name=name, **kwargs) - # ...carry forward the data - self.data = self.dataset.data + def __len__(self) -> int: + """Return the number of files in the dataset.""" + return len(self.file_paths) - # ...carry forward the name - self.name = self.dataset.name if hasattr(self.dataset, "name") else repr(self.dataset) + def __contains__(self, example_id: str) -> bool: + """Check if the dataset contains the example ID.""" + return example_id in self.id_to_idx_map + + def id_to_idx(self, example_id: str | list[str]) -> int | list[int]: + """Convert example ID(s) to index(es).""" + if isinstance(example_id, list): + return [self.id_to_idx_map[id] for id in example_id] + return self.id_to_idx_map[example_id] + + def idx_to_id(self, idx: int | list[int]) -> str | list[str]: + """Convert index(es) to example ID(s).""" + if isinstance(idx, list): + return [self._get_example_id(i) for i in idx] + return self._get_example_id(idx) def __getitem__(self, idx: int) -> Any: - """ - Performs the following steps: - (1) Retrieve the row at the specified index from the dataset using the __getitem__ method. - (2) Parse the row into a common dictionary format using the dataset parser. - (3) Load the CIF file from the information in the common dictionary format (i.e., the "path" key). - (4) Apply the transformation pipeline to the data which, at a minimum, contains the output of `atomworks.io` parsing. + """Load and transform an example by file index. Args: - idx (int): The index of the item to retrieve. + idx: The index of the file to load. Returns: - Any: The processed item. + Transformed data from the file. """ + file_path = str(self.file_paths[idx]) + example_id = self._get_example_id(idx) + data = self._apply_loader(file_path) + return self._apply_transform(data, example_id=example_id, idx=idx) - # Capture example ID & current rng state (for reproducibility & debugging) - if hasattr(self, "idx_to_id"): - # ...if the dataset has a custom idx_to_id method, use it (e.g., for a PandasDataset) - example_id = self.idx_to_id(idx) - else: - # ...otherwise, fallback to a the `id_column` or a string representation of the index - example_id = self.dataset[idx][self.id_column] if self.id_column else f"row_{idx}" - - # Get process id and hostname (for debugging) - logger.debug(f"({socket.gethostname()}:{os.getpid()}) Processing example ID: {example_id}") - - # Load the row, using the __getitem__ method of the dataset - row = self.dataset[idx] - - # Process the row into a transform-ready dictionary with the given CIF and dataset parsers - # We require the "data" dictionary output from `load_example_from_metadata_row` to contain, at a minimum: - # (a) An "id" key, which uniquely identifies the example within the dataframe; and, - # (b) The "path" key, which is the path to the CIF file - _start_parse_time = time.time() - data = load_example_from_metadata_row(row, self.dataset_parser, cif_parser_args=self.cif_parser_args) - _stop_parse_time = time.time() - - # Manually add timing for cif-parsing - data = TransformedDict(data) - data.__transform_history__.append( - { - "name": "load_example_from_metadata_row", - "instance": hex(id(load_example_from_metadata_row)), - "start_time": _start_parse_time, - "end_time": _stop_parse_time, - "processing_time": _stop_parse_time - _start_parse_time, - } - ) + def _get_example_id(self, idx: int) -> str: + """Get example ID from index - returns filename stem without extensions. - # Apply the transformation pipeline to the data - if exists(self.transform): - try: - rng_state_dict = capture_rng_states(include_cuda=False) - data = self.transform(data) - except KeyboardInterrupt as e: - raise e - except Exception as e: - # Log the error and save the failed example to disk (optional) - logger.info(f"Error processing row {idx} ({example_id}): {e}") - - if exists(self.save_failed_examples_to_dir): - save_failed_example_to_disk( - example_id=example_id, - error_msg=e, - rng_state_dict=rng_state_dict, - data={}, # We do not save the data, since it may be large. - fail_dir=self.save_failed_examples_to_dir, - ) - raise e - - # Return the specified key or the entire data dict (i.e., only "feats" key from the Transform dictionary) - if exists(self.return_key): - return data[self.return_key] - else: - return data - - def __len__(self) -> int: - """Pass through the length of the wrapped dataset.""" - return len(self.dataset) - - def __contains__(self, example_id: str) -> bool: - """Pass through the contains method of the wrapped dataset.""" - return example_id in self.dataset - - def id_to_idx(self, example_id: str) -> int: - """Pass through the id_to_idx method of the wrapped dataset.""" - return self.dataset.id_to_idx(example_id) - - def idx_to_id(self, idx: int) -> str: - """Pass through the idx_to_id method of the wrapped dataset.""" - return self.dataset.idx_to_id(idx) - - def __getattr__(self, name: str) -> Any: - """Delegate attribute access to the wrapped dataset.""" - try: - # `object.__getattribute__(self, "dataset")` bypasses the custom `__getattr__` and safely retrieves the attribute, - # avoiding infinite recursion. - dataset = object.__getattribute__(self, "dataset") - return getattr(dataset, name) - except AttributeError: - raise AttributeError(f"'{type(self).__name__}' object (or its wrapped dataset) has no attribute '{name}'") # noqa: B904 + Args: + idx: The index of the file. + Returns: + Filename stem without any extensions. + """ + file_path = self.file_paths[idx] + filename = Path(file_path).stem + # If filename has multiple extensions (e.g., .cif.gz), remove them all + while "." in filename: + filename = Path(filename).stem + return filename -class PandasDataset(BaseDataset): - """ - A wrapper around PyTorch's Dataset class that allows for easy loading, filtering, and indexing of datasets stored as Pandas DataFrames. - The underlying DataFrame can be accessed via the `data` property. - For example usage, see the tests in `tests/datasets/test_datasets.py`. +class PandasDataset(MolecularDataset, ExampleIDMixin): + """Dataset for tabular data stored as pandas DataFrames. - Args: - data (pd.DataFrame | PathLike): The dataset, either as a Pandas DataFrame or a path to a file. - id_column (str | None, optional): The column to use as the index; must be unique within the DataFrame. Defaults to None. - For example, we use the `example_id` column as the index in the `PDBDataset`. By setting the dataframe index to the `example_id` - column, we can retrieve the row corresponding to a specific example ID by calling `dataset.data.loc[example_id]` in O(1) time. - filters (list[str] | None, optional): A list of query strings to filter the data. Defaults to None. For examples on how to specify filters, - see the docstring for `_apply_filters`. - name (str | None, optional): The name of the dataset. Defaults to None. Useful for debugging and logging. - columns_to_load (list[str] | None, optional): Specific columns to load if data is provided as a file path. Defaults to None. Helpful for - large datasets where only a subset of columns is needed (if using `parquet` or other columnar storage formats). - **load_kwargs (Any): Additional keyword arguments for loading the data. - - Attributes: - data (pd.DataFrame): The underlying DataFrame, accessible via the `data` property. + Inherits all functionality from :class:`MolecularDataset` with additional + DataFrame-specific features for filtering and ID-based access. """ def __init__( self, *, data: pd.DataFrame | PathLike, - id_column: str | None = None, + name: str, + id_column: str | None = "example_id", filters: list[str] | None = None, - name: str | None = None, columns_to_load: list[str] | None = None, - **load_kwargs: Any, + # MolecularDataset parameters + transform: Callable | None = None, + loader: Callable | None = None, + save_failed_examples_to_dir: str | Path | None = None, + load_kwargs: dict | tuple | None = None, ): - if name is not None: - self.name = name - else: - self.name = repr(self) + """Initialize PandasDataset. + + Args: + data: Either a pandas DataFrame or path to a CSV/Parquet file containing + the tabular data. Each row represents one example. + name: Descriptive name for this dataset. Used for debugging and some + downstream functions when using nested datasets. + id_column: Optional column name to use as the DataFrame index for + example ID lookups. If provided, this column will be set as the index. + filters: Optional list of pandas query strings to filter the data. + Applied in order during initialization. + columns_to_load: Optional list of column names to load when reading + from a file. If None, all columns are loaded. Can dramatically reduce + memory usage and load time if loading from a columnar format like Parquet. + transform: Transform pipeline to apply to loaded data. + loader: Optional function to process raw DataFrame rows into Transform-ready format. + save_failed_examples_to_dir: Optional directory path where failed examples + will be saved for debugging. Includes RNG state and error information. + load_kwargs: Additional keyword arguments passed to pandas' read functions + (read_csv, read_parquet) when loading from file. + + Examples: + Load from DataFrame: + >>> df = pd.DataFrame({"path": [...], "label": [...]}) + >>> dataset = PandasDataset(data=df, name="my_dataset") + + Load from file with filtering: + >>> dataset = PandasDataset(data="data.csv", name="filtered_dataset", filters=["label > 0", "path.str.contains('.pdb')"]) + """ + super().__init__( + name=name, + transform=transform, + loader=loader, + save_failed_examples_to_dir=save_failed_examples_to_dir, + ) - # Load the data from the path, if provided (and load only the specified columns) + # Load data from path if needed if isinstance(data, PathLike | str): - data = self._load_from_path(data, columns_to_load, **load_kwargs) - self._data = data + data = self._load_from_path(data, columns_to_load, **(load_kwargs or {})) + self.data = data - # Apply filters, if provided + # Apply filters self.filters = filters self._already_filtered = False - if exists(filters): + if filters: self._apply_filters(filters) self._already_filtered = True + # Set index column if specified if id_column is not None: - assert id_column in self._data.columns, f"Column {id_column} not found in dataset." - self._data.set_index(id_column, inplace=True, drop=False, verify_integrity=True) + assert id_column in self.data.columns, f"Column {id_column} not found in dataset." + self.data.set_index(id_column, inplace=True, drop=False, verify_integrity=True) def _load_from_path( self, path: PathLike | str, columns_to_load: list[str] | None = None, **load_kwargs: Any ) -> pd.DataFrame: + """Load data from file path. + + Args: + path: Path to the file to load. + columns_to_load: Optional list of column names to load. + **load_kwargs: Additional arguments for pandas read functions. + + Returns: + Loaded DataFrame. + + Raises: + ValueError: If file type is unsupported. + """ path = Path(path) if path.suffix == ".csv": data = pd.read_csv(path, usecols=columns_to_load, keep_default_na=False, na_values=NA_VALUES, **load_kwargs) @@ -424,27 +493,36 @@ def _load_from_path( raise ValueError(f"Unsupported file type: {path.suffix}") return data - @property - def data(self) -> pd.DataFrame: - """Expose underlying dataframe as property to discourage changing it (can lead to unexpected behavior with torch ConcatDatasets).""" - return self._data - def __getitem__(self, idx: int) -> Any: - return self._data.iloc[idx] + """Get an example by index, applying specified loader and Transforms. + + Args: + idx: The index of the example to retrieve. + + Returns: + Transformed data from the row. + """ + raw_data = self.data.iloc[idx] + example_id = self._get_example_id(idx) + data = self._apply_loader(raw_data) + return self._apply_transform(data, example_id=example_id, idx=idx) def __len__(self) -> int: - return len(self._data) + """Return the number of rows in the dataset.""" + return len(self.data) def __contains__(self, example_id: str) -> bool: """Check if the dataset contains the example ID.""" - return example_id in self._data.index + return example_id in self.data.index def _id_to_index_single(self, example_id: str) -> int: - return self._data.index.get_loc(example_id) + """Convert single example ID to index.""" + return self.data.index.get_loc(example_id) def _id_to_index_multiple(self, example_ids: list[str]) -> list[int]: - idxs = np.arange(len(self._data)) - return [idxs[self._data.index.get_loc(example_id)] for example_id in example_ids] + """Convert multiple example IDs to indices.""" + idxs = np.arange(len(self.data)) + return [idxs[self.data.index.get_loc(example_id)] for example_id in example_ids] def id_to_idx(self, example_id: str | list[str]) -> int | list[int]: """Convert an example ID to the corresponding local index.""" @@ -462,31 +540,29 @@ def idx_to_id(self, idx: int | list[int]) -> str | np.ndarray: _return_single = True idx = idx.item() if isinstance(idx, np.ndarray) else idx idx = slice(idx, idx + 1) - ids = self._data.iloc[idx].index.values + ids = self.data.iloc[idx].index.values return ids[0] if _return_single else ids def _apply_filters(self, filters: list[str]) -> pd.DataFrame: - """ - Apply filters to the data based on the provided list of query strings. + """Apply filters to the data based on the provided list of query strings. + For documentation on pandas query syntax, see: https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html Args: - filters (List[str]): List of query strings to apply to the data. + filters: List of query strings to apply to the data. Raises: ValueError: If the data is not initialized or if a query removes all rows. Warning: If a query does not remove any rows. - Exampleelse: - logger.info( - f"Query '{query}' filtered dataset from {original_num_rows:,} to {filtered_num_rows:,} rows (dropped {original_num_rows - filtered_num_rows:,} rows)" - ): - queries = [ - "deposition_date < '2020-01-01'", - "resolution < 2.5 and ~method.str.contains('NMR')", - "cluster.notnull()", - "method in ['X-RAY_DIFFRACTION', 'ELECTRON_MICROSCOPY']" - ] + Example: + >>> queries = [ + >>> "deposition_date < '2020-01-01'", + >>> "resolution < 2.5 and ~method.str.contains('NMR')", + >>> "cluster.notnull()", + >>> "method in ['X-RAY_DIFFRACTION', 'ELECTRON_MICROSCOPY']" + >>> ] + >>> dataset = PandasDataset(data="data.csv", name="filtered_dataset", filters=queries) """ assert not self._already_filtered, "Filters cannot be applied after initialization." @@ -495,30 +571,27 @@ def _apply_filters(self, filters: list[str]) -> pd.DataFrame: self._apply_query(query) def _apply_query(self, query: str) -> None: - """ - Apply a single query to the data. + """Apply a single query to the data. Args: - query (str): A query string to apply to the data. + query: The pandas query string to apply. """ # Filter using query and validate impact - original_num_rows = len(self._data) - self._data = self._data.query(query) - filtered_num_rows = len(self._data) + original_num_rows = len(self.data) + self.data = self.data.query(query) + filtered_num_rows = len(self.data) self._validate_filter_impact(query, original_num_rows, filtered_num_rows) def _validate_filter_impact(self, query: str, original_num_rows: int, filtered_num_rows: int) -> None: - """ - Validate the impact of the filter. + """Validate the impact of the filter. Args: - query (str): The query string that was applied. - original_num_rows (int): The number of rows before applying the filter. - filtered_num_rows (int): The number of rows after applying the filter. + query: The query that was applied. + original_num_rows: Number of rows before filtering. + filtered_num_rows: Number of rows after filtering. Raises: - Warning: If the filter did not remove any rows. - ValueError: If the filter removed all rows. + ValueError: If the query removes all rows. """ rows_removed = original_num_rows - filtered_num_rows percent_removed = (rows_removed / original_num_rows) * 100 @@ -538,50 +611,97 @@ def _validate_filter_impact(self, query: str, original_num_rows: int, filtered_n f"+-------------------------------------------+\n" ) + def _get_example_id(self, idx: int) -> str: + """Get example ID from index - returns the index value from the DataFrame. + + Args: + idx: The index of the row. + + Returns: + The index value as a string. + """ + return str(self.data.iloc[idx].name) # .name gets the index value + class ConcatDatasetWithID(ConcatDataset): - """Equivalent to `torch.utils.data.ConcatDataset` but allows accessing examples by ID.""" + """Equivalent to :class:`torch.utils.data.ConcatDataset` but allows accessing examples by ID. + + Provides ID-based access across multiple datasets that implement :class:`ExampleIDMixin`. + """ + + # TODO: Do I need all of these _raise_if etc. etc. here? Can I just check that the wrapped datasets inherit somehow from ExampleIDMixin? - datasets: list[Dataset] + datasets: list[ExampleIDMixin] - def __init__(self, datasets: list[Dataset]): + def __init__(self, datasets: list[ExampleIDMixin]): + """Initialize ConcatDatasetWithID. + + Args: + datasets: List of datasets that implement ExampleIDMixin. + """ super().__init__(datasets) - # Print the length of each dataset + # Log the length of each dataset for i, dataset in enumerate(datasets): logger.info(f"Dataset {i} ({type(dataset)}): {len(dataset):,} examples") @cached_property def _can_convert_ids_and_idx(self) -> bool: + """Check if all sub-datasets can convert between IDs and indices.""" has_id_to_idx = all(hasattr(sub_dataset, "id_to_idx") for sub_dataset in self.datasets) has_idx_to_id = all(hasattr(sub_dataset, "idx_to_id") for sub_dataset in self.datasets) return has_id_to_idx and has_idx_to_id and self._can_check_contains @cached_property def _can_check_contains(self) -> bool: + """Check if all sub-datasets support contains operations.""" return all(hasattr(sub_dataset, "__contains__") for sub_dataset in self.datasets) def _raise_if_cannot_check_contains(self) -> None: + """Raise error if dataset cannot check contains.""" if not self._can_check_contains: raise ValueError("This dataset cannot check if it contains an example ID.") def _raise_if_cannot_convert_ids_and_idx(self) -> None: + """Raise error if dataset cannot convert IDs and indices.""" if not self._can_convert_ids_and_idx: raise ValueError("This dataset cannot convert example IDs to indices.") def _raise_if_idx_out_of_bounds(self, idx: int) -> None: + """Raise error if index is out of bounds. + + Args: + idx: The index to check. + """ if idx < 0 or idx >= len(self): raise ValueError(f"Index {idx} out of bounds for dataset of length {len(self)}.") def __contains__(self, example_id: str) -> bool: - """Check if the dataset contains the example ID.""" + """Check if the dataset contains the example ID. + + Args: + example_id: The ID to check for. + + Returns: + True if the ID exists in any sub-dataset. + """ self._raise_if_cannot_check_contains() return any(example_id in sub_dataset for sub_dataset in self.datasets) def id_to_idx(self, example_id: str) -> int: """Retrieves the index corresponding to the example ID. - WARNING: Assumes that the example ID is unique within the dataset. If not, + Args: + example_id: The ID to convert. + + Returns: + The corresponding index. + + Raises: + ValueError: If the example ID is not found. + + Warning: + Assumes that the example ID is unique within the dataset. If not, the first occurrence of the example ID is returned. """ # TODO: Generalize to list[str] @@ -594,7 +714,17 @@ def id_to_idx(self, example_id: str) -> int: raise ValueError(f"Example ID {example_id} not found in any sub-dataset.") def idx_to_id(self, idx: int) -> str: - """Retrieves the example ID corresponding to the index.""" + """Retrieves the example ID corresponding to the index. + + Args: + idx: The index to convert. + + Returns: + The corresponding example ID. + + Raises: + ValueError: If the index is out of bounds. + """ # TODO: Generalize to list[int] self._raise_if_cannot_convert_ids_and_idx() self._raise_if_idx_out_of_bounds(idx) @@ -606,7 +736,17 @@ def idx_to_id(self, idx: int) -> str: raise ValueError(f"Index {idx} out of bounds for any sub-dataset.") def get_dataset_by_idx(self, idx: int) -> Dataset: - """Retrieves the dataset containing the index.""" + """Retrieves the dataset containing the index. + + Args: + idx: The index to find. + + Returns: + The sub-dataset containing the index. + + Raises: + ValueError: If the index is out of bounds. + """ self._raise_if_idx_out_of_bounds(idx) for sub_dataset in self.datasets: if idx < len(sub_dataset): @@ -618,24 +758,30 @@ def get_dataset_by_idx(self, idx: int) -> Dataset: def get_dataset_by_id(self, example_id: str) -> Dataset: """Retrieves the dataset containing the example ID. - WARNING: Assumes that the example ID is unique within the dataset. If not, + Args: + example_id: The ID to find. + + Returns: + The sub-dataset containing the ID. + + Warning: + Assumes that the example ID is unique within the dataset. If not, the first occurrence of the example ID is returned. """ idx = self.id_to_idx(example_id) return self.get_dataset_by_idx(idx) -def get_row_and_index_by_example_id(dataset: ConcatDatasetWithID, example_id: str) -> dict: - """ - Retrieve a row and its index from a nested dataset structure by its example ID. +def get_row_and_index_by_example_id(dataset: ExampleIDMixin, example_id: str) -> dict: + """Retrieve a row and its index from a nested dataset structure by its example ID. - Parameters: - dataset (PandasDataset | ConcatDataset): The dataset or concatenated dataset to search. + Args: + dataset: The dataset or concatenated dataset to search. Must have the `id_to_idx` method. - example_id (str): The example ID to search for. + example_id: The example ID to search for. Returns: - tuple: A tuple containing the row (pd.Series) and the (global)index (int) corresponding to the example ID. + Dictionary containing the row and global index corresponding to the example ID. """ assert hasattr(dataset, "id_to_idx"), "Dataset must have the `id_to_idx` method." idx = dataset.id_to_idx(example_id) @@ -650,28 +796,25 @@ def get_row_and_index_by_example_id(dataset: ConcatDatasetWithID, example_id: st class FallbackDatasetWrapper(Dataset): - """ - A wrapper around a dataset that allows for a fallback dataset to be used when an error occurs. + """A wrapper around a dataset that allows for a fallback dataset to be used when an error occurs. Meant to be used with a FallbackSamplerWrapper. """ def __init__(self, dataset: Dataset, fallback_dataset: Dataset): - """ - FallbackDatasetWrapper is a wrapper around a dataset that provides a fallback mechanism - to another dataset in case of errors during data retrieval. + """Initialize FallbackDatasetWrapper. - Attributes: - - dataset (Dataset): The primary dataset to retrieve data from. - - fallback_dataset (Dataset): The fallback dataset to use when an error occurs. This + Args: + dataset: The primary dataset to retrieve data from. + fallback_dataset: The fallback dataset to use when an error occurs. This may be the same as the primary dataset, or a different one. """ self.dataset = dataset self.fallback_dataset = fallback_dataset def __getitem__(self, idxs: tuple[int, ...]) -> Any: - """ - Attempt to retrieve an item from the primary dataset, falling back to additional indices if errors occur. + """Attempt to retrieve an item from the primary dataset, falling back to additional indices if errors occur. + If all attempts fail, raises a RuntimeError containing all encountered exceptions. Args: @@ -717,4 +860,61 @@ def __getitem__(self, idxs: tuple[int, ...]) -> Any: ) def __len__(self): + """Return the length of the primary dataset.""" return len(self.dataset) + + +# Backwards Compatibility +# TODO: Deprecate +def StructuralDatasetWrapper( # noqa: N802 + dataset_parser: Callable, + transform: Callable | None = None, + dataset: PandasDataset | None = None, + cif_parser_args: dict | None = None, + save_failed_examples_to_dir: str | Path | None = None, + **kwargs, +) -> PandasDataset: + """Backwards-compatible wrapper for the deprecated StructuralDatasetWrapper. + + This function is deprecated and will be removed in a future version. + Use :class:`PandasDataset` with the appropriate loader function instead. + + Args: + dataset_parser: The dataset parser to use (e.g., PNUnitsDFParser, InterfacesDFParser). + transform: Transform pipeline to apply to loaded data. + dataset: The underlying PandasDataset containing the tabular data. + cif_parser_args: Arguments to pass to the CIF parser. + save_failed_examples_to_dir: Directory to save failed examples for debugging. + **kwargs: Additional arguments passed to PandasDataset. + + Returns: + PandasDataset instance configured with the deprecated parameters. + + Raises: + ValueError: If dataset parameter is required but not provided. + """ + from atomworks.ml.datasets.parsers import load_example_from_metadata_row + + warnings.warn( + "StructuralDatasetWrapper is deprecated. Use PandasDataset with a loader function instead. " + "See atomworks.ml.datasets.loaders for functional alternatives to dataset parsers.", + DeprecationWarning, + stacklevel=2, + ) + + if dataset is None: + raise ValueError("dataset parameter is required for StructuralDatasetWrapper") + + # Create loader from deprecated parameters + def loader(row: pd.Series) -> dict[str, Any]: + return load_example_from_metadata_row(row, dataset_parser, cif_parser_args=cif_parser_args or {}) + + # Create a new PandasDataset with the loader + return PandasDataset( + data=dataset.data, + name=dataset.name if hasattr(dataset, "name") else "structural_dataset", + transform=transform, + loader=loader, + save_failed_examples_to_dir=save_failed_examples_to_dir, + **kwargs, + ) diff --git a/src/atomworks/ml/datasets/loaders.py b/src/atomworks/ml/datasets/loaders.py new file mode 100644 index 00000000..43d8e968 --- /dev/null +++ b/src/atomworks/ml/datasets/loaders.py @@ -0,0 +1,265 @@ +"""Functional loader implementations for AtomWorks datasets. + +Loaders are functions that process raw dataset output (e.g., pandas Series) into a Transform-ready format. +E.g., converts what may be dataset-specific metadata into a standard format for use in AtomWorks Transform pipelines. +""" + +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import pandas as pd +from toolz import keyfilter + +from atomworks.io.parser import parse +from atomworks.ml.utils.io import apply_sharding_pattern + + +def _construct_metadata_hierarchy(row: pd.Series, attrs: dict | None = None) -> dict[str, Any]: + """Construct metadata dictionary with precedence hierarchy. + + Assembles metadata from multiple sources with the following precedence (lowest to highest priority): + 1. DataFrame-level attributes (row.attrs) + 2. Row-level data (row.to_dict()) + 3. Loader-specific attributes (attrs parameter) + + Args: + row: pandas Series representing one dataset example + attrs: Optional loader-specific attributes to merge with highest precedence + + Returns: + Dictionary containing merged metadata with proper hierarchy + """ + # Start with DataFrame-level attributes (lowest precedence) + extra_info = row.attrs.copy() if hasattr(row, "attrs") else {} + + # Add row-level data (middle precedence) + extra_info.update(row.to_dict()) + + # Add loader-specific attributes (highest precedence) + extra_info.update(attrs or {}) + + return extra_info + + +def _construct_structure_path( + path: str, base_path: str | None, extension: str | None, sharding_pattern: str | None = None +) -> Path: + """Construct file path with optional base_path, extension, and sharding pattern. + + Args: + path: The base path or identifier (e.g., PDB ID) + base_path: Base directory to prepend + extension: File extension to add/replace + sharding_pattern: Pattern for organizing files in subdirectories + - "/1:2/": Use characters 1-2 for first directory level + - "/1:2/0:2/": Use chars 1-2 for first dir, then chars 0-2 for second dir + - None: No sharding (default) + """ + sharded_path = apply_sharding_pattern(path, sharding_pattern) + + if base_path: + final_path = Path(base_path) / sharded_path + else: + final_path = sharded_path + + if extension: + final_path = final_path.with_suffix(extension) + + return final_path + + +def _load_structure_from_path(path: Path, assembly_id: str, parser_args: dict | None = None) -> dict[str, Any]: + """Load structure from file path using the CIF parser.""" + result_dict = parse( + filename=path, + build_assembly=(assembly_id,), + **(parser_args or {}), + ) + return result_dict + + +def create_base_loader( + example_id_colname: str = "example_id", + path_colname: str = "path", + assembly_id_colname: str | None = "assembly_id", + attrs: dict | None = None, + base_path: str = "", + extension: str = "", + sharding_pattern: str | None = None, + parser_args: dict | None = None, +) -> Callable[[pd.Series], dict[str, Any]]: + """Factory function that creates a base loader with common logic for many AtomWorks datasets. + + Args: + example_id_colname: Name of column containing unique example identifiers + path_colname: Name of column containing paths to structure files + assembly_id_colname: Optional column name containing assembly IDs. + If None, assembly_id defaults to "1" for all examples. + attrs: Additional attributes to merge with highest precedence into the metadata hierarchy + (and ultimately included in the output dictionary's "extra_info" key). + base_path: Base path to prepend to file paths if not included in path column + extension: File extension to add/replace if not included in path column + sharding_pattern: Pattern for how files are organized in subdirectories, if not specified in the path + - "/1:2/": Use characters 1-2 for first directory level + - "/1:2/0:2/": Use chars 1-2 for first dir, then chars 0-2 for second dir + - None: No sharding (default) + parser_args: Optional dictionary of arguments to pass to the CIF parser when loading the structure file. + + Returns: + A function that takes a pandas Series and returns a dictionary of the loaded structure. + """ + + def loader_function(row: pd.Series) -> dict[str, Any]: + # Prepare loader-specific attributes + loader_attrs = attrs.copy() if attrs else {} + if base_path and "base_path" not in loader_attrs: + loader_attrs["base_path"] = base_path + if extension and "extension" not in loader_attrs: + loader_attrs["extension"] = extension + + extra_info = _construct_metadata_hierarchy(row, loader_attrs) + + assembly_id = ( + row[assembly_id_colname] if assembly_id_colname is not None and assembly_id_colname in row else "1" + ) + path = _construct_structure_path( + row[path_colname], extra_info.get("base_path"), extra_info.get("extension"), sharding_pattern + ) + result_dict = _load_structure_from_path(path, assembly_id, parser_args) + + # Remove used columns from extra_info to avoid duplication in the output dictionary + exclude_cols = ( + [example_id_colname, path_colname] + + ([assembly_id_colname] if assembly_id_colname else []) + + ["base_path", "extension"] + ) + extra_info = keyfilter(lambda k: k not in exclude_cols, extra_info) + + return { + # ... from the row and metadata hierarchy + "example_id": row[example_id_colname], + "path": path, + "assembly_id": assembly_id, + "extra_info": extra_info, + # ... from the CIF parser + "atom_array": result_dict["assemblies"][assembly_id][0], # First model + "atom_array_stack": result_dict["assemblies"][assembly_id], # All models + "chain_info": result_dict["chain_info"], + "ligand_info": result_dict["ligand_info"], + "metadata": result_dict["metadata"], + } + + return loader_function + + +def create_loader_with_query_pn_units( + example_id_colname: str = "example_id", + path_colname: str = "path", + pn_unit_iid_colnames: str | list[str] | None = None, + assembly_id_colname: str | None = "assembly_id", + base_path: str = "", + extension: str = "", + sharding_pattern: str | None = None, + attrs: dict | None = None, + parser_args: dict | None = None, +) -> Callable[[pd.Series], dict[str, Any]]: + """Factory function that creates a generic loader for pipelines with query pn_units (chains). + + For instance, in the interfaces dataset, each sampled row contains two pn_unit instance IDs + that should be included in the cropped structure. + + Examples: + Interfaces dataset: + >>> loader = create_loader_with_query_pn_units( + ... pn_unit_iid_colnames=["pn_unit_1_iid", "pn_unit_2_iid"], assembly_id_colname="assembly_id" + ... ) + + Chains dataset: + >>> loader = create_loader_with_query_pn_units( + ... pn_unit_iid_colnames="pn_unit_iid", base_path="/data/structures", extension=".cif.gz" + ... ) + """ + # Normalize pn_unit_iid_colnames to list format + if isinstance(pn_unit_iid_colnames, str): + pn_unit_iid_colnames = [pn_unit_iid_colnames] + pn_unit_iid_colnames = pn_unit_iid_colnames or [] + + # Create base loader with common parameters + base_loader = create_base_loader( + example_id_colname=example_id_colname, + path_colname=path_colname, + assembly_id_colname=assembly_id_colname, + attrs=attrs, + base_path=base_path, + extension=extension, + sharding_pattern=sharding_pattern, + parser_args=parser_args, + ) + + def loader_function(row: pd.Series) -> dict[str, Any]: + # Get base loader dictionary with common functionality + result = base_loader(row) + result["extra_info"] = keyfilter(lambda k: k not in pn_unit_iid_colnames, result["extra_info"]) + + # Add query-specific fields + query_pn_unit_iids = [row[colname] for colname in pn_unit_iid_colnames] + result["query_pn_unit_iids"] = query_pn_unit_iids + + return result + + return loader_function + + +def create_loader_with_interfaces_and_pn_units_to_score( + example_id_colname: str = "example_id", + path_colname: str = "path", + assembly_id_colname: str | None = "assembly_id", + interfaces_to_score_colname: str | None = "interfaces_to_score", + pn_units_to_score_colname: str | None = "pn_units_to_score", + base_path: str = "", + extension: str = "", + sharding_pattern: str | None = None, + attrs: dict | None = None, + parser_args: dict | None = None, +) -> Callable[[pd.Series], dict[str, Any]]: + """Factory function that creates a loader that adds interfaces and pn_units to score for validation datasets. + + Example: + >>> loader = create_loader_with_interfaces_and_pn_units_to_score( + ... interfaces_to_score_colname="interfaces_to_score", pn_units_to_score_colname="pn_units_to_score" + ... ) + """ + # Create base loader with common parameters + base_loader = create_base_loader( + example_id_colname=example_id_colname, + path_colname=path_colname, + assembly_id_colname=assembly_id_colname, + attrs=attrs, + base_path=base_path, + extension=extension, + sharding_pattern=sharding_pattern, + parser_args=parser_args, + ) + + def loader_function(row: pd.Series) -> dict[str, Any]: + # Get base loader dictionary with common functionality + result = base_loader(row) + result["extra_info"] = keyfilter( + lambda k: k not in [interfaces_to_score_colname, pn_units_to_score_colname], result["extra_info"] + ) + + # Add validation-specific fields + interfaces_to_score = row[interfaces_to_score_colname] if interfaces_to_score_colname is not None else None + pn_units_to_score = row[pn_units_to_score_colname] if pn_units_to_score_colname is not None else None + + result.update( + { + "interfaces_to_score": interfaces_to_score, + "pn_units_to_score": pn_units_to_score, + } + ) + + return result + + return loader_function diff --git a/src/atomworks/ml/datasets/parsers/base.py b/src/atomworks/ml/datasets/parsers/base.py index 2125463f..613bc32d 100644 --- a/src/atomworks/ml/datasets/parsers/base.py +++ b/src/atomworks/ml/datasets/parsers/base.py @@ -7,16 +7,18 @@ from atomworks.constants import CRYSTALLIZATION_AIDS from atomworks.io import parse -DEFAULT_CIF_PARSER_ARGS = { +DEFAULT_PARSER_ARGS = { "add_missing_atoms": True, "add_id_and_entity_annotations": True, "add_bond_types_from_struct_conn": ["covale"], "remove_ccds": CRYSTALLIZATION_AIDS, "remove_waters": True, - "hydrogen_policy": "remove", "fix_ligands_at_symmetry_centers": True, - "convert_mse_to_met": True, "fix_arginines": True, + "fix_formal_charges": True, + "fix_bond_types": True, + "convert_mse_to_met": True, + "hydrogen_policy": "remove", "model": None, # all models } """Default cif parser arguments for `atomworks.io.parse`. @@ -115,12 +117,13 @@ def load_example_from_metadata_row( cif_parser_args = {} # Convenience utilities to default to loading from and saving to cache if a cache_dir is provided, unless explicitly overridden + # TODO: Move to DEFAULT_CIF_PARSER_ARGS, but set to False by default not True if cif_parser_args.get("cache_dir"): cif_parser_args.setdefault("load_from_cache", True) cif_parser_args.setdefault("save_to_cache", True) # Merge DEFAULT_CIF_PARSER_ARGS with cif_parser_args, overriding with the keys present in cif_parser_args - merged_cif_parser_args = {**DEFAULT_CIF_PARSER_ARGS, **cif_parser_args} + merged_cif_parser_args = {**DEFAULT_PARSER_ARGS, **cif_parser_args} # Use the parse function with the merged CIF parser arguments result_dict = parse( diff --git a/src/atomworks/ml/datasets/parsers/default_metadata_row_parsers.py b/src/atomworks/ml/datasets/parsers/default_metadata_row_parsers.py index 1718538e..20fa2727 100644 --- a/src/atomworks/ml/datasets/parsers/default_metadata_row_parsers.py +++ b/src/atomworks/ml/datasets/parsers/default_metadata_row_parsers.py @@ -96,7 +96,7 @@ class PNUnitsDFParser(MetadataRowParser): def __init__( self, - base_dir: Path | str | list[Path | str] | tuple[Path | str, ...] = Path(PDB_MIRROR_PATH), + base_dir: Path | str | list[Path | str] | tuple[Path | str, ...] = PDB_MIRROR_PATH, file_extension: str | list[str] | tuple[str, ...] = ".cif.gz", path_template: str | list[str] | tuple[str, ...] = "{base_dir}/{pdb_id[1:3]}/{pdb_id}{file_extension}", ): @@ -148,7 +148,7 @@ class InterfacesDFParser(MetadataRowParser): def __init__( self, - base_dir: Path | str | list[Path | str] | tuple[Path | str, ...] = Path(PDB_MIRROR_PATH), + base_dir: Path | str | list[Path | str] | tuple[Path | str, ...] = PDB_MIRROR_PATH, file_extension: str | list[str] | tuple[str, ...] = ".cif.gz", path_template: str | list[str] | tuple[str, ...] = "{base_dir}/{pdb_id[1:3]}/{pdb_id}{file_extension}", ): @@ -193,38 +193,48 @@ class GenericDFParser(MetadataRowParser): We parse an input row (e.g., a Pandas Series) and return a dictionary containing pertinent information for the Transform pipeline. Args: - example_id_colname (str): Name of the column containing a unique identifier for each example (across ALL datasets, not just this dataset). - By convention, the columns values should be generated with `atomworks.ml.common.generate_example_id`. Default: "example_id" - path_colname (str): Name of the column containing paths (relative or absolute) to the relevant structure files. Default: "path" - pn_unit_iid_colnames (str | List[str]): The name(s) of the column(s) containing the CIFUtils pn_unit_iid(s); used for cropping. + example_id_colname: Name of the column containing a unique identifier for each example (across ALL datasets, not just this dataset). + By convention, the columns values should be generated with ``atomworks.ml.common.generate_example_id``. Default: "example_id" + path_colname: Name of the column containing paths (relative or absolute) to the relevant structure files. Default: "path" + pn_unit_iid_colnames: The name(s) of the column(s) containing the CIFUtils pn_unit_iid(s); used for cropping. If given as a list, should contain one element for a monomers dataset and two for an interfaces dataset. Default: None (crop randomly) - assembly_id_colname (str | None): Optional parameter giving the name of the column containing the assembly ID. + assembly_id_colname: Optional parameter giving the name of the column containing the assembly ID. If None, the assembly ID will be set to "1" for all examples. Default: None - base_path (str): The base path to the files, if not included in the path. - extension (str): The file extension of the structure files, if not included in the path. - attrs (dict): Additional attributes to be merged with the dataframe-level attributes stored in the DF (if present). Attributes + base_path: The base path to the files, if not included in the path. + extension: The file extension of the structure files, if not included in the path. + attrs: Additional attributes to be merged with the dataframe-level attributes stored in the DF (if present). Attributes in this dictionary will take precedence over those in the dataset-level attributes and will be returned in the "extra_info" key. Returns: - - example_id: The unique identifier for the example. Must be unique across all datasets. - - path: The composed path to the structure file, including the base path and extension if specified. - - query_pn_unit_iids: The pn_unit_iid(s) that inform where to crop the structure. - During TRAINING, we typically want to specify the chain(s) or interface at which to center our crop. If not given (i.e., None), + dict: A dictionary containing: + + example_id + The unique identifier for the example. Must be unique across all datasets. + path + The composed path to the structure file, including the base path and extension if specified. + query_pn_unit_iids + The pn_unit_iid(s) that inform where to crop the structure. + During TRAINING, we typically want to specify the chain(s) or interface at which to center our crop. If not given (i.e., None), then we will crop the structure at a random location, if a crop is required. - During VALIDATION, then we do not crop, and query_pn_unit_iids should be None. - - assembly_id: The assembly ID. Used to load the correct assembly from the CIF file. If not given, the assembly ID will be set to "1". - - extra_info: A dictionary containing all additional information that should be passed to the Transform pipeline. Contains, in order of precedence: - - Any additional key-value pairs specified by the `attrs` parameter - - All unused dataframe columns (i.e., those not used for example_id, path, query_pn_unit_iids, or assembly_id) - - Dataset-level attributes (if present), found in the `attrs` attribute of the Dataframe (or Series) - For example, the "extra_info" key could contain information about which chain(s) to score during validation, metadata for specific metrics, etc. - - NOTE: We must avoid duplication of interfaces due to order inversion. If not using the preprocessing - scripts in `atomworks.ml`, ensure that the interfaces dataframe has been checked for duplicates. + During VALIDATION, then we do not crop, and query_pn_unit_iids should be None. + assembly_id + The assembly ID. Used to load the correct assembly from the CIF file. If not given, the assembly ID will be set to "1". + extra_info + A dictionary containing all additional information that should be passed to the Transform pipeline. Contains, in order of precedence: + + - Any additional key-value pairs specified by the ``attrs`` parameter + - All unused dataframe columns (i.e., those not used for example_id, path, query_pn_unit_iids, or assembly_id) + - Dataset-level attributes (if present), found in the ``attrs`` attribute of the Dataframe (or Series) + For example, the "extra_info" key could contain information about which chain(s) to score during validation, metadata for specific metrics, etc. + + Note: + We must avoid duplication of interfaces due to order inversion. If not using the preprocessing + scripts in ``atomworks.ml``, ensure that the interfaces dataframe has been checked for duplicates. For example, [A, B] and [B, A] should be considered the same interface. - Example dataframe: + Example: + Example dataframe: example_id path pn_unit_1_iid pn_unit_2_iid {['my-dataset']}{ex_1}{1}{[A_1,B_1]} /path/to/structure_1.cif A_1 B_1 {['my-dataset']}{ex_2}{2}{[C_1,B_1]} /path/to/structure_2.cif C_1 B_1 diff --git a/src/atomworks/ml/encoding_definitions.py b/src/atomworks/ml/encoding_definitions.py index 7621d223..41b6fbea 100644 --- a/src/atomworks/ml/encoding_definitions.py +++ b/src/atomworks/ml/encoding_definitions.py @@ -35,34 +35,36 @@ @dataclass class TokenEncoding: - """A class to represent an fixed length token encoding. + """A class to represent a fixed length token encoding. Args: - token_atoms (dict[str, np.ndarray]): A dictionary mapping token names to atom names. + token_atoms: A dictionary mapping token names to atom names. The order of the tokens in the sequence determines the integer encoding of the token. The order of the atom names in the tuple determines the integer encoding of the atom name within the token. - chemcomp_type_to_unknown (dict[str, str]): A dictionary mapping chemical component types + chemcomp_type_to_unknown: A dictionary mapping chemical component types to unknown token names. This is used to map unknown residues to the respective unknown token. Different chemical component types may map to different unknown token names. - Defaults to `{}`, meaning that no unknown tokens are defined, leading to a `KeyError` + Defaults to ``{}``, meaning that no unknown tokens are defined, leading to a ``KeyError`` if an unknown residue is encountered. - NOTE: We follow these conventions for tokens to make them compatible with the CCD for - robust and easy tokenization. If you want to use the Transforms written for automatically - tokenizing and encoding, you need to follow these conventions. + Note: + We follow these conventions for tokens to make them compatible with the CCD for + robust and easy tokenization. If you want to use the Transforms written for automatically + tokenizing and encoding, you need to follow these conventions: + - When encoding a residue, we use the standardized (up to) 3-letter residue name from the CCD, - e.g. 'ALA' for Alanine, or `DA` for Deoxyadenosine, or `U` for Uracil. + e.g. ``'ALA'`` for Alanine, or ``'DA'`` for Deoxyadenosine, or ``'U'`` for Uracil. - When encoding unknown tokens, we may define different unknown tokens for different chemical components (e.g. a different unknown for proteins, vs. dna, ...). The - unkown tokens can take on any arbitrary 3-letter code that we want to map to, but + unknown tokens can take on any arbitrary 3-letter code that we want to map to, but they should not clash with existing residue names in the CCD. - When encoding an atom, we use the atomic number of the element as a string as the - token name. E.g. '1' for Hydrogen, '6' for Carbon, '9' for Fluorine, ... - For unknown atoms, we use '0' as the token name. - # TODO: Deal with ligand names such as `100` which is also an atomic number - - To denote masked tokens, we use a '<...>' syntax. E.g. '' for a generic mask token, - or '' for a mask token for proteins. The ... can be any arbitrary string. We + token name. E.g. ``'1'`` for Hydrogen, ``'6'`` for Carbon, ``'9'`` for Fluorine, ... + For unknown atoms, we use ``'0'`` as the token name. + # TODO: Deal with ligand names such as ``'100'`` which is also an atomic number + - To denote masked tokens, we use a ``'<...>'`` syntax. E.g. ``''`` for a generic mask token, + or ``''`` for a mask token for proteins. The ... can be any arbitrary string. We use the angle brackets to avoid clashes with existing residue names in the CCD. """ @@ -282,7 +284,7 @@ def __repr__(self): """AF2's atom14 encoding. Reference: - - https://github.com/google-deepmind/alphafold/blob/f251de6613cb478207c732bf9627b1e853c99c2f/alphafold/common/residue_constants.py#L505 + `AlphaFold residue_constants.py `_ """ AF2_ATOM37_ENCODING = TokenEncoding( @@ -315,7 +317,7 @@ def __repr__(self): """AF2's atom37 encoding Reference: - - https://github.com/google-deepmind/alphafold/blob/f251de6613cb478207c732bf9627b1e853c99c2f/alphafold/common/residue_constants.py#L492-L544 + `AlphaFold residue_constants.py `_ (extracted via: ```python atom37 = {} @@ -447,8 +449,10 @@ def __repr__(self): chemcomp_type_to_unknown={chem_type: "UNK" for chem_type in AA_LIKE_CHEM_TYPES}, ) """RF2 atom14 encoding for proteins. - - Encodes only the heavy atoms (max 14, for `TRP`) - - Includes 1 unknown tokens: `UNK` + +- Encodes only the heavy atoms (max 14, for ``TRP``) +- Includes 1 unknown tokens: ``UNK`` + Print it out to see a visual representation of the encoding. """ @@ -494,8 +498,10 @@ def __repr__(self): ), ) """RF2 atom23 encoding for proteins and nucleic acids. - - Encodes only the heavy atoms (max 22, for `RG`) - - Includes 3 unknown tokens: `UNK` for proteins, `DN` for dna, `N` for RNA + +- Encodes only the heavy atoms (max 22, for ``RG``) +- Includes 3 unknown tokens: ``UNK`` for proteins, ``DN`` for dna, ``N`` for RNA + Print it out to see a visual representation of the encoding. """ @@ -687,18 +693,11 @@ def __repr__(self): class AF3SequenceEncoding: - """ - Encodes and decodes sequence tokens for AlphaFold 3. + """Encodes and decodes sequence tokens for AlphaFold 3. This class provides functionality to convert between residue names and their corresponding integer encodings as used in AlphaFold 3. It handles standard amino acids, RNA, DNA, and unknown residues. - - Methods: - encode(res_names): Encode residue names to integer indices. - decode(res_indices): Decode integer indices to residue names. - tokens: Property that returns the list of AF3 tokens. - n_tokens: Property that returns the number of AF3 tokens. """ def __init__(self): diff --git a/src/atomworks/ml/pipelines/af3.py b/src/atomworks/ml/pipelines/af3.py index 511f2e7c..df28b348 100644 --- a/src/atomworks/ml/pipelines/af3.py +++ b/src/atomworks/ml/pipelines/af3.py @@ -161,9 +161,8 @@ def build_af3_transform_pipeline( The pipeline includes steps for processing the structure, adding annotations, and generating features required for AF3-like predictions. - References: - - AlphaFold 3 Supplementary Information. - https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf + Reference: + `AlphaFold 3 Supplementary Information `_ """ if (crop_contiguous_probability > 0 or crop_spatial_probability > 0) and not is_inference: diff --git a/src/atomworks/ml/preprocessing/utils/clustering.py b/src/atomworks/ml/preprocessing/utils/clustering.py index a8d4cd12..f7a2d408 100644 --- a/src/atomworks/ml/preprocessing/utils/clustering.py +++ b/src/atomworks/ml/preprocessing/utils/clustering.py @@ -77,9 +77,9 @@ def run_mmseqs2_clustering( afe56282ba3, 19f7ce1eed1 References: - - PDB clustering approach: https://www.rcsb.org/docs/grouping-structures/sequence-based-clustering - - MMseqs2 documentation: https://github.com/soedinglab/mmseqs2/wiki - - CLI documentation for the `easy-cluster` command: `mmseqs easy-cluster -h` + `PDB clustering approach `_ + `MMseqs2 documentation `_ + CLI documentation for the `easy-cluster` command: `mmseqs easy-cluster -h` """ # If input is a Path object, convert it to a string if isinstance(input_fasta, Path): diff --git a/src/atomworks/ml/preprocessing/utils/structure_utils.py b/src/atomworks/ml/preprocessing/utils/structure_utils.py index 43f21f87..c833c370 100644 --- a/src/atomworks/ml/preprocessing/utils/structure_utils.py +++ b/src/atomworks/ml/preprocessing/utils/structure_utils.py @@ -395,8 +395,8 @@ def get_ligand_validity_scores_from_pdb_id(pdb_id: str) -> list[dict[str, str | residue name, chain ID, and entity ID. Can easily be converted to a pandas DataFrame for easier handling via `pd.DataFrame(records)`. - References: - - https://www.rcsb.org/docs/general-help/ligand-structure-quality-in-pdb-structures + Reference: + `RCSB Ligand Structure Quality Guide `_ """ pdb_graphql_url: Final[str] = "https://data.rcsb.org/graphql" diff --git a/src/atomworks/ml/samplers.py b/src/atomworks/ml/samplers.py index 20ef3e7f..76535fe5 100644 --- a/src/atomworks/ml/samplers.py +++ b/src/atomworks/ml/samplers.py @@ -238,8 +238,8 @@ class DistributedMixedSampler(Sampler): Returns: iter: An iterator over indices of the dataset for the current process (of length n_samples, not n_examples_per_epoch) - References: - - PyTorch DistributedSampler (https://github.com/pytorch/pytorch/blob/main/torch/utils/data/distributed.py#L68) + Reference: + `PyTorch DistributedSampler `_ """ def __init__( diff --git a/src/atomworks/ml/transforms/af3_reference_molecule.py b/src/atomworks/ml/transforms/af3_reference_molecule.py index 10acaeef..b98aea63 100644 --- a/src/atomworks/ml/transforms/af3_reference_molecule.py +++ b/src/atomworks/ml/transforms/af3_reference_molecule.py @@ -1,5 +1,6 @@ import logging from collections import defaultdict +from functools import lru_cache from typing import Any, ClassVar, Literal import biotite.structure as struc @@ -24,8 +25,13 @@ logger = logging.getLogger("atomworks.ml") -# UNL is a special CCD code for unknown ligands; we do not consider it "known" as it has no structure -KNOWN_CCD_CODES = get_available_ccd_codes(CCD_MIRROR_PATH) - {UNKNOWN_LIGAND} + +# (Lazy-load this expensive computation to avoid slow imports) +@lru_cache(maxsize=1) +def get_known_ccd_codes() -> frozenset[str]: + """Get the set of known CCD codes, computing it lazily on first access.""" + # UNL is a special CCD code for unknown ligands; we do not consider it "known" as it has no structure + return get_available_ccd_codes(CCD_MIRROR_PATH) - {UNKNOWN_LIGAND} def _extract_cached_conformers( @@ -82,11 +88,11 @@ def _get_rdkit_mols_with_conformers( to using the idealized conformer from the CCD entry if available. Reference: - - https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf + `AF3 Supplementary Information `_ """ ref_mols = {} for res_name, count in res_stochiometry.items(): - if res_name not in KNOWN_CCD_CODES: + if res_name not in get_known_ccd_codes(): ref_mols[res_name] = None # placeholder so that the unknown CCD codes are still counted later on continue @@ -114,7 +120,7 @@ def _encode_atom_names_like_af3(atom_names: np.ndarray) -> np.ndarray: length 4. Reference: - - https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf + `AF3 Supplementary Information `_ """ # Ensure uppercase atom_names = np.char.upper(atom_names) @@ -224,8 +230,7 @@ def get_af3_reference_molecule_features( - is_atomized_atom_level: [N_atoms] Whether the atom is atomized (atom-level version of "is_ligand") Reference: - - Section 2.8 of the AF3 supplementary information - https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf + `Section 2.8 of the AF3 supplementary information `_ """ _has_ground_truth_conformer_policy = "ground_truth_conformer_policy" in atom_array.get_annotation_categories() _has_global_res_id = "res_id_global" in atom_array.get_annotation_categories() @@ -263,8 +268,8 @@ def get_af3_reference_molecule_features( # ... generate conformers for CCD codes that are unknown (including UNL) unknown_ccd_conformers = defaultdict(list) - if not all(res_name in KNOWN_CCD_CODES for res_name in res_stochiometry): - res_indices_with_unknown = np.where(~np.isin(_res_names, list(KNOWN_CCD_CODES)))[0] + if not all(res_name in get_known_ccd_codes() for res_name in res_stochiometry): + res_indices_with_unknown = np.where(~np.isin(_res_names, list(get_known_ccd_codes())))[0] for res_index in res_indices_with_unknown: res_name = _res_names[res_index] @@ -307,7 +312,7 @@ def get_af3_reference_molecule_features( conf_idx = _next_conf_idx[res_name] # ... turn conformer into an atom array - if res_name not in KNOWN_CCD_CODES: + if res_name not in get_known_ccd_codes(): # (conformers for unknown CCD codes are already atom arrays, since we generated them directly) conformer = unknown_ccd_conformers[res_name][conf_idx % len(unknown_ccd_conformers[res_name])] else: @@ -457,8 +462,7 @@ class GetAF3ReferenceMoleculeFeatures(Transform): Note: This transform should be applied after cropping. Reference: - - Section 2.8 of the AF3 supplementary information - https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf + `Section 2.8 of the AF3 supplementary information `_ """ def __init__( diff --git a/src/atomworks/ml/transforms/atom_array.py b/src/atomworks/ml/transforms/atom_array.py index 29c06f06..e226b498 100644 --- a/src/atomworks/ml/transforms/atom_array.py +++ b/src/atomworks/ml/transforms/atom_array.py @@ -660,9 +660,10 @@ def sort_like_rf2aa(atom_array: AtomArray) -> AtomArray: class SortLikeRF2AA(Transform): """Sort the atom array in 3 groups (in this order). Within each group the atoms are ordered by their pn_unit_iid (and within a pn_unit their order is preserved). - - (1) polymer atoms - - (2) non-poly atoms of a pn-unit bonded to a polymer (covalent modifications) - - (3) non-poly atoms of a free-floating pn-unit (free-floating ligands) + + - (1) polymer atoms + - (2) non-poly atoms of a pn-unit bonded to a polymer (covalent modifications) + - (3) non-poly atoms of a free-floating pn-unit (free-floating ligands) """ requires_previous_transforms: ClassVar[list[str | Transform]] = ["AtomizeByCCDName"] diff --git a/src/atomworks/ml/transforms/atom_frames.py b/src/atomworks/ml/transforms/atom_frames.py index 29d41ac8..f21b2b31 100644 --- a/src/atomworks/ml/transforms/atom_frames.py +++ b/src/atomworks/ml/transforms/atom_frames.py @@ -159,21 +159,19 @@ def find_all_paths_of_length_n( graph: nx.Graph, n: int, order_independent_atom_frame_prioritization: bool = True ) -> list: - """ - Find all paths of a given length n in a NetworkX graph. + """Find all paths of a given length n in a NetworkX graph. - Parameters: - G (nx.Graph): The input graph. - n (int): The length of the paths to find. - order_independent_frame_prioritization (bool, optional): - If True, considers paths with the same nodes but in different orders as equivalent. - Defaults to True. + Args: + graph: The input graph. + n: The length of the paths to find. + order_independent_atom_frame_prioritization: If True, considers paths with the same nodes but in different orders as equivalent. + Defaults to True. Returns: - np.ndarray: A tensor containing all unique paths of length n. + A tensor containing all unique paths of length n. Reference: - https://stackoverflow.com/questions/28095646/finding-all-paths-walks-of-given-length-in-a-networkx-graph''' + `StackOverflow: Finding all paths of given length `_ """ def find_paths(graph: nx.Graph, u: Any, n: int) -> list[list[Any]]: diff --git a/src/atomworks/ml/transforms/base.py b/src/atomworks/ml/transforms/base.py index 3aa1f9fb..9b881938 100644 --- a/src/atomworks/ml/transforms/base.py +++ b/src/atomworks/ml/transforms/base.py @@ -31,31 +31,42 @@ class TransformPipelineError(Exception): - """A custom error class for Transform pipelines (via `Compose`).""" + """A custom error class for Transform pipelines (via :class:`Compose`). + + Attributes: + rng_state_dict: Optional RNG state dictionary for debugging purposes. + """ def __init__(self, message: str, rng_state_dict: dict[str, Any] | None = None): + """Initialize TransformPipelineError. + + Args: + message: The error message. + rng_state_dict: Optional RNG state dictionary for debugging purposes. + """ super().__init__(message) # expose RNG state dict for debugging self.rng_state_dict = rng_state_dict class TransformedDict(dict): - """A thin wrapper around a dictionary that can be used to track the transform history.""" + """A thin wrapper around a dictionary that can be used to track the transform history. + + Behaves just like a regular dictionary but includes a ``__transform_history__`` attribute + that tracks the sequence of transforms applied to the data. + """ def __new__(cls, __existing_dict_to_wrap: dict[str, Any] | None = None, **kwargs): """Create a new instance or return the existing TransformedDict instance. - NOTE: To get a pure dictionary, simply use `dict(transformed_dict)` on a TransformedDict instance. - TransformedDict's behave just like dicts for all intents and purposes, so you can use them just like - a regular dictionary. + Note: + To get a pure dictionary, simply use ``dict(transformed_dict)`` on a TransformedDict instance. + TransformedDict's behave just like dicts for all intents and purposes. Args: - __existing_dict_to_wrap (dict, optional): This is useful for wrapping an existing dictionary. - The odd name `__existing_dict_to_wrap` is used as an unlikely name to avoid conflicts - with the `dict` class. - **kwargs: Additional keyword arguments to pass to the dictionary constructor. This ensures - that a TransformedDict can be initialized just like a regular dictionary if no existing - dictionary to wrap is provided. + __existing_dict_to_wrap: This is useful for wrapping an existing dictionary. + The odd name is used as an unlikely name to avoid conflicts with the dict class. + **kwargs: Additional keyword arguments to pass to the dictionary constructor. """ # if the argument is already a TransformedDict, return it if isinstance(__existing_dict_to_wrap, TransformedDict): @@ -79,21 +90,18 @@ def __new__(cls, __existing_dict_to_wrap: dict[str, Any] | None = None, **kwargs class Transform(ABC): - """ - Abstract base class for transformations on dictionary objects. - - Class level attributes: - - validate_input (bool): Whether to validate the input. - - raise_if_invalid_input (bool): Whether to raise an error if the input is invalid. - - requires_previous_transforms (list[str]): Transforms that must have been applied before this transform. - - incompatible_previous_transforms (list[str]): Transforms that cannot have preceeded this transform. - - previous_transforms_order_matters (bool): Whether the order of the transforms is important. - - _track_transform_history (bool): Whether to track the transform history. - - To write a subclass, you need to implement the following methods: - - check_input(data: dict): Validates the input data. Should raise an error if the input is invalid. - The returned value is not used. - - forward(data: dict): Applies the transformation to the input data and returns the transformed data. + """Abstract base class for transformations on dictionary objects. + + To write a subclass, you need to implement the :meth:`forward` method. + Optionally, you can override :meth:`check_input` for input validation. + + Attributes: + validate_input: Whether to validate the input. + raise_if_invalid_input: Whether to raise an error if the input is invalid. + requires_previous_transforms: Transforms that must have been applied before this transform. + incompatible_previous_transforms: Transforms that cannot have preceded this transform. + previous_transforms_order_matters: Whether the order of the transforms is important. + _track_transform_history: Whether to track the transform history. """ validate_input: bool = True @@ -105,28 +113,39 @@ class Transform(ABC): # To be implemented by subclasses (optional) def check_input(self, data: dict[str, Any]) -> None: # noqa: B027 - """ - Check if the input dictionary is valid for the transform. Raises an error if the input is invalid. + """Check if the input dictionary is valid for the transform. + + Args: + data: The input dictionary to validate. + + Raises: + Exception: If the input is invalid. """ pass @abstractmethod def forward(self, data: dict[str, Any], *args, **kwargs) -> dict[str, Any]: - """ - Apply a transformation to the input dictionary and return the transformed dictionary. + """Apply a transformation to the input dictionary and return the transformed dictionary. - Parameters: - data (dict): The input dictionary to transform. + Args: + data: The input dictionary to transform. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. Returns: - dict: The transformed dictionary. + The transformed dictionary. """ pass # Internal logic for formatting error messages, debugging, logging and transform history tracking def _format_error_msg(self, e: Exception) -> str: - """ - Formats the error message with optional traceback when in DEBUG mode. + """Format the error message with optional traceback when in DEBUG mode. + + Args: + e: The exception that occurred. + + Returns: + Formatted error message. """ msg = f"Invalid input for {self.__class__.__name__}: {e}" if DEBUG: @@ -134,8 +153,13 @@ def _format_error_msg(self, e: Exception) -> str: return msg def _transform_to_str(self, t: str | Transform | ABCMeta) -> str: - """ - Convert a transform to a string. + """Convert a transform to a string. + + Args: + t: The transform to convert (string, Transform instance, or Transform class). + + Returns: + String representation of the transform. """ if isinstance(t, str): # case: transform was provided as string, e.g. as `"RemoveKeys"` @@ -150,19 +174,36 @@ def _transform_to_str(self, t: str | Transform | ABCMeta) -> str: raise ValueError(f"Transform `{t}` cannot be converted to a string form for comparison of history.") def _ensure_has_transform_history(self, data: dict[str, Any] | TransformedDict) -> TransformedDict: - """Ensure that the data dictionary has a transform history by wrapping it in a `TransformedDict`.""" + """Ensure that the data dictionary has a transform history by wrapping it in a TransformedDict. + + Args: + data: The data dictionary to wrap. + + Returns: + TransformedDict instance with transform history. + """ data = TransformedDict(data) return data def _get_transform_history(self, data: TransformedDict) -> list[str]: - """ - Get the transform history from the data. + """Get the transform history from the data. + + Args: + data: The TransformedDict containing the history. + + Returns: + List of transform names in the history. """ return data.__transform_history__ def _maybe_update_transform_history(self, data: TransformedDict) -> dict[str, Any]: - """ - Update the transform history by appending the current transform to the transform history. + """Update the transform history by appending the current transform to the transform history. + + Args: + data: The TransformedDict to update. + + Returns: + The updated data dictionary. """ if self._track_transform_history: this_transform_record = { @@ -178,8 +219,14 @@ def _maybe_update_transform_history(self, data: TransformedDict) -> dict[str, An return data def _maybe_restore_transform_history(self, data: TransformedDict, transform_history: list[str]) -> dict[str, Any]: - """ - Restore the transform history, in case the data was copied. + """Restore the transform history, in case the data was copied. + + Args: + data: The TransformedDict to restore history for. + transform_history: The history to restore. + + Returns: + The data with restored history. """ if not hasattr(data, "__transform_history__") or len(data.__transform_history__) == 0: # restore previous transform history if it is not present (e.g. if the data was copied) @@ -187,8 +234,13 @@ def _maybe_restore_transform_history(self, data: TransformedDict, transform_hist return data def _maybe_record_processing_time(self, data: TransformedDict) -> dict[str, Any]: - """ - Record the processing time for the transform. + """Record the processing time for the transform. + + Args: + data: The TransformedDict to record timing for. + + Returns: + The data with updated timing information. """ if self._track_transform_history and len(data.__transform_history__) > 0: for reverse_idx in range(len(data.__transform_history__) - 1, -1, -1): @@ -202,9 +254,13 @@ def _maybe_record_processing_time(self, data: TransformedDict) -> dict[str, Any] return data def _check_transform_history(self, data: TransformedDict) -> None: - """ - Check if the previous transforms are valid for the transform. - Raises an error if the input is invalid. + """Check if the previous transforms are valid for the transform. + + Args: + data: The TransformedDict to check. + + Raises: + TransformPipelineError: If the transform history is invalid. """ # extract the transform history history = [record["name"] for record in data.__transform_history__] @@ -243,11 +299,18 @@ def _check_transform_history(self, data: TransformedDict) -> None: ) def __call__(self, data: dict[str, Any], *args, **kwargs) -> dict[str, Any]: - """ - Validate and apply the transformation to the given dictionary. + """Validate and apply the transformation to the given dictionary. + + Args: + data: The input dictionary to transform. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + The transformed dictionary. Raises: - ValueError: If the input is invalid and raise_if_invalid_input is True. + TransformPipelineError: If the input is invalid and raise_if_invalid_input is True. """ # enable history tracking if it is not already enabled data = self._ensure_has_transform_history(data) @@ -291,7 +354,11 @@ def __call__(self, data: dict[str, Any], *args, **kwargs) -> dict[str, Any]: return data def __repr__(self) -> str: - """String representation of the transform for debugging, notebooks and logging.""" + """String representation of the transform for debugging, notebooks and logging. + + Returns: + String representation of the transform. + """ # Get all the attributes of the class repr_str = f"{self.__class__.__name__} at {hex(id(self))}" @@ -304,6 +371,17 @@ def __repr__(self) -> str: return repr_str def __add__(self, other: Transform) -> Compose: + """Add two transforms together to create a Compose instance. + + Args: + other: Another Transform or Compose instance. + + Returns: + A new Compose instance containing both transforms. + + Raises: + ValueError: If other is not a Transform or Compose instance. + """ # Case 1: self & other are `Compose` instances # ... overridden in `Compose` class # Case 2: self is a `Compose` instance and other is a `Transform` instance @@ -321,38 +399,36 @@ def __add__(self, other: Transform) -> Compose: class Compose(Transform): - """ - Compose multiple transformations together. + """Compose multiple transformations together. This class allows you to chain multiple transformations and apply them sequentially to a data dictionary. It is particularly useful for preprocessing pipelines where multiple steps need to be applied in a specific order. Attributes: - - transforms (list[Transform]): A list of transformations to be applied. - - track_rng_state (bool): Whether to track and serialize the random number generator (RNG) state. This is + transforms: A list of transformations to be applied. + track_rng_state: Whether to track and serialize the random number generator (RNG) state. This is useful for debugging when dealing with probabilistic transformations. The RNG state is returned with the error message if the transform pipeline fails, allowing you to instantiate the same RNG state - with `eval` for debugging. + with ``eval`` for debugging. """ _track_transform_history: bool = False # Compose does not show up in the transform history def __init__(self, transforms: list[Transform], track_rng_state: bool = True, print_rng_state: bool = False): - """ - Initialize the Compose transformation pipeline. + """Initialize the Compose transformation pipeline. Args: - - transforms (list[Transform]): A list of transformations to be applied sequentially. - - track_rng_state (bool): Whether to track and serialize the random number generator (RNG) state. + transforms: A list of transformations to be applied sequentially. + track_rng_state: Whether to track and serialize the random number generator (RNG) state. This is useful for debugging when dealing with probabilistic transformations. The RNG state is returned with the error message if the transform pipeline fails, allowing you to instantiate - the same RNG state with `eval` for debugging. - - print_rng_state (bool): Whether to print the RNG state upon failure. This can be useful + the same RNG state with ``eval`` for debugging. + print_rng_state: Whether to print the RNG state upon failure. This can be useful for debugging and reproducing specific states for transforms with stochasticity. Raises: - ValueError: If `transforms` is not a list or tuple, if it is empty, or if it contains elements that - are not instances of `Transform`. + ValueError: If transforms is not a list or tuple, if it is empty, or if it contains elements that + are not instances of Transform. """ if not isinstance(transforms, list | tuple): raise ValueError(f"Expected a list or tuple of Transforms, but got a {type(transforms)}") @@ -370,6 +446,17 @@ def __init__(self, transforms: list[Transform], track_rng_state: bool = True, pr self.print_rng_state = print_rng_state def __add__(self, other: Transform | list[Transform] | Compose) -> Compose: + """Add another transform or compose to this compose. + + Args: + other: Another Transform, list of Transforms, or Compose instance. + + Returns: + A new Compose instance containing all transforms. + + Raises: + ValueError: If other is not a valid type. + """ if isinstance(other, Compose): return Compose( self.transforms + other.transforms, track_rng_state=self.track_rng_state or other.track_rng_state @@ -382,6 +469,13 @@ def __add__(self, other: Transform | list[Transform] | Compose) -> Compose: raise ValueError(f"Expected a Transform or list of Transforms or Compose, but got a {type(other)}") def check_input(self, data: dict) -> None: + """Check if the input is valid for the compose. + + Compose is always valid, so this method does nothing. + + Args: + data: The input data to check. + """ # Compose is always valid pass @@ -391,6 +485,19 @@ def _stop_transforms( next_transform_idx: int, stop_before: Transform | int | str | None = None, ) -> bool: + """Check if transforms should stop before the next transform. + + Args: + next_transform: The next transform to apply. + next_transform_idx: The index of the next transform. + stop_before: The transform, name, or index to stop before. + + Returns: + True if transforms should stop before the next transform. + + Raises: + ValueError: If stop_before is not a valid type. + """ if stop_before is None: return False elif isinstance(stop_before, int): @@ -408,19 +515,17 @@ def forward( rng_state_dict: dict[str, Any] | None = None, _stop_before: Transform | str | int | None = None, ) -> dict: - """ - Apply a series of transformations to the input data. + """Apply a series of transformations to the input data. Args: - data (dict): The input data to be transformed. - rng_state_dict (dict[str, Any] | None, optional): Random number generator state dictionary. - If provided, sets the RNG state before applying transforms. Defaults to None. - _stop_before (Transform | str | int | None, optional): Specifies a point to stop the transformation + data: The input data to be transformed. + rng_state_dict: Random number generator state dictionary. + If provided, sets the RNG state before applying transforms. + _stop_before: Specifies a point to stop the transformation process. Can be a Transform instance, a string (transform class name), or an integer (index). - Defaults to None. Returns: - dict: The transformed data. + The transformed data. Raises: Exception: If any transform in the pipeline fails, with details about the failure point and RNG state. diff --git a/src/atomworks/ml/transforms/bonds.py b/src/atomworks/ml/transforms/bonds.py index 84499ba1..5aef7bff 100644 --- a/src/atomworks/ml/transforms/bonds.py +++ b/src/atomworks/ml/transforms/bonds.py @@ -135,41 +135,42 @@ def forward(self, data: dict) -> dict: class AddRF2AABondFeaturesMatrix(Transform): - """ - Adds a matrix indicating the RF2AA bond type between two nodes to the data. + """Adds a matrix indicating the RF2AA bond type between two nodes to the data. + This transform builds from the Biotite bond type, modifying as needed for residue-residue and residue-atom mappings. - We then add the matrix to the data dictionary under the key `rf2aa_bond_features_matrix`. + We then add the matrix to the data dictionary under the key "rf2aa_bond_features_matrix". From the RF2AA supplement, Supplementary Methods Table 8: Inputs to RFAA: - ------------------------------------------------------------------------------------------------ - bond_feats | (L, L, 7) Pairwise bond adjacency matrix. Pairs of residues are either - single, double, triple, aromatic, residue-residue, residue-atom, or other. - ------------------------------------------------------------------------------------------------ + + bond_feats | (L, L, 7) Pairwise bond adjacency matrix. Pairs of residues are either + | single, double, triple, aromatic, residue-residue, residue-atom, or other. Specifically, we map to the following enum, as described in ChemData: - - 0 = No bonds - - 1 = Single bond - - 2 = Double bond - - 3 = Triple bond - - 4 = Aromatic - - 5 = Residue-residue - - 6 = Residue-atom - - 7 = Other + - 0 = No bonds + - 1 = Single bond + - 2 = Double bond + - 3 = Triple bond + - 4 = Aromatic + - 5 = Residue-residue + - 6 = Residue-atom + - 7 = Other We build the matrix from the Biotite bond types. - The Biotite `BondType` enum contains the following mapping: - - ANY = 0 - - SINGLE = 1 - - DOUBLE = 2 - - TRIPLE = 3 - - QUADRUPLE = 4 - - AROMATIC_SINGLE = 5 - - AROMATIC_DOUBLE = 6 - - AROMATIC_TRIPLE = 7 + The Biotite BondType enum contains the following mapping: + + - ANY = 0 + - SINGLE = 1 + - DOUBLE = 2 + - TRIPLE = 3 + - QUADRUPLE = 4 + - AROMATIC_SINGLE = 5 + - AROMATIC_DOUBLE = 6 + - AROMATIC_TRIPLE = 7 + The the index -1 is used for non-bonded interactions. Reference: - - Biotite documentation (https://www.biotite-python.org/apidoc/biotite.structure.BondType.html#biotite.structure.BondType) + `Biotite BondType Documentation `_ """ requires_previous_transforms: ClassVar[list[str | Transform]] = [AtomizeByCCDName, AddTokenBondAdjacency] @@ -202,16 +203,15 @@ def forward(self, data: dict) -> dict: class AddRF2AATraversalDistanceMatrix(Transform): - """ - Generates a matrix indicating the minimum amount of bonds to traverse between two nodes. + """Generates a matrix indicating the minimum amount of bonds to traverse between two nodes. + We define the traversal distance between two protein nodes as zero. Sets the "traversal_distance_matrix" key in the data dictionary. From the RF2AA supplement, Supplementary Methods Table 8: Inputs to RFAA: - ------------------------------------------------------------------------------------------------ - dist_matrix | (L, L) Minimum amount of bonds to traverse between two nodes. - This is 0 between all protein nodes. - ------------------------------------------------------------------------------------------------ + + dist_matrix | (L, L) Minimum amount of bonds to traverse between two nodes. + | This is 0 between all protein nodes. """ def check_input(self, data: dict) -> None: diff --git a/src/atomworks/ml/transforms/chirals.py b/src/atomworks/ml/transforms/chirals.py index 4769b94d..0dccada0 100644 --- a/src/atomworks/ml/transforms/chirals.py +++ b/src/atomworks/ml/transforms/chirals.py @@ -90,7 +90,7 @@ def _get_plane_pair_keys_for_planes_between_chiral_center_and_tetrahedral_side( AssertionError: If the length of `bonded_atoms` is not 4. Reference: - - RF2AA supplementary notes figure S1 (https://www.science.org/doi/10.1126/science.adl2528#supplementary-materials) + `RF2AA supplementary notes figure S1 `_ Example: >>> chiral_center = 1 @@ -159,11 +159,13 @@ def get_rf2aa_chiral_features( NOTE: Each row of output features contains the indices of the plane pairs and the signed ideal dihedral angle for each chiral center. For example, the entry: - `[c, i, j, k, angle]` - means that the atom at index `c` is a chiral center with atoms at indices `(i, j, k)` bonded - to it. The signed dihedral angle `angle` is the signed angle between the planes `(cij)` and - `(ijk)`. The sign of the angle determines the chirality of the chiral center. + [c, i, j, k, angle] + means that the atom at index c is a chiral center with atoms at indices (i, j, k) bonded + to it. The signed dihedral angle angle is the signed angle between the planes (cij) and + (ijk). The sign of the angle determines the chirality of the chiral center. + NOTE: Each chiral center will result in more than one feature. In particular: + - 3 features if one of the 4 atoms bonded to the chiral center is an implicit hydrogen (as we do not look at any pair of planes where one plane contains an implicit hydrogen). - 12 features if all 4 atoms bonded to the chiral center are explicit atoms. @@ -176,19 +178,19 @@ def get_rf2aa_chiral_features( network to iteratively refine predictions to match ideal tetrahedral geometry. Args: - chiral_centers (list[dict]): A list of dictionaries, of the form: + chiral_centers: A list of dictionaries, of the form: {"chiral_center_idx": int, "bonded_explicit_atom_idxs": list[int]} - where `chiral_center_idx` is the index of the chiral center atom, and `bonded_explicit_atom_idxs` + where chiral_center_idx is the index of the chiral center atom, and bonded_explicit_atom_idxs is a list of the indices of the atoms bonded to the chiral center (excluding implicit hydrogens). - coords (np.ndarray): A numpy array of atomic coordinates. - take_first_chiral_subordering (bool): If True, only the first subordering is considered (when four + coords: A numpy array of atomic coordinates. + take_first_chiral_subordering: If True, only the first subordering is considered (when four bonded non-hydrogen atoms are present). If False, all orderings are considered (leading to 12 unique plane pairs in the case of 4 bonded atoms, or 3 unique plane pairs in the case of 3 bonded atoms). Returns: - torch.Tensor: A tensor of shape [n_chirals, 5] where each row contains the indices of the plane pairs - and the *signed* ideal dihedral angle for each chiral center. The sign of the dihedral + A tensor of shape [n_chirals, 5] where each row contains the indices of the plane pairs + and the signed ideal dihedral angle for each chiral center. The sign of the dihedral angle determines the chirality of the chiral center (+1 for clockwise, -1 for counterclockwise). If no stereocenters are found, returns an empty tensor of shape [0, 5]. """ @@ -233,46 +235,43 @@ def get_rf2aa_chiral_features( class AddRF2AAChiralFeatures(Transform): - """ - AddRF2AAChiralFeatures adds chiral features to the atom array data under the `"chiral_feats"` key. - Chiral centers are taken from `data["chiral_centers"]`, which is a list of dictionaries, of the form: - {"chiral_center_atom_id": int, "bonded_explicit_atom_ids": list[int]} - This metadata can be added by running e.g. the `AddOpenBabelMoleculesForAtomizedMolecules` and - `GetChiralCentersFromOpenBabel` transforms.This transform also requires the `AtomizeByCCDName` transform + """AddRF2AAChiralFeatures adds chiral features to the atom array data under the "chiral_feats" key. + + Chiral centers are taken from data["chiral_centers"], which is a list of dictionaries, of the form: + {"chiral_center_atom_id": int, "bonded_explicit_atom_ids": list[int]} + + This metadata can be added by running e.g. the AddOpenBabelMoleculesForAtomizedMolecules and + GetChiralCentersFromOpenBabel transforms. This transform also requires the AtomizeByCCDName transform to be applied previously to ensure the atom array is properly atomized. Args: - data (dict[str, Any]): A dictionary containing the input data, including the atom array and chiral centers. + data: A dictionary containing the input data, including the atom array and chiral centers. Returns: - dict[str, Any]: The updated `data` dictionary with the added chiral features under the `"chiral_feats"` key. + The updated data dictionary with the added chiral features under the "chiral_feats" key. Example: - data = { - "atom_array": atom_array, - "chiral_centers": [ - { - "chiral_center_atom_id": 5, - "bonded_explicit_atom_ids": [1, 2, 3, 4] - }, - { - "chiral_center_atom_id": 10, - "bonded_explicit_atom_ids": [6, 7, 8, 9] - } - ] - } - - transform = AddRF2AAChiralFeatures() - result = transform.forward(data) - - print(result["chiral_feats"]) - # Output might look like: - # (assuming the atom_id s above also correspond to the indices in the atom array, - # otherwise the first 4 columns look different as they are the indices in the atom array) - # tensor([[ 5., 1., 2., 3., 0.61546...], - # [ 5., 2., 3., 4., -0.61546...], - # ... - # [10., 7., 8., 9., -0.61546...]]) + .. code-block:: python + + data = { + "atom_array": atom_array, + "chiral_centers": [ + {"chiral_center_atom_id": 5, "bonded_explicit_atom_ids": [1, 2, 3, 4]}, + {"chiral_center_atom_id": 10, "bonded_explicit_atom_ids": [6, 7, 8, 9]}, + ], + } + + transform = AddRF2AAChiralFeatures() + result = transform.forward(data) + + print(result["chiral_feats"]) + # Output might look like: + # (assuming the atom_id s above also correspond to the indices in the atom array, + # otherwise the first 4 columns look different as they are the indices in the atom array) + # tensor([[ 5., 1., 2., 3., 0.61546...], + # [ 5., 2., 3., 4., -0.61546...], + # ... + # [10., 7., 8., 9., -0.61546...]]) """ requires_previous_transforms: ClassVar[list[str | Transform]] = ["AtomizeByCCDName"] @@ -401,12 +400,14 @@ def add_af3_chiral_features( class AddAF3ChiralFeatures(Transform): - """Adds chiral features into the `feats` dictionary. + """Adds chiral features into the feats dictionary. Adds the following features to the data dictionary under the 'feats' key: - - chiral_feats: [N_chiral_centers, 5] A listing of chiral centers of the format: - tensor([[ 5., 1., 2., 3., 0.61546...],...]) - Here, the first 4 columns define atom indices of chiral center; the 5th is target dihedral + + chiral_feats + [N_chiral_centers, 5] A listing of chiral centers of the format: + tensor([[ 5., 1., 2., 3., 0.61546...],...]) + Here, the first 4 columns define atom indices of chiral center; the 5th is target dihedral Metadata from GetRDKitChiralCenters, held in the "chiral_centers" key, is needed for this transform. """ diff --git a/src/atomworks/ml/transforms/covalent_modifications.py b/src/atomworks/ml/transforms/covalent_modifications.py index cfbc568e..99a09187 100644 --- a/src/atomworks/ml/transforms/covalent_modifications.py +++ b/src/atomworks/ml/transforms/covalent_modifications.py @@ -91,13 +91,12 @@ class FlagAndReassignCovalentModifications(Transform): """Handles covalent modifications within the AtomArray. Covalent modifications, e.g., glycosylation, are handled by the following algorithm: - ------------------------------------------------------------------------------------------------ + for polymer residues with atoms covalently bound to a NON-POLYMER: for ALL atoms in the polymer residue: set the pn_unit_iid and pn_unit_id identifying annotations to that of the NON-POLYMER polymer/non-polymer unit set atomize = true (thus, this transform must be run before the Atomize transform) set is_covalent_modification = true (for the entire pn_unit) - ------------------------------------------------------------------------------------------------ TODO: Break into two Transforms - one that flags, one that reassigns. Atomizing covalent modifications is a design choice that may not be desired in all pipelines. Annotating covalent modifications, however, is broadly useful. diff --git a/src/atomworks/ml/transforms/crop.py b/src/atomworks/ml/transforms/crop.py index c691530a..6c3e9521 100644 --- a/src/atomworks/ml/transforms/crop.py +++ b/src/atomworks/ml/transforms/crop.py @@ -99,8 +99,8 @@ def crop_contiguous_af2_multimer(iids: list[int | str], instance_lens: list[int] (iids) to crop masks (i.e. boolean arrays) indicating which tokens to keep. References: - - AF2 Multimer https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf - - AF3 https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf + `AF2 Multimer `_ + `AF3 `_ Example: >>> iids = [1, 2, 3] @@ -185,7 +185,6 @@ def get_spatial_crop_center( Sample a crop center from a spatial region of the atom array. Implements the selection of a crop center as described in AF3. - ``` In this procedure, polymer residues and ligand atoms are selected that are within close spatial distance of an interface atom. The interface atom is selected at random from the set of token centre atoms (defined @@ -195,7 +194,6 @@ def get_spatial_crop_center( provided (subsection 2.5), the reference atom is selected at random from interfacial token centre atoms that exist within this chain or interface. - ``` Args: atom_array (AtomArray): The array containing atom information. @@ -289,8 +287,8 @@ def get_spatial_crop_mask( crop_mask (np.ndarray): A boolean mask of shape (N,) where True indicates that the token is within the crop. References: - - AF2 Multimer https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf - - AF3 https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf + `AF2 Multimer `_ + `AF3 `_ Example: >>> coord = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]) diff --git a/src/atomworks/ml/transforms/encoding.py b/src/atomworks/ml/transforms/encoding.py index 19de7da4..4cdd65c1 100644 --- a/src/atomworks/ml/transforms/encoding.py +++ b/src/atomworks/ml/transforms/encoding.py @@ -197,8 +197,7 @@ def atom_array_from_encoding( **other_annotations: np.ndarray | None, # TODO: Allow passing a res_id ) -> AtomArray: - """ - Create an AtomArray from encoded coordinates, mask, and sequence. + """Create an AtomArray from encoded coordinates, mask, and sequence. This function takes encoded data and reconstructs an AtomArray, which is a structured representation of atomic information. The encoded coordinates, @@ -206,23 +205,24 @@ def atom_array_from_encoding( relevant annotations are included. Args: - - encoded_coord (torch.Tensor | np.ndarray): Encoded coordinates tensor. - - encoded_mask (torch.Tensor | np.ndarray): Encoded mask tensor. - - encoded_seq (torch.Tensor | np.ndarray): Encoded sequence tensor. - - encoding (TokenEncoding): The encoding to use for encoding the atom array. - - chain_id (str | np.ndarray, optional): Chain ID. Can be a single string (e.g., "A") + encoded_coord: Encoded coordinates tensor. + encoded_mask: Encoded mask tensor. + encoded_seq: Encoded sequence tensor. + encoding: The encoding to use for encoding the atom array. + chain_id: Chain ID. Can be a single string (e.g., "A") or a numpy array of shape (n_res,) corresponding to each residue. Defaults to "A". - - token_is_atom (torch.Tensor | np.ndarray | None, optional): Boolean mask indicating + token_is_atom: Boolean mask indicating whether each token corresponds to an atom. - - **other_annotations (np.ndarray | None): Additional annotations to include in the + **other_annotations: Additional annotations to include in the AtomArray. The shape must match one of the following: + - scalar, for global annotations - (n_atom,) for per-atom annotations, - (n_res,) for per-residue annotations, - (n_chain,) for per-chain annotations. Returns: - - atom_array (AtomArray): The created AtomArray containing the encoded atomic information. + The created AtomArray containing the encoded atomic information. """ # Turn tensors into numpy arrays if necessary _from_tensor = lambda x: x.cpu().numpy() if isinstance(x, torch.Tensor) else x # noqa E731 @@ -401,39 +401,53 @@ def forward(self, data: dict[str, Any]) -> dict[str, Any]: class EncodeAF3TokenLevelFeatures(Transform): - """ - A transform that encodes token-level features like AF3. The token-level features are returned as: + """A transform that encodes token-level features like AF3. The token-level features are returned as: - - feats: - # (Standard AF3 token-level features) - - `residue_index`: Residue number in the token's original input chain (pre-crop) - - `token_index`: Token number. Increases monotonically; does not restart at 1 for new + feats: + (Standard AF3 token-level features) + + residue_index + Residue number in the token's original input chain (pre-crop) + token_index + Token number. Increases monotonically; does not restart at 1 for new chains. (Runs from 0 to N_tokens) - - `asym_id`: Unique integer for each distinct chain (pn_unit_iid) - NOTE: We use `pn_unit_iid` rather than `chain_iid` to be more consistent + asym_id + Unique integer for each distinct chain (pn_unit_iid) + NOTE: We use pn_unit_iid rather than chain_iid to be more consistent with handling of multi-residue/multi-chain ligands (especially sugars) - - `entity_id`: Unique integer for each distinct sequence (pn_unit entity) - - `sym_id`: Unique integer within chains of this sequence. E.g. if pn_units A, B and C - share a sequence but D does not, their `sym_id`s would be [0, 1, 2, 0]. - - `restype`: Integer encoding of the sequence. 32 possible values: 20 AA + unknown, + entity_id + Unique integer for each distinct sequence (pn_unit entity) + sym_id + Unique integer within chains of this sequence. E.g. if pn_units A, B and C + share a sequence but D does not, their sym_ids would be [0, 1, 2, 0]. + restype + Integer encoding of the sequence. 32 possible values: 20 AA + unknown, 4 RNA nucleotides + unknown, 4 DNA nucleotides + unknown, and gap. Ligands are - represented as unknown amino acid (`UNK`) - - `is_protein`: whether a token is of protein type - - `is_rna`: whether a token is of RNA type - - `is_dna`: whether a token is of DNA type - - `is_ligand`: whether a token is a ligand residue - - # (Custom token-level features) - - `is_atomized`: whether a token is an atomized token - - - feat_metadata: - - `asym_name`: The asymmetric unit name for each id in `asym_id`. Acts as a legend. - - `entity_name`: The entity name for each id in `entity_id`. Acts as a legend. - - `sym_name`: The symmetric unit name for each id in `sym_id`. Acts as a legend. + represented as unknown amino acid (UNK) + is_protein + whether a token is of protein type + is_rna + whether a token is of RNA type + is_dna + whether a token is of DNA type + is_ligand + whether a token is a ligand residue + + (Custom token-level features) + + is_atomized + whether a token is an atomized token + + feat_metadata: + asym_name + The asymmetric unit name for each id in asym_id. Acts as a legend. + entity_name + The entity name for each id in entity_id. Acts as a legend. + sym_name + The symmetric unit name for each id in sym_id. Acts as a legend. Reference: - - Section 2.8 of the AF3 supplementary (Table 5) - https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf + `Section 2.8 of the AF3 supplementary (Table 5) `_ """ def __init__(self, sequence_encoding: AF3SequenceEncoding): diff --git a/src/atomworks/ml/transforms/filters.py b/src/atomworks/ml/transforms/filters.py index 5f431d65..16cd8906 100644 --- a/src/atomworks/ml/transforms/filters.py +++ b/src/atomworks/ml/transforms/filters.py @@ -30,6 +30,7 @@ def remove_unresolved_pn_units(atom_array: AtomArray) -> AtomArray: """ Filters PN units that have all unresolved atoms (i.e., atoms with occupancy 0) from the AtomArray. + Can be applied before or after croppping, since cropping may lead to PN units with all unresolved atoms that were previously not entirely unresolved. At training time, these unresolved PN units provide minimal value and can lead to errors in the model. """ @@ -72,6 +73,7 @@ def remove_unresolved_tokens(atom_array: AtomArray) -> AtomArray: class RemoveUnresolvedPNUnits(Transform): """ Filters PN units that have all unresolved atoms (i.e., atoms with occupancy 0) from the AtomArray. + Can be applied before or after croppping, since cropping may lead to PN units with all unresolved atoms that were previously not entirely unresolved. At training time, these unresolved PN units provide minimal value and can lead to errors in the model. """ @@ -222,10 +224,10 @@ def check_input(self, data: dict) -> None: def forward(self, data: dict) -> dict: if ("extra_info" not in data) or (self.pn_unit_iid_key not in data["extra_info"]): - # ...short-circuit if the key does not exist in the `extra_info` dictionary + # ... short-circuit if the key does not exist in the `extra_info` dictionary return data else: - # ...otherwise, filter the atom array + # ... otherwise, filter the atom array data["atom_array"] = filter_to_specified_pn_units( data["atom_array"], eval(data["extra_info"][self.pn_unit_iid_key]) ) diff --git a/src/atomworks/ml/transforms/msa/_msa_constants.py b/src/atomworks/ml/transforms/msa/_msa_constants.py index c958308f..4b48a16c 100644 --- a/src/atomworks/ml/transforms/msa/_msa_constants.py +++ b/src/atomworks/ml/transforms/msa/_msa_constants.py @@ -78,8 +78,8 @@ def create_lookup_table(one_letter_to_int: dict, fallback_letter: str) -> np.nda Ordered list of protein amino acid one-letter codes, including gaps, ambiguous, and rare amino acids. References: - - https://iupac.qmul.ac.uk/AminoAcid/A2021.html#AA21 (for IUPAC amino acid codes) - - https://www.cup.uni-muenchen.de/ch/compchem/tink/as.html (for Pyrollisine) + `IUPAC Amino Acid Codes `_ + `Pyrollisine `_ """ RNA_NUCLEOTIDE_ONE_LETTER_TO_INT = { @@ -109,8 +109,8 @@ def create_lookup_table(one_letter_to_int: dict, fallback_letter: str) -> np.nda """ Ordered list of RNA nucleotide one-letter codes, including gaps, ambiguous, and rare residues. -References: - - https://www.promega.com/resources/guides/nucleic-acid-analysis/restriction-enzyme-resource/restriction-enzyme-resource-tables/iupac-ambiguity-codes-for-nucleotide-degeneracy/ +Reference: + `IUPAC Ambiguity Codes for Nucleotide Degeneracy `_ """ # Create lookup tables from MSA one letter codes to integers, based on the above mappings diff --git a/src/atomworks/ml/transforms/msa/_msa_featurizing_utils.py b/src/atomworks/ml/transforms/msa/_msa_featurizing_utils.py index b682ea17..afab7fcd 100644 --- a/src/atomworks/ml/transforms/msa/_msa_featurizing_utils.py +++ b/src/atomworks/ml/transforms/msa/_msa_featurizing_utils.py @@ -151,8 +151,8 @@ def mask_msa_like_bert( - masked_msa (torch.Tensor): Tensor [n_rows, n_tokens_across_chains] representing the masked MSA, with the mask only applied to indices where `index_can_be_masked` is True. - mask_position (torch.Tensor): Boolean tensor [n_rows, n_tokens_across_chains] indicating positions where a mask was applied (i.e., one of the outcomes of the mask behavior) - References: - - AF2 Supplement https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf + Reference: + `AF2 Supplement `_ """ # We start by defining the probabilities for each masking behavior: @@ -318,8 +318,8 @@ def summarize_clusters( Examples: See the test cases in `test_featurize_msa`. - References: - - AlphaFold2 (https://github.com/google-deepmind/alphafold/blob/f251de6613cb478207c732bf9627b1e853c99c2f/alphafold/model/tf/data_transforms.py#L292) + Reference: + `AlphaFold2 data_transforms.py `_ """ n_clust = selected_indices.shape[0] n_rows, n_seq = encoded_msa.shape diff --git a/src/atomworks/ml/transforms/msa/_msa_loading_utils.py b/src/atomworks/ml/transforms/msa/_msa_loading_utils.py index e186de38..abafd84c 100644 --- a/src/atomworks/ml/transforms/msa/_msa_loading_utils.py +++ b/src/atomworks/ml/transforms/msa/_msa_loading_utils.py @@ -70,8 +70,8 @@ def parse_fasta(filename: PathLike, maxseq: int = 10000, query_tax_id: str = "qu ins (np.ndarray): Array of shape (N, L) where N is the number of sequences and L is the length of sequences. tax_ids (np.ndarray): Array of shape (N,) containing the taxonomy IDs for each sequence in the MSA. - References: - - UniProt FASTA Header Documentation (https://www.uniprot.org/help/fasta-headers) + Reference: + `UniProt FASTA Header Documentation `_ """ msa = [] ins = [] @@ -149,8 +149,8 @@ def parse_a3m( tax_ids (np.ndarray): Array of shape (N,) containing the taxonomy IDs for each sequence in the MSA. - References: - - A3M Format Documentation (https://yanglab.qd.sdu.edu.cn/trRosetta/msa_format.html#a3m) + Reference: + `A3M Format Documentation `_ """ msa = [] ins = [] diff --git a/src/atomworks/ml/transforms/msa/msa.py b/src/atomworks/ml/transforms/msa/msa.py index 97f03cbb..78b3158f 100644 --- a/src/atomworks/ml/transforms/msa/msa.py +++ b/src/atomworks/ml/transforms/msa/msa.py @@ -992,8 +992,8 @@ class FeaturizeMSALikeAF3(Transform): - "profile": Shape [n_tokens_across_chains, n_tokens]. Distribution across restypes in the main MSA. Computed before MSA truncation. - "insertion_mean": Shape [n_tokens_across_chains]. Mean number of insertions to the left of each position in the main MSA. Computed before MSA truncation. - References: - - AF3 Supplement, Table 5: https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf + Reference: + `AF3 Supplement, Table 5 `_ """ requires_previous_transforms: ClassVar[list[str | Transform]] = [ diff --git a/src/atomworks/ml/transforms/openbabel_utils.py b/src/atomworks/ml/transforms/openbabel_utils.py index 78d89e53..9fa54758 100644 --- a/src/atomworks/ml/transforms/openbabel_utils.py +++ b/src/atomworks/ml/transforms/openbabel_utils.py @@ -6,8 +6,8 @@ allows using OpenBabel for identifying e.g. stereochemistry, automorphisms, etc. References: -- OpenBabel documentation: https://open-babel.readthedocs.io/ -- Biotite documentation: https://www.biotite-python.org/ + `OpenBabel documentation `_ + `Biotite documentation `_ """ import logging @@ -74,7 +74,7 @@ or lone pairs. Reference: - - https://open-babel.readthedocs.io/en/latest/Stereochemistry/stereo.html#accessing-stereochemistry-information + `OpenBabel Stereochemistry Documentation `_ """ @@ -86,66 +86,69 @@ def atom_array_to_openbabel( annotations_to_keep: list[str] = _BIOTITE_DEFAULT_ANNOTATIONS, ph_for_inferred_hydrogens: float = 7.4, ) -> openbabel.OBMol: - """ - Convert a Biotite AtomArray to an OpenBabel OBMol with the option of keeping custom AtomArray annotations. + """Convert a Biotite AtomArray to an OpenBabel OBMol with the option of keeping custom AtomArray annotations. + + For easier interfacing with the OBMol object, you can wrap it into a pybel.Molecule object. - For easier interfacing with the `OBMol` object, you can wrap it into a `pybel.Molecule` object. - - https://open-babel.readthedocs.io/en/latest/UseTheLibrary/Python_PybelAPI.html - - https://github.com/openbabel/documentation/blob/master/pybel.py + - https://open-babel.readthedocs.io/en/latest/UseTheLibrary/Python_PybelAPI.html + - https://github.com/openbabel/documentation/blob/master/pybel.py Args: - atom_array (AtomArray): The Biotite AtomArray to convert. - set_coords (bool, optional): If True, set the atomic coordinates from the AtomArray in the OBMol. Defaults to True. - infer_aromaticity (bool, optional): If True, infer aromaticity in the OBMol or take the aromaticity annotations from the AtomArray. Defaults to False. - infer_hydrogens (bool, optional): If True, infer hydrogens in the OBMol or take the hydrogens annotations from the AtomArray. Defaults to False. - annotations_to_keep (list[str], optional): List of annotation categories to keep from the AtomArray. Defaults to _BIOTITE_DEFAULT_ANNOTATIONS. - ph_for_inferred_hydrogens (float, optional): The pH value to use for inferred hydrogens. Defaults to pH 7.4 which is the openbabel default. + atom_array: The Biotite AtomArray to convert. + set_coords: If True, set the atomic coordinates from the AtomArray in the OBMol. Defaults to True. + infer_aromaticity: If True, infer aromaticity in the OBMol or take the aromaticity annotations from the AtomArray. Defaults to False. + infer_hydrogens: If True, infer hydrogens in the OBMol or take the hydrogens annotations from the AtomArray. Defaults to False. + annotations_to_keep: List of annotation categories to keep from the AtomArray. Defaults to _BIOTITE_DEFAULT_ANNOTATIONS. + ph_for_inferred_hydrogens: The pH value to use for inferred hydrogens. Defaults to pH 7.4 which is the openbabel default. The pH value is exposed here explicitly, but we recommend using the default value and only changing it if you have a good reason, as this will likely make it out of sync with other parts of the codebase which use the default pH value. Returns: - openbabel.OBMol: The converted OpenBabel OBMol. The custom annotations are stored in the `_annotations` attribute. + The converted OpenBabel OBMol. The custom annotations are stored in the _annotations attribute. Example: - >>> from biotite.structure import AtomArray, BondType - >>> import numpy as np - >>> from atomworks.ml.transforms.openbabel_utils import atom_array_to_openbabel - >>> # Create AtomArray - >>> atom_array = AtomArray(5) - >>> atom_array.element = np.array(["C", "C", "O", "N", "H"]) - >>> atom_array.coord = np.array( - ... [ - ... [0.0, 0.0, 0.0], - ... [1.5, 0.0, 0.0], - ... [1.5, 1.5, 0.0], - ... [0.0, 1.5, 0.0], - ... [0.0, 0.0, 1.5], - ... ] - ... ) - >>> # Add bonds - >>> atom_array.bonds = struc.BondList(len(atom_array)) - >>> atom_array.bonds.add_bond(0, 1, BondType.SINGLE) - >>> atom_array.bonds.add_bond(1, 2, BondType.DOUBLE) - >>> atom_array.bonds.add_bond(1, 3, BondType.SINGLE) - >>> atom_array.bonds.add_bond(0, 4, BondType.SINGLE) - >>> # Convert to OpenBabel molecule - >>> obmol = atom_array_to_openbabel(atom_array) - >>> # Print number of atoms - >>> print(f"Number of atoms: {obmol.NumAtoms()}") - Number of atoms: 5 - >>> # Print atom information - >>> print("\nAtom information:") - >>> for atom in openbabel.OBMolAtomIter(obmol): - ... print( - ... f"Atomic number: {atom.GetAtomicNum()}, Coordinates: ({atom.GetX():.1f}, {atom.GetY():.1f}, {atom.GetZ():.1f})" - ... ) - - Atom information: - Atomic number: 6, Coordinates: (0.0, 0.0, 0.0) - Atomic number: 6, Coordinates: (1.5, 0.0, 0.0) - Atomic number: 8, Coordinates: (1.5, 1.5, 0.0) - Atomic number: 7, Coordinates: (0.0, 1.5, 0.0) - Atomic number: 1, Coordinates: (0.0, 0.0, 1.5) + .. code-block:: python + + from biotite.structure import AtomArray, BondType + import numpy as np + from atomworks.ml.transforms.openbabel_utils import atom_array_to_openbabel + + # Create AtomArray + atom_array = AtomArray(5) + atom_array.element = np.array(["C", "C", "O", "N", "H"]) + atom_array.coord = np.array( + [ + [0.0, 0.0, 0.0], + [1.5, 0.0, 0.0], + [1.5, 1.5, 0.0], + [0.0, 1.5, 0.0], + [0.0, 0.0, 1.5], + ] + ) + # Add bonds + atom_array.bonds = struc.BondList(len(atom_array)) + atom_array.bonds.add_bond(0, 1, BondType.SINGLE) + atom_array.bonds.add_bond(1, 2, BondType.DOUBLE) + atom_array.bonds.add_bond(1, 3, BondType.SINGLE) + atom_array.bonds.add_bond(0, 4, BondType.SINGLE) + # Convert to OpenBabel molecule + obmol = atom_array_to_openbabel(atom_array) + # Print number of atoms + print(f"Number of atoms: {obmol.NumAtoms()}") + # Number of atoms: 5 + # Print atom information + print("\nAtom information:") + for atom in openbabel.OBMolAtomIter(obmol): + print( + f"Atomic number: {atom.GetAtomicNum()}, Coordinates: ({atom.GetX():.1f}, {atom.GetY():.1f}, {atom.GetZ():.1f})" + ) + + # Atom information: + # Atomic number: 6, Coordinates: (0.0, 0.0, 0.0) + # Atomic number: 6, Coordinates: (1.5, 0.0, 0.0) + # Atomic number: 8, Coordinates: (1.5, 1.5, 0.0) + # Atomic number: 7, Coordinates: (0.0, 1.5, 0.0) + # Atomic number: 1, Coordinates: (0.0, 0.0, 1.5) """ # Initialize empty OpenBabel molecule obmol = openbabel.OBMol() @@ -410,8 +413,8 @@ def find_automorphisms(obmol: openbabel.OBMol, max_automorphs: int = 1000, max_m a single automorphism representing the identity (no swaps). References: - - https://openbabel.org/api/3.0/group__substructure.shtml#ga16841a730cf92c8e51a804ad8d746307 - - https://baoilleach.blogspot.com/2010/11/automorphisms-isomorphisms-symmetry.html + `OpenBabel Substructure API `_ + `Automorphisms and Symmetry Blog `_ Example: >>> from openbabel import pybel @@ -546,48 +549,50 @@ def forward(self, data: dict[str, Any]) -> dict[str, Any]: class GetChiralCentersFromOpenBabel(Transform): - """ - Identify chiral centers in the OpenBabel molecules stored in the `data["openbabel"]` dictionary. - These molecules typically correspond to the atomized molecules in the `data["atom_array"]` (c.f. - `AddOpenBabelMoleculesForAtomizedMolecules`). + """Identify chiral centers in the OpenBabel molecules stored in the data["openbabel"] dictionary. + + These molecules typically correspond to the atomized molecules in the data["atom_array"] (c.f. + AddOpenBabelMoleculesForAtomizedMolecules). Chiral centers are mapped to the global atom IDs in the atom array to enable tracking chiral centers regardless of cropping or reshuffling operations that may modify the atom_array. Args: - data (dict[str, Any]): A dictionary containing the input data, including the atom array and - OpenBabel molecules under the `data["openbabel"]` key. + data: A dictionary containing the input data, including the atom array and + OpenBabel molecules under the data["openbabel"] key. Returns: - dict[str, Any]: The updated `data` dictionary with the identified chiral centers under the - `"chiral_centers"` key. The chiral centers are stored as a list of dictionaries, where each + The updated data dictionary with the identified chiral centers under the + "chiral_centers" key. The chiral centers are stored as a list of dictionaries, where each dictionary contains the chiral center global atom ID and the atom IDs of the (3 to 4) atoms bonded to it. Example: - data = { - "atom_array": atom_array, - "openbabel": { - 1: obmol1, - 2: obmol2, + .. code-block:: python + + data = { + "atom_array": atom_array, + "openbabel": { + 1: obmol1, + 2: obmol2, + }, } - } - transform = GetChiralCentersFromOpenBabel() - result = transform.forward(data) - - print(result["chiral_centers"]) - # Output might look like: - # [ - # { - # "chiral_center_atom_id": 5, - # "bonded_explicit_atom_ids": [1, 2, 3, 4] - # }, - # { - # "chiral_center_atom_id": 10, - # "bonded_explicit_atom_ids": [6, 7, 8, 9] - # } - # ] + transform = GetChiralCentersFromOpenBabel() + result = transform.forward(data) + + print(result["chiral_centers"]) + # Output might look like: + # [ + # { + # "chiral_center_atom_id": 5, + # "bonded_explicit_atom_ids": [1, 2, 3, 4] + # }, + # { + # "chiral_center_atom_id": 10, + # "bonded_explicit_atom_ids": [6, 7, 8, 9] + # } + # ] """ requires_previous_transforms: ClassVar[list[str | Transform]] = [ diff --git a/src/atomworks/ml/transforms/rdkit_utils.py b/src/atomworks/ml/transforms/rdkit_utils.py index ab7c7ab0..fe9c55b8 100644 --- a/src/atomworks/ml/transforms/rdkit_utils.py +++ b/src/atomworks/ml/transforms/rdkit_utils.py @@ -60,35 +60,34 @@ def generate_conformers( attempts_with_random_coordinates: int = 10_000, **uff_optimize_kwargs: dict, ) -> Mol: - """ - Generate conformations for the given molecule. + """Generate conformations for the given molecule. Args: - - mol (rdkit.Chem.Mol): The RDKit molecule to generate conformations for. - - seed (int | None): Random seed for reproducibility. If None, a random seed is used. - - n_conformers (int): Number of conformations to generate. - - method (str): The method to use for conformer generation. Default is "ETKDGv3". + mol: The RDKit molecule to generate conformations for. + seed: Random seed for reproducibility. If None, a random seed is used. + n_conformers: Number of conformations to generate. + method: The method to use for conformer generation. Default is "ETKDGv3". Allowed methods are: "ETDG", "ETKDG", "ETKDGv2", "ETKDGv3", "srETKDGv3" See https://rdkit.org/docs/RDKit_Book.html#conformer-generation for details. - - num_threads (int): Number of threads to use for parallel computation. Default is 1. - - hydrogen_policy (Literal["infer", "remove", "keep", "auto"]): Whether to add explicit + num_threads: Number of threads to use for parallel computation. Default is 1. + hydrogen_policy: Whether to add explicit hydrogens to the molecule. If "remove", hydrogens are temporarily added for conformer generation, but removed again before returning the molecule. If "keep" the molecule is used as-is (without adding or removing hydrogens). If "auto", the policy is set to "keep" if the molecule already has explicit hydrogens, otherwise it is set to "remove". If "infer", we follow the same behavior as "remove," but do not remove added hydrogens prior to returning the molecule. - - optimize (bool): Whether to optimize the generated conformers using UFF. + optimize: Whether to optimize the generated conformers using UFF. Default is True. - - **uff_optimize_kwargs (dict): Additional keyword arguments for UFF optimization: - - maxIters (int): Maximum number of iterations (default 200). - - vdwThresh (float): Used to exclude long-range van der Waals interactions + **uff_optimize_kwargs: Additional keyword arguments for UFF optimization: + - maxIters: Maximum number of iterations (default 200). + - vdwThresh: Used to exclude long-range van der Waals interactions (default 10.0). - - ignoreInterfragInteractions (bool): If True, nonbonded terms between + - ignoreInterfragInteractions: If True, nonbonded terms between fragments will not be added to the forcefield (default True). Returns: - rdkit.Chem.Mol: The molecule with generated conformations. + The molecule with generated conformations. Note: - Optimizing conformers (optimize_conformers=True) is recommended for obtaining @@ -114,11 +113,10 @@ def generate_conformers( maxIterations or use more advanced sampling techniques. References: - 1. Conformer tutorial: https://rdkit.org/docs/RDKit_Book.html#conformer-generation - 1. RDKit Cookbook: https://www.rdkit.org/docs/Cookbook.html - 2. Riniker and Landrum, "Better Informed Distance Geometry: Using What We Know To - Improve Conformation Generation", JCIM, 2015. - + `Conformer tutorial `_ + `RDKit Cookbook `_ + Riniker and Landrum, "Better Informed Distance Geometry: Using What We Know To + Improve Conformation Generation", JCIM, 2015. """ # Ensure that all properties are being pickled (needed when we use timeout) assert ( @@ -326,8 +324,8 @@ def find_automorphisms_with_rdkit( If the search fails (e.g. due to running out of memory), returns an array with a single automorphism representing the identity (no swaps). - References: - - https://sourceforge.net/p/rdkit/mailman/message/27897393/ + Reference: + `RDKit Mailman Discussion `_ Example: >>> from openbabel import pybel @@ -397,21 +395,21 @@ def sample_rdkit_conformer_for_atom_array( """Sample a conformer for a Biotite AtomArray using RDKit. Args: - - atom_array: The Biotite AtomArray to sample a conformer for. - - n_conformers: The number of conformers to sample. - - timeout: The timeout for conformer generation. If None, + atom_array: The Biotite AtomArray to sample a conformer for. + n_conformers: The number of conformers to sample. + timeout: The timeout for conformer generation. If None, no timeout is applied. If a tuple, the first element is the offset and the second element is the slope. - - seed: The seed for conformer generation. If None, a random seed + seed: The seed for conformer generation. If None, a random seed is generated using the global numpy RNG. - - timeout_strategy: The strategy to use for the timeout. + timeout_strategy: The strategy to use for the timeout. Defaults to "subprocess". - - **generate_conformers_kwargs: Additional keyword arguments to pass to the + **generate_conformers_kwargs: Additional keyword arguments to pass to the generate_conformers function. Returns: - - AtomArray: The AtomArray with updated coordinates from the sampled conformer. - - Chem.Mol: The RDKit molecule with the generated conformer. + The AtomArray with updated coordinates from the sampled conformer. + The RDKit molecule with the generated conformer. Note: This function preserves the original atom order and properties of the input AtomArray. @@ -464,13 +462,13 @@ def ccd_code_to_rdkit_with_conformers( skip_rdkit_conformer_generation: bool = False, **generate_conformers_kwargs, ) -> Chem.Mol: - """ - Generate an RDKit molecule with conformers for a given residue name. + """Generate an RDKit molecule with conformers for a given residue name. This function attempts to generate the specified number of conformers for the given CCD code using RDKit's conformer generation (based on ETKDGv3 per default). If conformer generation fails or times out, it falls back to using the idealized conformer from the CCD entry if one is available. + Args: ccd_code: The CCD code to generate conformers for. E.g. 'ALA' or 'GLY', '9RH' etc. n_conformers: The number of conformers to generate for the given CCD code. @@ -485,7 +483,7 @@ def ccd_code_to_rdkit_with_conformers( generate_conformers function. Returns: - Chem.Mol: An RDKit molecule with the specified number of conformers. + An RDKit molecule with the specified number of conformers. """ # ... get molecule from CCD with its idealized conformer (default conformer 0) mol = ccd_code_to_rdkit(ccd_code, hydrogen_policy="remove") @@ -699,28 +697,32 @@ def get_rdkit_chiral_centers(rdkit_mols: dict[str, Mol]) -> dict: class GetRDKitChiralCenters(Transform): - """ - Identify chiral centers in the RDKit molecules stored in the `data["rdkit"]` dictionary. + """Identify chiral centers in the RDKit molecules stored in the data["rdkit"] dictionary. + Returns a dictionary mapping each residue name to a list of chiral centers, e.g: - data["chiral_centers"] = { - ... - "ILE": [ - {'chiral_center_idx': 1, 'bonded_explicit_atom_idxs': [0, 2, 4], 'chirality': 'S'}, - {'chiral_center_idx': 4, 'bonded_explicit_atom_idxs': [1, 5, 6], 'chirality': 'S'} - ], - ... - } + + .. code-block:: python + + data["chiral_centers"] = { + ... + "ILE": [ + {'chiral_center_idx': 1, 'bonded_explicit_atom_idxs': [0, 2, 4], 'chirality': 'S'}, + {'chiral_center_idx': 4, 'bonded_explicit_atom_idxs': [1, 5, 6], 'chirality': 'S'} + ], + ... + } + Each chiral center is a dict with a center atom index, 3 or 4 bonded atom indices, and the RDKit-determined chirality. Uses RDKit molecules first computed in GetAF3ReferenceMoleculeFeatures. Args: - data (dict[str, Any]): A dictionary containing the input data, including RDKit molecules - under the `"rdkit"` key. + data: A dictionary containing the input data, including RDKit molecules + under the "rdkit" key. Returns: - dict[str, Any]: The updated `data` dictionary with `chiral_centers` containing chiral + The updated data dictionary with chiral_centers containing chiral centers for each molecule. """ diff --git a/src/atomworks/ml/transforms/sasa.py b/src/atomworks/ml/transforms/sasa.py index 4b6226eb..bb9b63ce 100644 --- a/src/atomworks/ml/transforms/sasa.py +++ b/src/atomworks/ml/transforms/sasa.py @@ -133,16 +133,14 @@ def check_input(self, data: dict[str, Any]) -> None: check_atom_array_annotation(data, ["res_name"]) def forward(self, data: dict, key_to_add_sasa_to: str = "atom_array") -> dict: - """ - Calculates SASA and adds it to the data dictionary under the key `atom_array`. + """Calculates SASA and adds it to the data dictionary under the key "atom_array". + Args: - data: dict - A dictionary containing the input data atomarray. - key_to_add_sasa_to: str - The key in the data dictionary to add the SASA values to. + data: A dictionary containing the input data atomarray. + key_to_add_sasa_to: The key in the data dictionary to add the SASA values to. Returns: - dict: The data dictionary with SASA values added. + The data dictionary with SASA values added. """ atom_array: AtomArray = data[key_to_add_sasa_to] sasa = calculate_atomwise_sasa( diff --git a/src/atomworks/ml/transforms/symmetry.py b/src/atomworks/ml/transforms/symmetry.py index 8b245483..d9dd1477 100644 --- a/src/atomworks/ml/transforms/symmetry.py +++ b/src/atomworks/ml/transforms/symmetry.py @@ -34,44 +34,45 @@ def find_automorphisms(atom_array: AtomArray) -> np.ndarray: def apply_automorphs(data: torch.Tensor, automorphs: np.ndarray | torch.Tensor) -> torch.Tensor: - """ - Create data permutations of the input data for each of the automorphs. + """Create data permutations of the input data for each of the automorphs. - This function generates permutations of the input tensor `data` based on the provided automorphisms. + This function generates permutations of the input tensor data based on the provided automorphisms. Each permutation corresponds to a different automorphism, effectively reordering the data according to the automorphisms. Args: - - data (torch.Tensor): The input tensor to be permuted. The first dimension has to correspond to + data: The input tensor to be permuted. The first dimension has to correspond to the number of atoms. - - automorphs (np.ndarray | torch.Tensor): A tensor or numpy array of shape [n_automorphs, n_atoms, 2] + automorphs: A tensor or numpy array of shape [n_automorphs, n_atoms, 2] representing the automorphisms. Each automorphism is a list of paired atom indices - (from_idx, to_idx). The `from_idx` column is essentially just a repetition of np.arange(n_atoms). + (from_idx, to_idx). The from_idx column is essentially just a repetition of np.arange(n_atoms). Returns: - - data_automorphs (torch.Tensor): A tensor of shape [n_automorphs, *data.shape] containing the permuted + A tensor of shape [n_automorphs, ``*data.shape``] containing the permuted data for each automorphism. Example: - >>> data = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) - >>> # Example automorphisms (2 automorphisms for 3 atoms) - >>> automorphs = np.array([ + .. code-block:: python + + data = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + # Example automorphisms (2 automorphisms for 3 atoms) + automorphs = np.array([ [[0, 0], [1, 1], [2, 2]], [[0, 2], [1, 0], [2, 1]] - ... ]) - >>> permuted_data = create_automorph_permutations(data, automorphs) - >>> print(permuted_data) - tensor([[[1.0, 2.0], - [3.0, 4.0], - [5.0, 6.0]], - - [[5.0, 6.0], - [1.0, 2.0], - [3.0, 4.0]]]) + ... ]) + permuted_data = create_automorph_permutations(data, automorphs) + print(permuted_data) + # tensor([[[1.0, 2.0], + # [3.0, 4.0], + # [5.0, 6.0]], + # + # [[5.0, 6.0], + # [1.0, 2.0], + # [3.0, 4.0]]]) """ automorphs = torch.as_tensor(automorphs) n_automorphs, n_atoms, _ = automorphs.shape @@ -518,33 +519,33 @@ def handle_polymer_isomorphisms( post_poly_array: AtomArray, crop_tmask: np.ndarray, ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Handles polymer symmetries by computing all swaps between isomorphic (i.e., equivalent) polymers + """Handles polymer symmetries by computing all swaps between isomorphic (i.e., equivalent) polymers that are at least partially in the crop. NOTE: This function only swaps full chains. Swaps within atoms of a polymer (e.g., residue naming ambiguities) are not considered and are handled elsewhere. The process involves the following steps: + 1. Subset the crop mask and pre-cropped atom array to only include polymers that are at least partially in the crop. 2. Among these, identify polymers that are equal to each other (i.e., symmetry groups). 3. Generate all possible combinations of in-group permutations (isomorphisms). - 3. Apply these isomorphisms to the coordinates and masks of the pre-cropped, encoded polymers. - 4. Crop to the relevant bits that appear in the crop. - 5. De-duplicate the isomorphisms to remove any redundancies. + 4. Apply these isomorphisms to the coordinates and masks of the pre-cropped, encoded polymers. + 5. Crop to the relevant bits that appear in the crop. + 6. De-duplicate the isomorphisms to remove any redundancies. Args: - - pre_poly_array (AtomArray): The atom array representing the state before cropping, + pre_poly_array: The atom array representing the state before cropping, containing polymer tokens. - - post_poly_array (AtomArray): The atom array representing the state after cropping, + post_poly_array: The atom array representing the state after cropping, containing polymer tokens. - - crop_tmask (np.ndarray): A boolean mask indicating which tokens are included in the crop. + crop_tmask: A boolean mask indicating which tokens are included in the crop. Returns: - - poly_xyz (torch.Tensor): The xyz coordinates of the polymers after applying the isomorphisms. + poly_xyz: The xyz coordinates of the polymers after applying the isomorphisms. It has shape [n_permutations, n_crop_tokens, n_atoms_per_token, 3]. - - poly_mask (torch.Tensor): The mask of the polymers after applying the isomorphisms. + poly_mask: The mask of the polymers after applying the isomorphisms. It has shape [n_permutations, n_crop_tokens, n_atoms_per_token]. """ # NOTATION: a = atom-level, t = token-level, tidx = token-level index, tmask = token-level mask @@ -626,21 +627,21 @@ def handle_nonpoly_automorphisms( crop_tmask: np.ndarray, openbabel_data: dict[int, Any], ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Handles non-polymer symmetries by computing automorphs within each non-polymer that + """Handles non-polymer symmetries by computing automorphs within each non-polymer that is at least partially in the crop. This function calculates the swapped coordinate and mask values for each molecule and - concatenates all automorphs together, padding the `n_permutations` dimension to the + concatenates all automorphs together, padding the n_permutations dimension to the maximum number of automorphs for any molecule. WARNING: Unlike polymer symmetries, inter-molecule symmetries are not considered here, as they are managed by the RF2AA loss function through a greedy search. For non-polymers, the following steps are performed: + 1. Subset the pre-cropped non-poly array to only include the non-poly molecules that are at least partially in the crop. - 2. Compute the automorphs for each identified full molecule (i.e. *BEFORE* cropping). + 2. Compute the automorphs for each identified full molecule (i.e. BEFORE cropping). 3. Apply the computed automorphs to the coordinates and masks of the encoded full molecules. 4. Crop to the relevant sections of the molecules that appear in the crop. 5. Concatenate all automorphs together, padding to the maximum number of automorphs @@ -649,15 +650,15 @@ def handle_nonpoly_automorphisms( two automorphs, but the atom swaps are in a part that is not within the crop) Args: - - pre_nonpoly_array (AtomArray): The pre-cropped non-polymer array to process. - - post_nonpoly_array (AtomArray): The post-cropped non-polymer array to process. - - crop_tmask (np.ndarray): A boolean mask indicating which tokens are in the crop. - - openbabel_data (dict[int, Any]): A dictionary containing Open Babel data for molecules. + pre_nonpoly_array: The pre-cropped non-polymer array to process. + post_nonpoly_array: The post-cropped non-polymer array to process. + crop_tmask: A boolean mask indicating which tokens are in the crop. + openbabel_data: A dictionary containing Open Babel data for molecules. Returns: - - nonpoly_xyzs (torch.Tensor): A tensor containing the coordinates of the non-polymer automorphs. - - nonpoly_masks (torch.Tensor): A tensor containing the masks of the non-polymer automorphs. - - symmetry_info (dict[tuple[int, str], int]): A dictionary containing the symmetry information. + nonpoly_xyzs: A tensor containing the coordinates of the non-polymer automorphs. + nonpoly_masks: A tensor containing the masks of the non-polymer automorphs. + symmetry_info: A dictionary containing the symmetry information. """ n_nonpoly_token_in_crop = get_token_count(post_nonpoly_array) diff --git a/src/atomworks/ml/transforms/template.py b/src/atomworks/ml/transforms/template.py index 73970be5..d1b362e0 100644 --- a/src/atomworks/ml/transforms/template.py +++ b/src/atomworks/ml/transforms/template.py @@ -41,27 +41,26 @@ @dataclass class RF2AATemplate: - """ - Data class for holding template information in the RF, RF2 & RF2AA format. + """Data class for holding template information in the RF, RF2 & RF2AA format. NOTE: - - RF templates only exist for proteins - - This is a helper class to cast the templates into a more `readable` format and also - to provide an interface layer that allows us to deal with templates as atom_arrays, if - we ever re-create templates or add templates for non-proteins - - RF-style templates already come encoded in atom14 representation (RFAtom14, not AF2Atom14) + - RF templates only exist for proteins + - This is a helper class to cast the templates into a more readable format and also + to provide an interface layer that allows us to deal with templates as atom_arrays, if + we ever re-create templates or add templates for non-proteins + - RF-style templates already come encoded in atom14 representation (RFAtom14, not AF2Atom14) Keys: - - xyz: Tensor([1, n_templates x n_atoms_per_template, 14, 3]), raw coordinates of all templates - - mask: Tensor([1, n_templates x n_atom_per_template, 14]), mask of all templates - - qmap: Tensor([1, n_templates x n_atom_per_template, 2]), alignment mapping of all templates - - index 0: which index in the query protein this template index matches to - - index 1: which template index this matches to - - f0d: Tensor([1, n_templates, 8?]), [0,:,4] holds sequence identity info - - f1d: Tensor([1, n_templates x n_atoms_per_template, 3]), something in there may be related to template confidence, gaps? - - seq: Tensor([1, 100677]) (tensor, encoded with Chemdata.aa2num encoding) - - ids: list[tuple[str]] # Holds the f"{pdb_id}_{chain_id}" of the template - - label: list[str] # holds the lookup_id for this template + - xyz: Tensor([1, n_templates x n_atoms_per_template, 14, 3]), raw coordinates of all templates + - mask: Tensor([1, n_templates x n_atom_per_template, 14]), mask of all templates + - qmap: Tensor([1, n_templates x n_atom_per_template, 2]), alignment mapping of all templates + - index 0: which index in the query protein this template index matches to + - index 1: which template index this matches to + - f0d: Tensor([1, n_templates, 8?]), [0,:,4] holds sequence identity info + - f1d: Tensor([1, n_templates x n_atoms_per_template, 3]), something in there may be related to template confidence, gaps? + - seq: Tensor([1, 100677]) (tensor, encoded with Chemdata.aa2num encoding) + - ids: list[tuple[str]] # Holds the f"{pdb_id}_{chain_id}" of the template + - label: list[str] # holds the lookup_id for this template """ xyz: torch.Tensor # [1, n_templates x n_atoms_per_template, 14, 3] @@ -735,10 +734,8 @@ def featurize_templates_like_af3( dict: A dictionary containing the template features. References: - - Section 2.8 of the AF3 supplementary information - https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf - - AF2 supplementary information - https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf + `Section 2.8 of the AF3 supplementary information `_ + `AF2 supplementary information `_ NOTE: For templates a token is always a residue since we never align ligands, non-canonicals, PTMs, etc. """ @@ -907,10 +904,8 @@ class FeaturizeTemplatesLikeAF3(Transform): of the CA atom of all residues within the local frame of each residue. References: - - Section 2.8 of the AF3 supplementary information - https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf - - AF2 supplementary information - https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf + `Section 2.8 of the AF3 supplementary information `_ + `AF2 supplementary information `_ """ requires_previous_transforms: ClassVar[list[str | Transform]] = [ diff --git a/src/atomworks/ml/utils/debug.py b/src/atomworks/ml/utils/debug.py index 73ff5672..28225f74 100644 --- a/src/atomworks/ml/utils/debug.py +++ b/src/atomworks/ml/utils/debug.py @@ -1,3 +1,8 @@ +"""Debug utilities for ML components. + +Provides functions for saving failed examples and debugging ML pipelines. +""" + import logging import os import pickle @@ -16,6 +21,14 @@ def _remove_special_characters(s: str) -> str: + """Remove special characters from a string. + + Args: + s: The string to clean. + + Returns: + The cleaned string with only alphanumeric characters and underscores. + """ assert isinstance(s, str) # Remove unwanted characters using regex clean_s = re.sub(r"[^a-zA-Z0-9_]", "", s) @@ -30,17 +43,14 @@ def save_failed_example_to_disk( rng_state_dict: dict = {}, error_msg: str = "", ) -> None: - """ - Attempts to save a failed example to disk as a pickle file. + """Attempts to save a failed example to disk as a pickle file. Args: - - example_id (str): The ID of the example. - - fail_dir (str): The directory where the failed example should be saved. Defaults to a specific path. - - rng_state_dict (dict): The random number generator state dictionary. - - error_msg (str): The error message associated with the failure. - - Returns: - None + example_id: The ID of the example. + fail_dir: The directory where the failed example should be saved. + data: Optional data dictionary to save. + rng_state_dict: The random number generator state dictionary. + error_msg: The error message associated with the failure. """ try: # Get wandb run ID if currently in a wandb run diff --git a/src/atomworks/ml/utils/geometry.py b/src/atomworks/ml/utils/geometry.py index 12a7a13d..b52159c7 100644 --- a/src/atomworks/ml/utils/geometry.py +++ b/src/atomworks/ml/utils/geometry.py @@ -38,8 +38,7 @@ def rigid_from_3_points( t: torch.Tensor of shape [..., 3], translation vector Reference: - - AF2 supplementary, Algorithm 21 - https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf + `AF2 supplementary, Algorithm 21 `_ Example: >>> x1 = torch.tensor([0.0, 0.0, 1.0]) @@ -239,18 +238,18 @@ def get_random_rots(batch_size: int, **tensor_kwargs) -> torch.Tensor: def get_random_rigid(batch_size: int, scale: float = 1.0, **tensor_kwargs) -> tuple[torch.Tensor, torch.Tensor]: - """ - Generate random rigid body transformations (R, t). + """Generate random rigid body transformations (R, t). Args: - - batch_size (int): Number of rigid transformations to generate. - - scale (float, optional): Scale factor for the translation vectors. Defaults to 1.0. - - **tensor_kwargs: Additional keyword arguments to pass to tensor creation functions. + batch_size: Number of rigid transformations to generate. + scale: Scale factor for the translation vectors. Defaults to 1.0. + **tensor_kwargs: Additional keyword arguments to pass to tensor creation functions. Returns: - - tuple[torch.Tensor, torch.Tensor]: A `rigid`tuple containing: - - rots (torch.Tensor): Batch of random rotation matrices with shape (batch_size, 3, 3). - - trans (torch.Tensor): Batch of random translation vectors with shape (batch_size, 3). + A rigid tuple containing: + + - rots: Batch of random rotation matrices with shape (batch_size, 3, 3). + - trans: Batch of random translation vectors with shape (batch_size, 3). Note: If batch_size is 1, the output tensors are squeezed to remove the batch dimension. diff --git a/src/atomworks/ml/utils/io.py b/src/atomworks/ml/utils/io.py index eb38cc0a..a7e91e49 100644 --- a/src/atomworks/ml/utils/io.py +++ b/src/atomworks/ml/utils/io.py @@ -1,39 +1,40 @@ +"""I/O utilities for ML components. + +Provides functions for file operations, directory scanning, and data loading. +""" + import gzip import hashlib +import os import pickle -import warnings +import re from collections.abc import Callable from functools import wraps from os import PathLike from pathlib import Path from typing import Any, TextIO -import biotite.structure as struc -import numpy as np import pandas as pd import pyarrow as pa import pyarrow.parquet as pq -from atomworks.constants import ( - AA_LIKE_CHEM_TYPES, - ATOMIC_NUMBER_TO_ELEMENT, - DNA_LIKE_CHEM_TYPES, - HYDROGEN_LIKE_SYMBOLS, - POLYPEPTIDE_D_CHEM_TYPES, - POLYPEPTIDE_L_CHEM_TYPES, - RNA_LIKE_CHEM_TYPES, - UNKNOWN_LIGAND, -) -from atomworks.io.utils.ccd import get_chem_comp_type from atomworks.ml.utils.misc import ( - convert_pn_unit_iids_to_pn_unit_ids, - extract_transformation_id_from_pn_unit_iid, logger, ) def open_file(filename: PathLike) -> TextIO: - """Open a file, handling gzipped files if necessary.""" + """Open a file, handling gzipped files if necessary. + + Args: + filename: The path to the file to open. + + Returns: + A file-like object for reading. + + Raises: + AssertionError: If the file does not exist. + """ filename = Path(filename) # ...assert that the file exists assert filename.exists(), f"File {filename} does not exist" @@ -43,32 +44,60 @@ def open_file(filename: PathLike) -> TextIO: return filename.open("r") -def cache_based_on_subset_of_args(cache_keys: list[str], maxsize: int | None = None) -> Callable: +def scan_directory(dir_path: PathLike, max_depth: int) -> list[str]: + """Fast, order-independent directory scan for files up to max_depth levels deep. + + Args: + dir_path: The root directory to scan. + max_depth: The maximum depth to scan. A max_depth of 1 means only the top-level directory. + + Returns: + A list of file paths found within the specified directory and depth. """ - Decorator to cache function results based on a subset of its keyword arguments. - Most helpful when some arguments may be unhashable types (e.g., dictionaries, AtomArray). + file_paths = [] + + for root, dirs, files in os.walk(dir_path): + current_depth = len(Path(root).relative_to(dir_path).parts) + + if current_depth >= max_depth: + dirs.clear() + continue + + for file in files: + file_path = os.path.join(root, file) + file_paths.append(file_path) + + return file_paths + +def cache_based_on_subset_of_args(cache_keys: list[str], maxsize: int | None = None) -> Callable: + """Decorator to cache function results based on a subset of its keyword arguments. + + Most helpful when some arguments may be unhashable types (e.g., dictionaries, AtomArray). If the value of any of the cache keys is None, the function is executed and the result is not cached. Note: - The wrapped function must use keyword arguments for those specified in `cache_keys`. + The wrapped function must use keyword arguments for those specified in cache_keys. Positional arguments are not supported for cache key extraction. Args: - cache_keys (List[str]): The names of the keyword arguments to use as the cache key. - maxsize (Optional[int]): The maximum number of entries to store in the cache. + cache_keys: The names of the keyword arguments to use as the cache key. + maxsize: The maximum number of entries to store in the cache. If None, the cache size is unlimited. Returns: - Callable: A decorator that caches the function results based on the specified keyword arguments. + A decorator that caches the function results based on the specified keyword arguments. Example: - @cache_based_on_subset_of_args(['arg1'], maxsize=2) - def function(*, arg1, arg2): - return arg1 + arg2 + .. code-block:: python - result1 = function(arg1=1, arg2=2) # Caches with key 1 - result2 = function(arg1=1, arg2=3) # Retrieves from cache + @cache_based_on_subset_of_args(["arg1"], maxsize=2) + def function(*, arg1, arg2): + return arg1 + arg2 + + + result1 = function(arg1=1, arg2=2) # Caches with key 1 + result2 = function(arg1=1, arg2=3) # Retrieves from cache """ def decorator(func: Callable) -> Callable: @@ -212,213 +241,6 @@ def get_sharded_file_path( return (nested_path / file_hash).with_suffix(extension) -def convert_af3_model_output_to_atom_array_stack( - atom_to_token_map: np.ndarray[int], - pn_unit_iids: np.ndarray[str], - decoded_restypes: np.ndarray[str], - xyz: np.ndarray, - elements: np.ndarray[int | str], - token_is_atomized: np.ndarray[bool] = None, -) -> struc.AtomArrayStack: - """ - Create an AtomArrayStack from AlphaFold-3-type model outputs. - Specific to AF-3; may not work with other formats. - - Parameters: - - atom_to_token_map (np.ndarray): Mapping from atoms to tokens [n_atom] - - pn_unit_iids (np.ndarray): PN unit IID's for each token [n_token] - - decoded_restype (np.ndarray): Decoded residue types for each token [n_token] - - xyz (np.ndarray): Coordinates of atoms [n_atom, 3] or [batch, n_atom, 3], where batch is the number of structures - - elements (np.ndarray): Element types for each atom [n_atom] - - token_is_atomized (np.ndarray, optional): Flags indicating if tokens are atomized [n_token]. If not provided - or None, residues with a single atom are considered atomized. - - Returns: - - AtomArrayStack: Constructed AtomArrayStack. - """ - # Issue a deprecation warning - warnings.warn( - "`convert_af3_model_output_to_atom_array_stack` is deprecated in favor of overwriting the AtomArray coordinates directly and will be removed in future versions.", - DeprecationWarning, - stacklevel=2, - ) - - atom_array = None - chain_iid_residue_counts = {} - - # If dimensions are [n_atom, 3], add a batch dimension - if len(xyz.shape) == 2: - xyz = np.expand_dims(xyz, axis=0) - - # If elements are integers, convert them to strings (since that's what we get from the CCD, and it better matches what CIF files expect) - if np.issubdtype(type(elements[0]), np.integer): - elements = np.array([ATOMIC_NUMBER_TO_ELEMENT[element] for element in elements]) - - ####################################################################################################### - # Iterate over the residues, and create the appropriate atoms for each residue with empty coordinates - # We add the atom type, residue ID, chain ID, and transformation ID to the AtomArray - ####################################################################################################### - - for global_res_idx, res_name in enumerate(decoded_restypes): - # Get atoms corresponding to the residue - atom_indices_in_token = np.where(atom_to_token_map == global_res_idx)[0] - - # ...check if we're dealing with an atomized token - if token_is_atomized is not None: - # If we have the token_is_atomized array, we can use it to determine if the residue is atomized - is_atom = token_is_atomized[global_res_idx] - else: - # Otherwise, we assume that a residue with a single atom is atomized - is_atom = len(atom_indices_in_token) == 1 - - # ...compute the residue ID - pn_unit_iid = pn_unit_iids[global_res_idx] - if pn_unit_iid not in chain_iid_residue_counts: - chain_iid_residue_counts[pn_unit_iid] = 1 - elif not is_atom: - # Only increment the residue count if we're not dealing with an atomized token (we put all atomized tokens in the same residue, like the PDB) - chain_iid_residue_counts[pn_unit_iid] += 1 - res_id = chain_iid_residue_counts[pn_unit_iid] - - if is_atom: - # UNL is "Unknown Ligand" in the CCD - element = elements[atom_indices_in_token].item() - - # ruff: noqa: B023 - def atom_name_exists(atom_name: str) -> bool: - return ( - atom_array[ - (atom_array.pn_unit_iid == pn_unit_iid) - & (atom_array.res_id == res_id) - & (atom_array.atom_name == atom_name) - ].array_length() - > 0 - ) - - # Create the atom name and ensure it's unique within the residue (so that we can give all the atoms the same ID) - atom_name = element - if atom_name_exists(atom_name): - atom_name = next( - f"{element}{atom_count}" - for atom_count in range(2, len(atom_array) + 1) - if not atom_name_exists(f"{element}{atom_count}") - ) - - atom = struc.Atom(np.full((3,), np.nan), res_name=UNKNOWN_LIGAND, element=element, atom_name=atom_name) - residue_atom_array = struc.array([atom]) - else: - chem_type = get_chem_comp_type(res_name) - - # Get the atom array of the residue from the CCD - residue_atom_array = struc.info.residue(res_name) - - # Set the elements to uppercase for consistency - residue_atom_array.element = np.array([x.upper() for x in residue_atom_array.element]) - - # If needed, remove type-specific atoms (e.g., OXT in polypeptides, O3' in RNA or DNA) for residues participating in inter-residue bonds - # If we are at a terminal residue, we don't want to remove these leaving groups - residue_atom_array = filter_residue_atoms( - residue_atom_array=residue_atom_array, chem_type=chem_type, elements=elements[atom_indices_in_token] - ) - - # Empty coordinates to avoid unexpected behavior - residue_atom_array.coord = np.full((residue_atom_array.array_length(), 3), np.nan) - - # Wipe the bond information (we are better off letting PyMOL infer the bonds) - residue_atom_array.bonds = None - - # Get the chain_iid, chain_id, and transformation_id - pn_unit_id = convert_pn_unit_iids_to_pn_unit_ids([pn_unit_iid])[0] - transformation_id = extract_transformation_id_from_pn_unit_iid(pn_unit_iid) - - # Set the annotations (for our purposes, chains and pn_units are the same) - residue_atom_array.set_annotation("chain_id", np.full(residue_atom_array.array_length(), pn_unit_id)) - residue_atom_array.set_annotation("pn_unit_id", np.full(residue_atom_array.array_length(), pn_unit_id)) - residue_atom_array.set_annotation("chain_iid", np.full(residue_atom_array.array_length(), pn_unit_iid)) - residue_atom_array.set_annotation("pn_unit_iid", np.full(residue_atom_array.array_length(), pn_unit_iid)) - residue_atom_array.set_annotation( - "transformation_id", np.full(residue_atom_array.array_length(), transformation_id) - ) - - # Everything is full occupancy - residue_atom_array.set_annotation("occupancy", np.full(residue_atom_array.array_length(), 1.0)) - - # Set the residue ID - residue_atom_array.set_annotation("res_id", np.full(residue_atom_array.array_length(), res_id)) - - if atom_array is None: - atom_array = residue_atom_array - else: - atom_array += residue_atom_array - - ####################################################################################################### - # Iterate over the batches of coordinates, and create a new AtomArray for each batch - ####################################################################################################### - atom_arrays = [] - for coords in xyz: - # ...create a new AtomArray for each batch, with new coordinates - batch_atom_array = atom_array.copy() - batch_atom_array.coord = coords - atom_arrays.append(batch_atom_array) - - # Convert to a stack - atom_array_stack = struc.stack(atom_arrays) - - return atom_array_stack - - -def filter_residue_atoms( - residue_atom_array: struc.AtomArray, chem_type: str, elements: np.ndarray[str] -) -> struc.AtomArray: - """ - Filter out unwanted atoms from a residue (e.g.., hydrogens, leaving groups) - - Parameters: - - residue_atom_array (struc.AtomArray): The AtomArray to filter. - - chem_type (str): Type of the chemical chain. - - elements (np.array): Element types (as strings, e.g., "C") for each atom in the residue. - - Returns: - - struc.AtomArray: Filtered AtomArray. - """ - # ...capitalize the chemical type - chem_type = chem_type.upper() - - # ...remove hydrogens and deuteriums - residue_atom_array = residue_atom_array[~np.isin(residue_atom_array.element, HYDROGEN_LIKE_SYMBOLS)] - - # If the arrays match, we return the residue as-is - if len(residue_atom_array) == len(elements) and all(elements == residue_atom_array.element): - return residue_atom_array - - # ...otherwise, we will try to remove specific atoms until the arrays match - if ( - chem_type in AA_LIKE_CHEM_TYPES - or chem_type in POLYPEPTIDE_L_CHEM_TYPES - or chem_type in POLYPEPTIDE_D_CHEM_TYPES - ): - # ...try removing OXT in non-terminal polypeptides - candidate_residue_atom_array = residue_atom_array[residue_atom_array.atom_name != "OXT"] - if len(candidate_residue_atom_array) == len(elements) and all(elements == candidate_residue_atom_array.element): - return candidate_residue_atom_array - - elif chem_type in RNA_LIKE_CHEM_TYPES or chem_type in DNA_LIKE_CHEM_TYPES: - # ...try removing OP3 in RNA or DNA - candidate_residue_atom_array = residue_atom_array[residue_atom_array.atom_name != "OP3"] - if len(candidate_residue_atom_array) == len(elements) and all(elements == candidate_residue_atom_array.element): - return candidate_residue_atom_array - - # ...as a last resort, try and match the elements by sliding a window over the residue - for start in range(len(residue_atom_array) - len(elements) + 1): - current_slice = residue_atom_array[start : start + len(elements)] - if all(elements == current_slice.element): - return current_slice - - raise ValueError( - f"Could not find a matching AtomArray for residue {residue_atom_array.res_name[0]} with elements {elements}" - ) - - def to_parquet_with_metadata(df: pd.DataFrame, filepath: PathLike, **kwargs: Any) -> None: """Convenience wrapper around df.to_parquet that saves table-wide metadata (df.attrs) to the parquet file. @@ -473,3 +295,63 @@ def read_parquet_with_metadata(filepath: PathLike, **kwargs: Any) -> pd.DataFram df.attrs = metadata_dict return df + + +def parse_sharding_pattern(sharding_pattern: str) -> list[tuple[int, int]]: + """Parse a sharding pattern string into directory levels. + + Args: + sharding_pattern: String like "/1:2/0:2/" where each /start:end/ defines a directory level + - start:end defines the character range to use for that directory level + - Example: "/1:2/0:2/" means use chars 1-2 for first dir, then chars 0-2 for second dir + + Returns: + List of (start, end) tuples for each directory level + """ + # Find all patterns like /start:end/ using a non-consuming lookahead + pattern = r"/(\d+):(\d+)(?=/)" + matches = [] + for match in re.finditer(pattern, sharding_pattern): + matches.append((int(match.group(1)), int(match.group(2)))) + + if not matches: + raise ValueError(f"Invalid sharding pattern format: {sharding_pattern}. Expected format like '/1:2/0:2/'") + + return matches + + +def apply_sharding_pattern(path: str, sharding_pattern: str | None = None) -> Path: + """Apply a sharding pattern to construct a file path. + + Args: + path: The base path or identifier (e.g., PDB ID) + sharding_pattern: Pattern for organizing files in subdirectories + - "/1:2/": Use characters 1-2 for first directory level + - "/1:2/0:2/": Use chars 1-2 for first dir, then chars 0-2 for second dir + - None: No sharding (default) + + Returns: + Path: The constructed file path with sharding applied + """ + if sharding_pattern and sharding_pattern.startswith("/"): + # General sharding pattern: /start:end/start:end/... + try: + shard_levels = parse_sharding_pattern(sharding_pattern) + except ValueError as e: + raise ValueError(f"Invalid sharding pattern: {e}") from e + + # Build the sharded path + current_path = Path() + + for start, end in shard_levels: + if end > len(path): + raise ValueError(f"Sharding range {start}:{end} exceeds path length {len(path)} for path '{path}'") + shard_dir = path[start:end] + current_path = current_path / shard_dir + + final_path = current_path / path + else: + # Default behavior: no sharding + final_path = Path(path) + + return final_path diff --git a/src/atomworks/ml/utils/misc.py b/src/atomworks/ml/utils/misc.py index bddabcf8..0ce3203e 100644 --- a/src/atomworks/ml/utils/misc.py +++ b/src/atomworks/ml/utils/misc.py @@ -218,7 +218,7 @@ def masked_mean( tensor([3., 5.]) # float32 Reference: - - AF2 Multimer Code (https://github.com/google-deepmind/alphafold/blob/f251de6613cb478207c732bf9627b1e853c99c2f/alphafold/model/utils.py#L79) + `AF2 Multimer Code `_ """ # Drop the last channel of the mask if specified diff --git a/src/atomworks/ml/utils/rng.py b/src/atomworks/ml/utils/rng.py index 8f4d1d2a..7bd00701 100644 --- a/src/atomworks/ml/utils/rng.py +++ b/src/atomworks/ml/utils/rng.py @@ -55,52 +55,62 @@ def rng_state( rng_state_dict: dict[str, Any] | None = None, include_cuda: bool = True ) -> Generator[dict[str, Any], None, None]: """A context manager that resets the global random state on exit to what it was before entering. + Within the context manager, the RNG states are set to the provided rng state in the dictionary. It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators. Args: - - rng_state_dict (dict[str, Any] | None): A dictionary of RNG states to set. It can have the following keys: + rng_state_dict: A dictionary of RNG states to set. It can have the following keys: + - "torch": The state of the PyTorch RNG. - "torch.cuda": The state of the PyTorch CUDA RNG. - "numpy": The state of the Numpy RNG. - "python": The state of the Python built-in RNG. + If no rng_state_dict is provided, the RNG states are set to the current state of the RNGs. If the rng_state_dict only contains a subset of the RNG states, the other RNG states are set to the current state of the RNGs. - - include_cuda (bool): Whether to allow this function to also control the `torch.cuda` random number generator. - Set this to ``False`` when using the function in a forked process where CUDA re-initialization is + include_cuda: Whether to allow this function to also control the torch.cuda random number generator. + Set this to False when using the function in a forked process where CUDA re-initialization is prohibited. Defaults to True. Example: - - ``` - # Outside the context manager - print("NumPy:", np.random.random(3)) # [0.04810046 0.99270597 0.70612995] - print("PyTorch:", torch.rand(3)) # tensor([0.1405, 0.4602, 0.4284]) - print("Python random:", [random.random() for _ in range(3)]) # [0.7406435863188185, 0.5632059276194807, 0.8537007637060476] - - # Inside the context manager with fixed seeds - with rng_state(create_rng_state_from_seeds(np_seed=42, torch_seed=42, py_seed=42)) as rng_state_dict: - my_state = serialize_rng_state_dict(rng_state_dict) - print("\nWithin context manager:") - print("NumPy:", np.random.random(3)) # [0.37454012 0.95071431 0.73199394] - print("PyTorch:", torch.rand(3)) # tensor([0.8823, 0.9150, 0.3829]) - print("Python random:", [random.random() for _ in range(3)]) # [0.6394267984578837, 0.025010755222666936, 0.27502931836911926] - - # Back to the original state outside the context manager - print("\nBack outside the context manager:") - print("NumPy:", np.random.random(3)) # [0.75479377 0.99594641 0.70411424] - print("PyTorch:", torch.rand(3)) # tensor([0.2757, 0.5345, 0.1754]) - print("Python random:", [random.random() for _ in range(3)]) # [0.2194923914916147, 0.8731837332486028, 0.47700011905124995] - - # Inside the context manager with fixed seeds - with rng_state(eval(my_state)): - print("\nWithin context manager:") - print("NumPy:", np.random.random(3)) # [0.37454012 0.95071431 0.73199394] - print("PyTorch:", torch.rand(3)) # tensor([0.8823, 0.9150, 0.3829]) - print("Python random:", [random.random() for _ in range(3)]) # [0.6394267984578837, 0.025010755222666936, 0.27502931836911926] - ``` + .. code-block:: python + + # Outside the context manager + print("NumPy:", np.random.random(3)) # [0.04810046 0.99270597 0.70612995] + print("PyTorch:", torch.rand(3)) # tensor([0.1405, 0.4602, 0.4284]) + print( + "Python random:", [random.random() for _ in range(3)] + ) # [0.7406435863188185, 0.5632059276194807, 0.8537007637060476] + + # Inside the context manager with fixed seeds + with rng_state(create_rng_state_from_seeds(np_seed=42, torch_seed=42, py_seed=42)) as rng_state_dict: + my_state = serialize_rng_state_dict(rng_state_dict) + print("\nWithin context manager:") + print("NumPy:", np.random.random(3)) # [0.37454012 0.95071431 0.73199394] + print("PyTorch:", torch.rand(3)) # tensor([0.8823, 0.9150, 0.3829]) + print( + "Python random:", [random.random() for _ in range(3)] + ) # [0.6394267984578837, 0.025010755222666936, 0.27502931836911926] + + # Back to the original state outside the context manager + print("\nBack outside the context manager:") + print("NumPy:", np.random.random(3)) # [0.75479377 0.99594641 0.70411424] + print("PyTorch:", torch.rand(3)) # tensor([0.2757, 0.5345, 0.1754]) + print( + "Python random:", [random.random() for _ in range(3)] + ) # [0.2194923914916147, 0.8731837332486028, 0.47700011905124995] + + # Inside the context manager with fixed seeds + with rng_state(eval(my_state)): + print("\nWithin context manager:") + print("NumPy:", np.random.random(3)) # [0.37454012 0.95071431 0.73199394] + print("PyTorch:", torch.rand(3)) # tensor([0.8823, 0.9150, 0.3829]) + print( + "Python random:", [random.random() for _ in range(3)] + ) # [0.6394267984578837, 0.025010755222666936, 0.27502931836911926] """ # Collect previous states prev_states = capture_rng_states(include_cuda) diff --git a/src/atomworks/ml/utils/timer.py b/src/atomworks/ml/utils/timer.py index fb3c9d2f..66eaae01 100644 --- a/src/atomworks/ml/utils/timer.py +++ b/src/atomworks/ml/utils/timer.py @@ -38,25 +38,24 @@ def timeout(timeout: float | int | None = None, strategy: Literal["signal", "sub def do_nothing(*args, **kwargs) -> Callable: - """ - A decorator that does nothing and simply returns the original function. + """A decorator that does nothing and simply returns the original function. This decorator can be used as a placeholder or for testing purposes when you want to conditionally apply decorators without changing the code structure. Returns: - Callable: A decorator function that returns the original function unchanged. + A decorator function that returns the original function unchanged. Example: - ```python - @do_nothing_decorator() - def my_function(): - return "Hello, World!" + .. code-block:: python + + @do_nothing_decorator() + def my_function(): + return "Hello, World!" - # or: - do_nothing(bla=123, blub=456)(my_function) - ``` + # or: + do_nothing(bla=123, blub=456)(my_function) """ def decorator(func: Callable) -> Callable: diff --git a/src/atomworks/ml/utils/token.py b/src/atomworks/ml/utils/token.py index 606924ff..7ca6ecbf 100644 --- a/src/atomworks/ml/utils/token.py +++ b/src/atomworks/ml/utils/token.py @@ -329,7 +329,7 @@ def get_af3_token_center_masks(atom_array: AtomArray) -> np.ndarray: np.ndarray: A boolean mask indicating the center atoms of the tokens in the atom array. Reference: - - AF3: https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf + `AF3 Supplementary Information `_ """ assert ( @@ -378,7 +378,7 @@ def get_af3_token_center_coords(atom_array: AtomArray) -> np.ndarray: np.ndarray: The center coordinates of the tokens in the atom array. Reference: - - AF3: https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf + `AF3 Supplementary Information `_ Example: >>> # Contrived example showing only a few tokens and annotations per residue for illustration diff --git a/src/atomworks_cli/pdb.py b/src/atomworks_cli/pdb.py index 5e8489f0..8a5908f9 100644 --- a/src/atomworks_cli/pdb.py +++ b/src/atomworks_cli/pdb.py @@ -20,7 +20,17 @@ def _normalize_pdb_id(pdb_id: str) -> str: - """Return a normalized, lower-case 4-char PDB id or raise ValueError.""" + """Return a normalized, lower-case 4-char PDB id or raise ValueError. + + Args: + pdb_id: The PDB ID to normalize. + + Returns: + Normalized lowercase PDB ID. + + Raises: + ValueError: If the PDB ID is invalid. + """ pdb_id = pdb_id.strip().lower() if not PDB_ID_REGEX.match(pdb_id): raise ValueError(f"Invalid PDB id: {pdb_id}") @@ -31,6 +41,12 @@ def _pdb_id_to_relpath(pdb_id: str) -> Path: """Map a PDB id to its relative mmCIF path under the divided layout. Example: '1a0i' -> 'a0/1a0i.cif.gz' + + Args: + pdb_id: The PDB ID to map. + + Returns: + The relative path to the mmCIF file. """ pid = _normalize_pdb_id(pdb_id) subdir = pid[1:3] @@ -38,7 +54,15 @@ def _pdb_id_to_relpath(pdb_id: str) -> Path: def _run_rsync_list(remote_path: str, port: int | None) -> tuple[bool, str]: - """Try to list a remote rsync path and return success and output/error.""" + """Try to list a remote rsync path and return success and output/error. + + Args: + remote_path: The remote rsync path to list. + port: The port to use for rsync connection. + + Returns: + Tuple of (success, output) where success is a boolean and output is the stdout/stderr. + """ cmd = ["rsync", "--list-only"] if port is not None: cmd.extend(["--port", str(port)]) diff --git a/src/atomworks_cli/setup.py b/src/atomworks_cli/setup.py index cdb616f1..a7550181 100644 --- a/src/atomworks_cli/setup.py +++ b/src/atomworks_cli/setup.py @@ -21,7 +21,7 @@ """The URL for the latest AtomWorks test pack. Should be untared in `tests/data/shared`.""" METADATA_URL = f"{IPD_DOWNLOAD_URL}/pdb_metadata_latest.tar.gz" -"""The URL for the latest AtomWorks PDB metadata. Should be untared at the specifided location.""" +"""The URL for the latest AtomWorks PDB metadata. Should be untared at the specified location.""" app = typer.Typer(help="Setup utilities for AtomWorks.") diff --git a/tests/conftest.py b/tests/conftest.py index 24bc0832..bfaf97e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,11 @@ # Conditional skip markers ---------------------------------------------------------- def _is_on_digs() -> bool: + """Check if running on DIGS infrastructure. + + Returns: + True if running on DIGS infrastructure, False otherwise. + """ return os.path.exists("/software/containers/versions/rf_diffusion_aa/ipd.txt") @@ -22,6 +27,11 @@ def _is_on_digs() -> bool: def _is_on_github_runner() -> bool: + """Check if running on GitHub Actions runner. + + Returns: + True if running on GitHub Actions runner, False otherwise. + """ return os.environ.get("GITHUB_ACTIONS", "false") == "true" @@ -32,6 +42,11 @@ def _is_on_github_runner() -> bool: def _has_internet_connection() -> bool: + """Check if internet connection is available. + + Returns: + True if internet connection is available, False otherwise. + """ try: # Try to connect to a well-known DNS server (Google's) socket.create_connection(("8.8.8.8", 53), timeout=2) @@ -44,6 +59,11 @@ def _has_internet_connection() -> bool: def _has_gpu() -> bool: + """Check if GPU is available. + + Returns: + True if GPU is available, False otherwise. + """ import torch return torch.cuda.is_available() diff --git a/tests/data/ml/af3_model_outs_protein_dna.pkl b/tests/data/ml/af3_model_outs_protein_dna.pkl deleted file mode 100644 index af0f1bbc..00000000 Binary files a/tests/data/ml/af3_model_outs_protein_dna.pkl and /dev/null differ diff --git a/tests/data/ml/af3_model_outs_protein_ligand.pkl b/tests/data/ml/af3_model_outs_protein_ligand.pkl deleted file mode 100644 index 3cd8b019..00000000 Binary files a/tests/data/ml/af3_model_outs_protein_ligand.pkl and /dev/null differ diff --git a/tests/data/ml/pdb_interfaces/metadata.parquet b/tests/data/ml/pdb_interfaces/metadata.parquet index 03f06f79..70358777 100644 Binary files a/tests/data/ml/pdb_interfaces/metadata.parquet and b/tests/data/ml/pdb_interfaces/metadata.parquet differ diff --git a/tests/data/ml/pdb_pn_units/metadata.parquet b/tests/data/ml/pdb_pn_units/metadata.parquet index 4f438f7a..f4dbc5b3 100644 Binary files a/tests/data/ml/pdb_pn_units/metadata.parquet and b/tests/data/ml/pdb_pn_units/metadata.parquet differ diff --git a/tests/io/components/test_caching.py b/tests/io/components/test_caching.py index b0f759e8..ef42386a 100644 --- a/tests/io/components/test_caching.py +++ b/tests/io/components/test_caching.py @@ -7,7 +7,7 @@ from tests.io.conftest import get_pdb_path TEST_CASES = [ - "1A7J", # Contains an unusual operation expression for assembly building + "4NDZ", # 29K atoms, large enough to test caching without too much variance ] @@ -93,8 +93,8 @@ def different_args_parse(): normal_result["assemblies"][assembly_id], cached_result["assemblies"][assembly_id], annotations_to_compare ) - # Assert that the cached result is at least 1.5x faster than the normal result - assert cached_elapsed_time < normal_elapsed_time / 1.5 + # Assert that the cached result is at least 3x faster than the normal result + assert cached_elapsed_time < normal_elapsed_time / 3 # Assert that the result with different arguments is similar to the normal elapsed time assert abs(different_args_elapsed_time - normal_elapsed_time) < normal_elapsed_time * 0.5 diff --git a/tests/io/tools/test_inference_processing.py b/tests/io/tools/test_inference_processing.py index 68eb714b..f80de814 100644 --- a/tests/io/tools/test_inference_processing.py +++ b/tests/io/tools/test_inference_processing.py @@ -133,7 +133,6 @@ def custom_residues(): return { "C:0": { "path": f"{TEST_DATA_IO}/example_ncaa.cif", - "chain_type": "polypeptide(l)", } } diff --git a/tests/io/utils/test_io.py b/tests/io/utils/test_io.py index 78ef52ab..cf34f6aa 100644 --- a/tests/io/utils/test_io.py +++ b/tests/io/utils/test_io.py @@ -511,7 +511,6 @@ def custom_residues(): return { "C:0": { "path": f"{TEST_DATA_IO}/example_ncaa.cif", - "chain_type": "polypeptide(l)", } } diff --git a/tests/io/utils/test_selection_utils.py b/tests/io/utils/test_selection_utils.py index 4f7fa94a..888f2685 100644 --- a/tests/io/utils/test_selection_utils.py +++ b/tests/io/utils/test_selection_utils.py @@ -164,7 +164,7 @@ def test_parse_selection_string(selection_string, pymol_string, expected_selecti from_pymol_string = parse_pymol_string(pymol_string) assert from_selection_string == expected_selection assert from_pymol_string == expected_selection - assert from_selection_string == AtomSelection.from_str(selection_string) + assert from_selection_string == AtomSelection.from_selection_str(selection_string) def test_get_mask_from_selection_string(basic_atom_array: struc.AtomArray): @@ -172,13 +172,13 @@ def test_get_mask_from_selection_string(basic_atom_array: struc.AtomArray): mask = get_mask_from_selection_string(basic_atom_array, "A/ALA/1/CA") expected_mask = np.array([False, True, False, False, False, False], dtype=bool) assert np.array_equal(mask, expected_mask) - assert np.array_equal(mask, AtomSelection.from_str("A/ALA/1/CA").get_mask(basic_atom_array)) + assert np.array_equal(mask, AtomSelection.from_selection_str("A/ALA/1/CA").get_mask(basic_atom_array)) # Test partial match mask = get_mask_from_selection_string(basic_atom_array, "A/ALA") expected_mask = np.array([True, True, False, False, False, False], dtype=bool) assert np.array_equal(mask, expected_mask) - assert np.array_equal(mask, AtomSelection.from_str("A/ALA").get_mask(basic_atom_array)) + assert np.array_equal(mask, AtomSelection.from_selection_str("A/ALA").get_mask(basic_atom_array)) # Test no match raises ValueError with pytest.raises(ValueError, match="No atoms found for selection: A/VAL/1/CB"): @@ -192,18 +192,18 @@ def test_get_mask_from_selection_string(basic_atom_array: struc.AtomArray): @pytest.mark.parametrize("contig_test_case", CONTIG_TEST_CASES) -def test_get_mask_from_contig_string(contig_test_case: str): +def test_get_mask_from_contig(contig_test_case: str): contig_string, expected_length = contig_test_case - selection_stack = AtomSelectionStack.from_contig_string(contig_string) + selection_stack = AtomSelectionStack.from_contig(contig_string) assert isinstance(selection_stack, AtomSelectionStack) assert len(selection_stack.selections) == expected_length @pytest.mark.parametrize("contig_test_case", CONTIG_TEST_CASES) -def test_get_mask_from_contig_string_with_atom_array(basic_atom_array: struc.AtomArray, contig_test_case: str): +def test_get_mask_from_contig_with_atom_array(basic_atom_array: struc.AtomArray, contig_test_case: str): contig_string, expected_length = contig_test_case - selection_stack = AtomSelectionStack.from_contig_string(contig_string) + selection_stack = AtomSelectionStack.from_contig(contig_string) residue_starts = get_residue_starts(basic_atom_array) mask = selection_stack.get_mask(basic_atom_array) @@ -214,7 +214,7 @@ def test_get_mask_from_contig_string_with_atom_array(basic_atom_array: struc.Ato def test_atom_selection_stack_get_center_of_mass(basic_atom_array: struc.AtomArray): """Test that get_center_of_mass returns the correct center for selected atoms.""" - selection_stack = AtomSelectionStack.from_contig_string("A1-2, B3-3") + selection_stack = AtomSelectionStack.from_contig("A1-2, B3-3") center_of_mass = selection_stack.get_center_of_mass(basic_atom_array) expected_center = np.mean(basic_atom_array[selection_stack.get_mask(basic_atom_array)].coord, axis=0) assert np.allclose(center_of_mass, expected_center) @@ -229,7 +229,7 @@ def test_atom_selection_stack_get_center_of_mass(basic_atom_array: struc.AtomArr def test_atom_selection_stack_get_principle_components(basic_atom_array: struc.AtomArray): """Test that get_principle_components returns correct principal axes for selected atoms.""" - selection_stack = AtomSelectionStack.from_contig_string("A1-2, B3-3") + selection_stack = AtomSelectionStack.from_contig("A1-2, B3-3") # AtomArray case pcs = selection_stack.get_principal_components(basic_atom_array) coords = basic_atom_array[selection_stack.get_mask(basic_atom_array)].coord @@ -252,5 +252,23 @@ def test_atom_selection_stack_get_principle_components(basic_atom_array: struc.A assert np.allclose(np.abs(pcs_stack[i]), np.abs(expected_pcs)) +def test_atom_selection_stack_from_query_ranges(basic_atom_array: struc.AtomArray) -> None: + """Select a range of residue IDs within a chain using extended syntax.""" + selection_stack = AtomSelectionStack.from_query("A/*/1-2") + mask = selection_stack.get_mask(basic_atom_array) + # Expect residues 1 and 2 in chain A (first four atoms) + expected_mask = np.array([True, True, True, True, False, False], dtype=bool) + assert np.array_equal(mask, expected_mask) + + +def test_atom_selection_stack_from_query_multiple_tokens(basic_atom_array: struc.AtomArray) -> None: + """Union of multiple selection tokens.""" + selection_stack = AtomSelectionStack.from_query(["A/ALA", "B/VAL"]) # include ALA in chain A and VAL in chain B + mask = selection_stack.get_mask(basic_atom_array) + # Expect ALA in chain A (first two atoms) and VAL in chain B (last two atoms) + expected_mask = np.array([True, True, False, False, True, True], dtype=bool) + assert np.array_equal(mask, expected_mask) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/ml/conftest.py b/tests/ml/conftest.py index 7005cbe9..48e55947 100644 --- a/tests/ml/conftest.py +++ b/tests/ml/conftest.py @@ -7,16 +7,15 @@ import pytest from dotenv import load_dotenv -from atomworks.constants import AF3_EXCLUDED_LIGANDS_REGEX, _load_env_var +from atomworks.constants import AF3_EXCLUDED_LIGANDS_REGEX, PDB_MIRROR_PATH, _load_env_var from atomworks.io.tools.inference import SequenceComponent -from atomworks.ml.datasets.datasets import ConcatDatasetWithID, PandasDataset, StructuralDatasetWrapper -from atomworks.ml.datasets.parsers import ( - GenericDFParser, - InterfacesDFParser, - PNUnitsDFParser, - ValidationDFParserLikeAF3, +from atomworks.ml.datasets.datasets import ConcatDatasetWithID, PandasDataset +from atomworks.ml.datasets.loaders import ( + create_base_loader, + create_loader_with_interfaces_and_pn_units_to_score, + create_loader_with_query_pn_units, ) -from atomworks.ml.datasets.parsers.base import DEFAULT_CIF_PARSER_ARGS +from atomworks.ml.datasets.parsers.base import DEFAULT_PARSER_ARGS from atomworks.ml.pipelines.af3 import build_af3_transform_pipeline from atomworks.ml.pipelines.rf2aa import build_rf2aa_transform_pipeline from atomworks.ml.preprocessing.constants import TRAINING_SUPPORTED_CHAIN_TYPES_INTS @@ -37,14 +36,20 @@ def pytest_configure(config): dotenv_path = os.path.join(current_dir, "../..", ".env") # Load the environment variables - load_dotenv(dotenv_path) - - -if not os.environ.get("PDB_MIRROR_PATH") or not os.path.exists(os.environ.get("PDB_MIRROR_PATH")): - raise pytest.UsageError( - "ERROR: Required PDB_MIRROR_PATH environment variable not set. " - "Please set this in the .env file or in your shell environment." - ) + load_dotenv(dotenv_path, override=True) + + # We require a PDB mirror (of at least a subset of the PDB) for the AtomWorks.ml tests + pdb_mirror_path = os.environ.get("PDB_MIRROR_PATH") + if not pdb_mirror_path: + raise pytest.UsageError( + "ERROR: Required PDB_MIRROR_PATH environment variable not set. " + "Please set this in the .env file or in your shell environment." + ) + if not os.path.exists(pdb_mirror_path): + raise pytest.UsageError( + f"ERROR: PDB_MIRROR_PATH is set to '{pdb_mirror_path}', but this path does not exist. " + "Please check your .env file or shell environment." + ) ########################################################################################## @@ -115,13 +120,13 @@ def interfaces_df(): # AF2 Distillation Facebook, with and without table-wide metadata (to test metadata handling) @pytest.fixture(scope="session") -def af2_distillation_facebook_df_no_metadata(): +def af2_distillation_df_no_metadata(): path = TEST_DATA_ML / "af2_distillation" / "metadata.parquet" return pd.read_parquet(path) @pytest.fixture(scope="session") -def af2_distillation_facebook_df_with_metadata(): +def af2_distillation_df_with_metadata(): df = read_parquet_with_metadata(TEST_DATA_ML / "af2_distillation" / "metadata.parquet") df.attrs["base_path"] = str(TEST_DATA_ML / "af2_distillation" / "cif") return df @@ -134,7 +139,7 @@ def af3_validation_df(): ########################################################################################## -# + ------------------------------------ Datasets -------------------------------------- + +# + ------------------------------------ Filters -------------------------------------- + ########################################################################################## SHARED_TEST_FILTERS = [ @@ -159,72 +164,19 @@ def af3_validation_df(): TEST_DIFFUSION_BATCH_SIZE = 32 # Set to a value other than default (48) for testing -# +--------------------------------------------------------------------------+ -# Base PandasDataset fixtures -# +--------------------------------------------------------------------------+ - - -@pytest.fixture(scope="session") -def pn_units_pandas_dataset(pn_units_df): - return PandasDataset( - name="pn_units", - id_column="example_id", - data=pn_units_df, - filters=SHARED_TEST_FILTERS + TEST_PN_UNITS_FILTERS, - columns_to_load=None, # Load all columns - ) - - -@pytest.fixture(scope="session") -def interfaces_pandas_dataset(interfaces_df): - return PandasDataset( - name="interfaces", - id_column="example_id", - data=interfaces_df, - filters=SHARED_TEST_FILTERS + TEST_INTERFACES_FILTERS, - columns_to_load=None, # Load all columns - ) - -@pytest.fixture(scope="session") -def validation_pandas_dataset(af3_validation_df): - return PandasDataset( - name="validation", - data=af3_validation_df, - id_column="example_id", - columns_to_load=None, # Load all columns - ) - - -@pytest.fixture(scope="session") -def distillation_pandas_dataset_no_metadata(af2_distillation_facebook_df_no_metadata): - return PandasDataset( - data=af2_distillation_facebook_df_no_metadata, - id_column="example_id", - name="af2fb_distillation", - columns_to_load=["example_id", "sequence_hash", "path"], - ) +########################################################################################## +# + ------------------------------------ Datasets -------------------------------------- + +########################################################################################## @pytest.fixture(scope="session") -def distillation_pandas_dataset_with_metadata(af2_distillation_facebook_df_with_metadata): +def rf2aa_pn_units_dataset(pn_units_df): return PandasDataset( - data=af2_distillation_facebook_df_with_metadata, + data=pn_units_df, + name="rf2aa_pn_units", id_column="example_id", - name="af2fb_distillation", - columns_to_load=["example_id", "sequence_hash", "path"], - ) - - -# +--------------------------------------------------------------------------+ -# RF2AA Dataset fixtures -# +--------------------------------------------------------------------------+ - - -@pytest.fixture(scope="session") -def rf2aa_pn_units_dataset(pn_units_pandas_dataset): - return StructuralDatasetWrapper( - dataset_parser=PNUnitsDFParser(), + loader=create_loader_with_query_pn_units(pn_unit_iid_colnames=["q_pn_unit_iid"], base_path=PDB_MIRROR_PATH), transform=build_rf2aa_transform_pipeline( protein_msa_dirs=PROTEIN_MSA_DIRS, rna_msa_dirs=RNA_MSA_DIRS, @@ -237,16 +189,20 @@ def rf2aa_pn_units_dataset(pn_units_pandas_dataset): template_lookup_path=TEMPLATE_LOOKUP, template_base_dir=TEMPLATE_DIR, ), - dataset=pn_units_pandas_dataset, - cif_parser_args={"cache_dir": None}, save_failed_examples_to_dir=None, + filters=SHARED_TEST_FILTERS + TEST_PN_UNITS_FILTERS, ) @pytest.fixture(scope="session") -def rf2aa_interfaces_dataset(interfaces_pandas_dataset): - return StructuralDatasetWrapper( - dataset_parser=InterfacesDFParser(), +def rf2aa_interfaces_dataset(interfaces_df): + return PandasDataset( + data=interfaces_df, + name="rf2aa_interfaces", + id_column="example_id", + loader=create_loader_with_query_pn_units( + pn_unit_iid_colnames=["pn_unit_1_iid", "pn_unit_2_iid"], base_path=PDB_MIRROR_PATH + ), transform=build_rf2aa_transform_pipeline( protein_msa_dirs=PROTEIN_MSA_DIRS, rna_msa_dirs=RNA_MSA_DIRS, @@ -259,9 +215,8 @@ def rf2aa_interfaces_dataset(interfaces_pandas_dataset): template_lookup_path=TEMPLATE_LOOKUP, template_base_dir=TEMPLATE_DIR, ), - dataset=interfaces_pandas_dataset, - cif_parser_args={"cache_dir": None}, save_failed_examples_to_dir=None, + filters=SHARED_TEST_FILTERS + TEST_INTERFACES_FILTERS, ) @@ -271,10 +226,17 @@ def rf2aa_pdb_dataset(rf2aa_pn_units_dataset, rf2aa_interfaces_dataset): @pytest.fixture(scope="session") -def rf2aa_validation_dataset(validation_pandas_dataset): - """Create a StructuralDatasetWrapper for RF2AA validation.""" - return StructuralDatasetWrapper( - dataset_parser=ValidationDFParserLikeAF3(), +def rf2aa_validation_dataset(af3_validation_df): + """Create a PandasDataset for RF2AA validation.""" + return PandasDataset( + data=af3_validation_df, + name="rf2aa_validation", + loader=create_loader_with_interfaces_and_pn_units_to_score( + path_colname="pdb_id", + base_path=str(PDB_MIRROR_PATH), + extension=".cif.gz", + sharding_pattern="/1:3/", + ), transform=build_rf2aa_transform_pipeline( protein_msa_dirs=PROTEIN_MSA_DIRS, rna_msa_dirs=RNA_MSA_DIRS, @@ -287,7 +249,6 @@ def rf2aa_validation_dataset(validation_pandas_dataset): template_lookup_path=TEMPLATE_LOOKUP, template_base_dir=TEMPLATE_DIR, ), - dataset=validation_pandas_dataset, save_failed_examples_to_dir=None, ) @@ -298,9 +259,11 @@ def rf2aa_validation_dataset(validation_pandas_dataset): @pytest.fixture(scope="session") -def af3_pn_units_dataset(pn_units_pandas_dataset): - return StructuralDatasetWrapper( - dataset_parser=PNUnitsDFParser(), +def af3_pn_units_dataset(pn_units_df): + return PandasDataset( + data=pn_units_df, + name="af3_pn_units", + loader=create_loader_with_query_pn_units(pn_unit_iid_colnames=["q_pn_unit_iid"], base_path=PDB_MIRROR_PATH), transform=build_af3_transform_pipeline( protein_msa_dirs=PROTEIN_MSA_DIRS, rna_msa_dirs=RNA_MSA_DIRS, @@ -313,15 +276,19 @@ def af3_pn_units_dataset(pn_units_pandas_dataset): template_lookup_path=TEMPLATE_LOOKUP, template_base_dir=TEMPLATE_DIR, ), - dataset=pn_units_pandas_dataset, save_failed_examples_to_dir=None, + filters=SHARED_TEST_FILTERS + TEST_PN_UNITS_FILTERS, ) @pytest.fixture(scope="session") -def af3_interfaces_dataset(interfaces_pandas_dataset): - return StructuralDatasetWrapper( - dataset_parser=InterfacesDFParser(), +def af3_interfaces_dataset(interfaces_df): + return PandasDataset( + data=interfaces_df, + name="af3_interfaces", + loader=create_loader_with_query_pn_units( + pn_unit_iid_colnames=["pn_unit_1_iid", "pn_unit_2_iid"], base_path=PDB_MIRROR_PATH + ), transform=build_af3_transform_pipeline( protein_msa_dirs=PROTEIN_MSA_DIRS, rna_msa_dirs=RNA_MSA_DIRS, @@ -334,9 +301,8 @@ def af3_interfaces_dataset(interfaces_pandas_dataset): template_lookup_path=TEMPLATE_LOOKUP, template_base_dir=TEMPLATE_DIR, ), - dataset=interfaces_pandas_dataset, - cif_parser_args={"cache_dir": None}, save_failed_examples_to_dir=None, + filters=SHARED_TEST_FILTERS + TEST_INTERFACES_FILTERS, ) @@ -346,9 +312,16 @@ def af3_pdb_dataset(af3_pn_units_dataset, af3_interfaces_dataset): @pytest.fixture(scope="session") -def af3_validation_dataset(validation_pandas_dataset): - return StructuralDatasetWrapper( - dataset_parser=ValidationDFParserLikeAF3(), +def af3_validation_dataset(af3_validation_df): + return PandasDataset( + data=af3_validation_df, + name="af3_validation", + loader=create_loader_with_interfaces_and_pn_units_to_score( + path_colname="pdb_id", + base_path=PDB_MIRROR_PATH, + extension=".cif.gz", + sharding_pattern="/1:3/", + ), transform=build_af3_transform_pipeline( protein_msa_dirs=PROTEIN_MSA_DIRS, rna_msa_dirs=RNA_MSA_DIRS, @@ -360,20 +333,19 @@ def af3_validation_dataset(validation_pandas_dataset): template_lookup_path=TEMPLATE_LOOKUP, template_base_dir=TEMPLATE_DIR, ), - dataset=validation_pandas_dataset, save_failed_examples_to_dir=None, ) @pytest.fixture(scope="session") -def af3_af2fb_distillation_dataset_no_metadata(distillation_pandas_dataset_no_metadata): - return StructuralDatasetWrapper( - dataset=distillation_pandas_dataset_no_metadata, - dataset_parser=GenericDFParser( +def af2_distillation_dataset_no_metadata(af2_distillation_df_no_metadata): + return PandasDataset( + data=af2_distillation_df_no_metadata, + name="af3_af2fb_distillation_no_metadata", + loader=create_base_loader( base_path=str(TEST_DATA_ML / "af2_distillation" / "cif"), extension=".cif", ), - cif_parser_args={}, transform=build_af3_transform_pipeline( protein_msa_dirs=PROTEIN_MSA_DIRS, rna_msa_dirs=[], @@ -387,11 +359,11 @@ def af3_af2fb_distillation_dataset_no_metadata(distillation_pandas_dataset_no_me @pytest.fixture(scope="session") -def af3_af2fb_distillation_dataset_with_metadata(distillation_pandas_dataset_with_metadata): - return StructuralDatasetWrapper( - dataset=distillation_pandas_dataset_with_metadata, - dataset_parser=GenericDFParser(), - cif_parser_args={}, +def af2_distillation_dataset_with_metadata(af2_distillation_df_with_metadata): + return PandasDataset( + data=af2_distillation_df_with_metadata, + name="af3_af2fb_distillation_with_metadata", + loader=create_base_loader(), transform=build_af3_transform_pipeline( protein_msa_dirs=PROTEIN_MSA_DIRS, rna_msa_dirs=[], @@ -405,8 +377,8 @@ def af3_af2fb_distillation_dataset_with_metadata(distillation_pandas_dataset_wit @pytest.fixture(scope="session") -def af3_af2fb_distillation_concat_dataset(af3_af2fb_distillation_dataset_no_metadata): - return ConcatDatasetWithID(datasets=[af3_af2fb_distillation_dataset_no_metadata]) +def af3_af2fb_distillation_concat_dataset(af2_distillation_dataset_no_metadata): + return ConcatDatasetWithID(datasets=[af2_distillation_dataset_no_metadata]) ########################################################################################## @@ -419,16 +391,16 @@ def atom_array(): """ Load a CIF file from somewhere local and return the atom_array """ - merged_cif_parser_args = { - **DEFAULT_CIF_PARSER_ARGS, + parser_args = { + **DEFAULT_PARSER_ARGS, **{ "fix_arginines": False, "add_missing_atoms": False, # this is crucial otherwise the annotations are deleted }, } - merged_cif_parser_args.pop("add_bond_types_from_struct_conn") - merged_cif_parser_args.pop("remove_ccds") - data = cached_parse("6lyz", **merged_cif_parser_args) + parser_args.pop("add_bond_types_from_struct_conn") + parser_args.pop("remove_ccds") + data = cached_parse("6lyz", **parser_args) atom_array = data["atom_array"] return atom_array diff --git a/tests/ml/datasets/test_datasets.py b/tests/ml/datasets/test_datasets.py index 295756b2..ef3693d0 100644 --- a/tests/ml/datasets/test_datasets.py +++ b/tests/ml/datasets/test_datasets.py @@ -4,7 +4,11 @@ import torch from torch.utils.data import SequentialSampler, WeightedRandomSampler -from atomworks.ml.datasets.datasets import ConcatDatasetWithID, PandasDataset, get_row_and_index_by_example_id +from atomworks.ml.datasets.datasets import ( + ConcatDatasetWithID, + PandasDataset, + get_row_and_index_by_example_id, +) from atomworks.ml.samplers import ( MixedSampler, calculate_weights_for_pdb_dataset_df, @@ -20,7 +24,7 @@ def create_dummy_dataset(length: int, name: str, dataset_class: PandasDataset = } ) data.attrs = {"base_path": "/example/base/path"} - return dataset_class(data=data, id_column="example_id", name=name) + return dataset_class(data=data, name=name, id_column="example_id") def test_nested_dummy_datasets(): @@ -52,11 +56,11 @@ def test_nested_dummy_datasets(): assert row.attrs["base_path"] is not None -def test_structural_datasets(rf2aa_interfaces_dataset, rf2aa_pn_units_dataset, rf2aa_pdb_dataset): - # +------------------ Structural Dataset (PandasDataset wrapped with a StructuralDatasetWrapper) ------------------+ +def test_nested_datasets_with_weighted_samplers(rf2aa_interfaces_dataset, rf2aa_pn_units_dataset, rf2aa_pdb_dataset): + # +------------------ Sampler ------------------+ num_examples_per_epoch = 100 - # ...calculate the weights based on the AF-3 weighting methodology + # ... calculate the weights based on the AF-3 weighting methodology b_pn_unit = 0.5 # β_chain b_interface = 0.5 # β_interface alphas = { @@ -73,7 +77,7 @@ def test_structural_datasets(rf2aa_interfaces_dataset, rf2aa_pn_units_dataset, r ) pdb_dataset_weights = torch.cat([pn_units_dataset_weights, interfaces_dataset_weights]) # NOTE: Order matters! - # ...and initialize one sampler for all PDB datasets, using the unified weights + # ... and initialize one sampler for all PDB datasets, using the unified weights pdb_sampler = WeightedRandomSampler( weights=pdb_dataset_weights, num_samples=num_examples_per_epoch, # We later override with proportional number of examples @@ -107,7 +111,7 @@ def test_structural_datasets(rf2aa_interfaces_dataset, rf2aa_pn_units_dataset, r n_examples_per_epoch=100, ) - # ...create a dataset including both datasets + # ... create a dataset including both datasets concat_dataset = ConcatDatasetWithID(datasets=datasets) # +---------------------------- Tests and assertions ----------------------------+ @@ -133,7 +137,7 @@ def test_structural_datasets(rf2aa_interfaces_dataset, rf2aa_pn_units_dataset, r assert len(indices) == 100 # Check that 80% of the indices are from the (second copy of) the pn_units dataset - # ...all idxs >= len(pdb_dataset) should be from pn_units dataset + # (all idxs >= len(pdb_dataset) should be from pn_units dataset) pn_unit_indices = [idx for idx in indices if idx >= len(rf2aa_pdb_dataset)] assert len(pn_unit_indices) == 80 diff --git a/tests/ml/datasets/test_datasets_with_filters.py b/tests/ml/datasets/test_datasets_with_filters.py index eef30a2e..67079d18 100644 --- a/tests/ml/datasets/test_datasets_with_filters.py +++ b/tests/ml/datasets/test_datasets_with_filters.py @@ -64,6 +64,7 @@ def test_filter_no_impact(caplog, pn_units_df): PandasDataset( data=pn_units_df.copy(), filters=filters, + name="no_impact_test_dataset", ) assert "did not remove any rows" in caplog.text, "Warning for no impact filter not raised" @@ -75,6 +76,7 @@ def test_filter_remove_all_rows(pn_units_df): PandasDataset( data=pn_units_df.copy(), filters=filters, + name="remove_all_rows_test_dataset", ) diff --git a/tests/ml/pipelines/test_data_loading_pipelines.py b/tests/ml/pipelines/test_data_loading_pipelines.py index 148199d6..46ea2656 100644 --- a/tests/ml/pipelines/test_data_loading_pipelines.py +++ b/tests/ml/pipelines/test_data_loading_pipelines.py @@ -14,28 +14,33 @@ @pytest.fixture def datasets_to_test( - af3_af2fb_distillation_dataset_with_metadata, - af3_af2fb_distillation_dataset_no_metadata, + af3_pdb_dataset, af3_validation_dataset, + af2_distillation_dataset_with_metadata, + af2_distillation_dataset_no_metadata, rf2aa_validation_dataset, rf2aa_pdb_dataset, - af3_pdb_dataset, ): """Create the list of datasets to test with actual dataset objects.""" return [ { - "dataset": af3_af2fb_distillation_dataset_with_metadata, + "dataset": af3_pdb_dataset, "type": "train", + "num_examples": 5, + }, + { + "dataset": af3_validation_dataset, + "type": "validation", "num_examples": 1, }, { - "dataset": af3_af2fb_distillation_dataset_no_metadata, + "dataset": af2_distillation_dataset_with_metadata, "type": "train", "num_examples": 1, }, { - "dataset": af3_validation_dataset, - "type": "validation", + "dataset": af2_distillation_dataset_no_metadata, + "type": "train", "num_examples": 1, }, { @@ -48,11 +53,6 @@ def datasets_to_test( "type": "train", "num_examples": 1, }, - { - "dataset": af3_pdb_dataset, - "type": "train", - "num_examples": 5, - }, ] diff --git a/tests/ml/pipelines/test_pipeline_regression.py b/tests/ml/pipelines/test_pipeline_regression.py index b2242091..247429f4 100644 --- a/tests/ml/pipelines/test_pipeline_regression.py +++ b/tests/ml/pipelines/test_pipeline_regression.py @@ -14,7 +14,7 @@ from atomworks.enums import ChainType from atomworks.io import parse from atomworks.io.utils.testing import assert_same_atom_array -from atomworks.ml.datasets.parsers.base import DEFAULT_CIF_PARSER_ARGS +from atomworks.ml.datasets.parsers.base import DEFAULT_PARSER_ARGS from atomworks.ml.pipelines.af3 import build_af3_transform_pipeline from atomworks.ml.utils.rng import create_rng_state_from_seeds, rng_state @@ -76,7 +76,7 @@ def instantiate_example(example_name: str): result_dict = parse( filename=file, build_assembly=("1",), - **DEFAULT_CIF_PARSER_ARGS, + **DEFAULT_PARSER_ARGS, ) for chain_id in result_dict["chain_info"]: result_dict["chain_info"][chain_id]["msa_path"] = test_data_dir / example_name / f"{example_name}.a3m" diff --git a/tests/ml/transforms/msa/test_load_msas.py b/tests/ml/transforms/msa/test_load_msas.py index 76673543..7ebbf04b 100644 --- a/tests/ml/transforms/msa/test_load_msas.py +++ b/tests/ml/transforms/msa/test_load_msas.py @@ -4,7 +4,6 @@ from typing import Any import numpy as np -import pandas as pd import pytest from atomworks.enums import ChainType @@ -15,7 +14,6 @@ from atomworks.ml.transforms.msa._msa_loading_utils import get_msa_path from atomworks.ml.transforms.msa.msa import LoadPolymerMSAs from atomworks.ml.utils.testing import cached_parse -from tests.conftest import skip_if_not_on_digs from tests.ml.conftest import PROTEIN_MSA_DIRS, RNA_MSA_DIRS logging.basicConfig(level=logging.INFO) @@ -202,47 +200,6 @@ def test_msas_with_mse(): ), "All proteins should have MSAs after MSE conversion" -@pytest.mark.slow -@pytest.mark.requires_digs -@skip_if_not_on_digs -def test_msa_coverage(pn_units_df): - """Ensure the MSA coverage for the test data set surpasses a certain threshold.""" - - protein_coverage_threshold = 0.95 - rna_coverage_threshold = 0.40 - - result = _evaluate_coverage_for_df(pn_units_df, PROTEIN_MSA_DIRS, RNA_MSA_DIRS) - - assert ( - result["protein_coverage"] >= protein_coverage_threshold - ), f"Protein MSA coverage of {result['protein_coverage']} is below the threshold of {protein_coverage_threshold}" - assert ( - result["rna_coverage"] >= rna_coverage_threshold - ), f"RNA MSA coverage of {result['rna_coverage']} is below the threshold of {rna_coverage_threshold}" - - -def _evaluate_coverage_for_df(df: pd.DataFrame, protein_msa_dirs: list[str], rna_msa_dirs: list[str]): - """Utility function to evaluate the MSA coverage for a DataFrame path.""" - num_proteins = num_proteins_with_msas = num_rna = num_rna_with_msa = 0 - - for row in df.itertuples(): - chain_type = ChainType(row.q_pn_unit_type) - if chain_type.is_protein(): - num_proteins += 1 - if get_msa_path(row.q_pn_unit_processed_entity_non_canonical_sequence, protein_msa_dirs) is not None: - num_proteins_with_msas += 1 - elif chain_type == ChainType.RNA: - num_rna += 1 - # HACK: Replace U with T to match the RNA MSA file names (legacy issue) - sequence = row.q_pn_unit_processed_entity_non_canonical_sequence.replace("U", "T") - if get_msa_path(sequence, rna_msa_dirs) is not None: - num_rna_with_msa += 1 - return { - "protein_coverage": num_proteins_with_msas / num_proteins, - "rna_coverage": num_rna_with_msa / num_rna, - } - - @pytest.mark.parametrize("test_case", MSA_TEST_CASES) def test_inference_msa_transform(test_case): """Test the LoadPolymerMSAsInference transformation pipeline, where we provide MSAs through the `chain_info` field""" diff --git a/tests/ml/utils/test_io.py b/tests/ml/utils/test_io.py deleted file mode 100644 index c7a31add..00000000 --- a/tests/ml/utils/test_io.py +++ /dev/null @@ -1,48 +0,0 @@ -import pickle - -import numpy as np -import pytest -from biotite.structure import AtomArrayStack - -from atomworks.constants import ATOMIC_NUMBER_TO_ELEMENT -from atomworks.ml.utils.io import convert_af3_model_output_to_atom_array_stack -from tests.ml.conftest import TEST_DATA_ML - -# NOTE: Not the "true" model outputs; slightly pre-processed for storage efficiency -TEST_PICKLED_AF3_MODEL_OUTPUTS = ["af3_model_outs_protein_dna.pkl", "af3_model_outs_protein_ligand.pkl"] - - -@pytest.mark.parametrize("file_path", TEST_PICKLED_AF3_MODEL_OUTPUTS) -def test_convert_af3_model_output_to_atom_array_stack(file_path: str): - full_path = TEST_DATA_ML / file_path - - # Load the model outputs - with open(full_path, "rb") as f: - model_outputs = pickle.load(f) - - # Convert the model outputs to an AtomArrayStack - atom_array_stack = convert_af3_model_output_to_atom_array_stack( - atom_to_token_map=model_outputs["atom_to_token_map"], - pn_unit_iids=model_outputs["chain_iids"], - decoded_restypes=model_outputs["decoded_restypes"], - xyz=model_outputs["xyz"], - elements=model_outputs["elements"], - token_is_atomized=model_outputs["token_is_atomized"], - ) - - # Smoke tests - assert isinstance(atom_array_stack, AtomArrayStack) - assert len(atom_array_stack[0]) == len(model_outputs["xyz"]) - - # Assert that the AtomArray has the correct elements - uppercase_elements = np.array( - [ATOMIC_NUMBER_TO_ELEMENT[atomic_number] for atomic_number in model_outputs["elements"]] - ) - assert np.array_equal(atom_array_stack.element, uppercase_elements) - - # Assert that the AtomArray has the correct coordinates for the first (and only) model - assert np.array_equal(atom_array_stack.coord[0], model_outputs["xyz"]) - - -if __name__ == "__main__": - pytest.main(["-s", "-v", "-m not very_slow", __file__])