Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
480 changes: 480 additions & 0 deletions docs/configuration.md

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# XtalPaint – A framework for crystal structure inpainting based on diffusion models


Welcome to the `XtalPaint` Documentation.

## Overview
Expand All @@ -16,10 +15,12 @@ Welcome to the `XtalPaint` Documentation.

## Getting Started

Check out the examples on to run the inpainting pipeline:
Read the [Configuration Guide](configuration.md) to learn how to specify workflows and understand the AiiDA vs. plain-Python execution modes.

Afterwards, check out the examples:

* [With AiiDA integration](examples/running-with-AiiDA.ipynb)
* [Without AiiDA integration](examples/running-wo-AiiDA.ipynb)
- [With AiiDA integration](examples/running-with-AiiDA.ipynb)
- [Without AiiDA integration](examples/running-wo-AiiDA.ipynb)

## Installation

Expand All @@ -40,7 +41,6 @@ uv pip install .[aiida]

Model checkpoints for the retrained versions of MatterGen used in our work can be downloaded from [Hugging Face](https://huggingface.co/t-reents/XtalPaint). Currently, the repository contains the `pos-only` and `TD-pos-only` models discussed in the paper.


## Acknowledgements

This project is developed to perform crystal structure inpainting, currently on top of Microsoft's [MatterGen](https://github.com/microsoft/mattergen). Some parts of the codebase are adapted from MatterGen's implementation (as highlighted in the respective files).
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ markdown_extensions:

nav:
- Home: index.md
- Configuration Guide: configuration.md
- Examples:
- Running with AiiDA: examples/running-with-AiiDA.ipynb
- Running without AiiDA: examples/running-wo-AiiDA.ipynb
Expand Down
70 changes: 25 additions & 45 deletions src/xtalpaint/aiida/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from aiida_workgraph import spec, task
from aiida_workgraph.socket_spec import meta
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

from xtalpaint.aiida.data import (
BatchedStructures,
Expand All @@ -24,19 +23,40 @@
run_mpi_parallel_inpainting_pipeline,
)
from xtalpaint.utils.relaxation_utils import relax_structures
from xtalpaint.utils.structure_utils import (
filter_unique_structures,
refine_structures,
)

evaluate_inpainting_task = task.pythonjob(
outputs=spec.namespace(
metric_results=t.Any,
),
)(evaluate_inpainting)


refine_structures_task = task.pythonjob(
outputs=spec.namespace(structures=t.Any),
)(refine_structures)


uniqueness_filter_task = task.pythonjob(
outputs=spec.namespace(unique_structures=t.Any),
)(filter_unique_structures)


@task.pythonjob(
outputs=spec.namespace(candidates=t.Any),
)
def _generate_inpainting_candidates_task(
def generate_inpainting_candidates_task(
structures: t.Union[Structure, t.Iterable[Structure]] | BatchedStructures,
n_inp: t.Union[
int, t.Tuple[int, int], t.List[int], t.List[t.Tuple[int, int]]
],
element: t.Union[str, t.List[str]],
num_samples: int = 1,
) -> BatchedStructures:
"""Task wrapper for the inpainting candidates generation."""
if isinstance(structures, BatchedStructures):
structures = structures.get_structures("pymatgen")
candidates = generate_inpainting_candidates(
Expand All @@ -49,40 +69,6 @@ def _generate_inpainting_candidates_task(
return {"candidates": BatchedStructures(candidates)}


@task.pythonjob(
outputs=spec.namespace(structures=t.Any),
)
def _refine_structures_task(
structures: t.Union[Structure, t.Iterable[Structure]] | BatchedStructures,
refinement_symprec: float,
primitive: bool = False,
) -> BatchedStructures:
"""Refine structures to standard conventional cells."""
if isinstance(structures, BatchedStructures):
structures = structures.get_structures("pymatgen")

refined_structures = {}
for k, s in structures.items():
analyzer = SpacegroupAnalyzer(s, symprec=refinement_symprec)
try:
refined_structure = analyzer.get_refined_structure()
except Exception:
refined_structure = s

if primitive:
analyzer = SpacegroupAnalyzer(
refined_structure, symprec=refinement_symprec
)
try:
refined_structure = analyzer.get_primitive_structure()
except Exception:
refined_structure = refined_structure

refined_structures[k] = refined_structure

return {"structures": BatchedStructures(refined_structures)}


@task.pythonjob(
outputs=spec.namespace(
structures=t.Any,
Expand All @@ -95,24 +81,18 @@ def _refine_structures_task(
],
)
)
def _inpainting_pipeline_task(
def inpainting_pipeline_task(
structures: t.Union[Structure, t.Iterable[Structure]] | BatchedStructures,
config: dict,
usempi: bool = False,
):
"""Task wrapper for the inpainting pipeline."""
if usempi:
return run_mpi_parallel_inpainting_pipeline(structures, config)

return run_inpainting_pipeline(structures, config)


_evaluate_inpainting_task = task.pythonjob(
outputs=spec.namespace(
metric_results=t.Any,
),
)(evaluate_inpainting)


@task.pythonjob(
outputs=spec.namespace(
structures=t.Any,
Expand All @@ -122,7 +102,7 @@ def _inpainting_pipeline_task(
final_forces=spec.socket(t.Any, required=False),
)
)
def _relaxation_task(
def relaxation_task(
structures: t.Union[
dict[str, Structure], BatchedStructuresData, BatchedStructures
],
Expand Down
4 changes: 4 additions & 0 deletions src/xtalpaint/aiida/workgraphs/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
"""Modules defining workgraphs for inpainting tasks."""

from xtalpaint.aiida.workgraphs.relaxation import relaxation_graph

__all__ = ("relaxation_graph",)
Loading
Loading