From 4b71c371c01019d33f60ff7e6400fb658549b168 Mon Sep 17 00:00:00 2001 From: t-reents Date: Tue, 12 May 2026 10:18:19 +0200 Subject: [PATCH 1/7] Add uniqueness task and restructure the refinement task Moreover, some small restructuring related to the utils. --- src/xtalpaint/aiida/tasks/tasks.py | 70 +++++-------- src/xtalpaint/eval.py | 40 +------- src/xtalpaint/utils/data_utils.py | 36 +++++++ src/xtalpaint/utils/structure_utils.py | 134 +++++++++++++++++++++++++ tests/test_eval.py | 54 ---------- 5 files changed, 198 insertions(+), 136 deletions(-) create mode 100644 src/xtalpaint/utils/structure_utils.py diff --git a/src/xtalpaint/aiida/tasks/tasks.py b/src/xtalpaint/aiida/tasks/tasks.py index 7811da7..77fb555 100644 --- a/src/xtalpaint/aiida/tasks/tasks.py +++ b/src/xtalpaint/aiida/tasks/tasks.py @@ -6,7 +6,6 @@ from aiida_workgraph import spec, task from aiida_workgraph.socket_spec import meta from pymatgen.core.structure import Structure -from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from xtalpaint.aiida.data import ( BatchedStructures, @@ -24,12 +23,32 @@ run_mpi_parallel_inpainting_pipeline, ) from xtalpaint.utils.relaxation_utils import relax_structures +from xtalpaint.utils.structure_utils import ( + filter_unique_structures, + refine_structures, +) + +evaluate_inpainting_task = task.pythonjob( + outputs=spec.namespace( + metric_results=t.Any, + ), +)(evaluate_inpainting) + + +refine_structures_task = task.pythonjob( + outputs=spec.namespace(structures=t.Any), +)(refine_structures) + + +uniqueness_filter_task = task.pythonjob( + outputs=spec.namespace(unique_structures=t.Any), +)(filter_unique_structures) @task.pythonjob( outputs=spec.namespace(candidates=t.Any), ) -def _generate_inpainting_candidates_task( +def generate_inpainting_candidates_task( structures: t.Union[Structure, t.Iterable[Structure]] | BatchedStructures, n_inp: t.Union[ int, t.Tuple[int, int], t.List[int], t.List[t.Tuple[int, int]] @@ -37,6 +56,7 @@ def _generate_inpainting_candidates_task( element: t.Union[str, t.List[str]], num_samples: int = 1, ) -> BatchedStructures: + """Task wrapper for the inpainting candidates generation.""" if isinstance(structures, BatchedStructures): structures = structures.get_structures("pymatgen") candidates = generate_inpainting_candidates( @@ -49,40 +69,6 @@ def _generate_inpainting_candidates_task( return {"candidates": BatchedStructures(candidates)} -@task.pythonjob( - outputs=spec.namespace(structures=t.Any), -) -def _refine_structures_task( - structures: t.Union[Structure, t.Iterable[Structure]] | BatchedStructures, - refinement_symprec: float, - primitive: bool = False, -) -> BatchedStructures: - """Refine structures to standard conventional cells.""" - if isinstance(structures, BatchedStructures): - structures = structures.get_structures("pymatgen") - - refined_structures = {} - for k, s in structures.items(): - analyzer = SpacegroupAnalyzer(s, symprec=refinement_symprec) - try: - refined_structure = analyzer.get_refined_structure() - except Exception: - refined_structure = s - - if primitive: - analyzer = SpacegroupAnalyzer( - refined_structure, symprec=refinement_symprec - ) - try: - refined_structure = analyzer.get_primitive_structure() - except Exception: - refined_structure = refined_structure - - refined_structures[k] = refined_structure - - return {"structures": BatchedStructures(refined_structures)} - - @task.pythonjob( outputs=spec.namespace( structures=t.Any, @@ -95,24 +81,18 @@ def _refine_structures_task( ], ) ) -def _inpainting_pipeline_task( +def inpainting_pipeline_task( structures: t.Union[Structure, t.Iterable[Structure]] | BatchedStructures, config: dict, usempi: bool = False, ): + """Task wrapper for the inpainting pipeline.""" if usempi: return run_mpi_parallel_inpainting_pipeline(structures, config) return run_inpainting_pipeline(structures, config) -_evaluate_inpainting_task = task.pythonjob( - outputs=spec.namespace( - metric_results=t.Any, - ), -)(evaluate_inpainting) - - @task.pythonjob( outputs=spec.namespace( structures=t.Any, @@ -122,7 +102,7 @@ def _inpainting_pipeline_task( final_forces=spec.socket(t.Any, required=False), ) ) -def _relaxation_task( +def relaxation_task( structures: t.Union[ dict[str, Structure], BatchedStructuresData, BatchedStructures ], diff --git a/src/xtalpaint/eval.py b/src/xtalpaint/eval.py index 45c859b..a7d0282 100644 --- a/src/xtalpaint/eval.py +++ b/src/xtalpaint/eval.py @@ -3,7 +3,6 @@ from concurrent.futures import ProcessPoolExecutor from functools import partial -import numpy as np import pandas as pd from mattergen.evaluation.utils.utils import compute_rmsd_angstrom from pymatgen.analysis.structure_matcher import StructureMatcher @@ -12,12 +11,8 @@ from xtalpaint.data import BatchedStructures from xtalpaint.utils import _is_batched_structure - - -def _check_for_nan(structure: Structure) -> bool: - """Check if a pymatgen Structure has NaN values in its atomic positions.""" - positions = structure.cart_coords - return np.isnan(positions).any() +from xtalpaint.utils.data_utils import get_structure_keys +from xtalpaint.utils.structure_utils import check_for_nan_positions def _rmsd(strct1, strct2, normalization_element: str | None = None) -> float: @@ -59,7 +54,7 @@ def _comparison_per_key( comparisons = [] comp_func = COMPARISON_METHODS[metric] for sample_idx, sample in inpainted_structures_grouped[key]: - if _check_for_nan(sample): + if check_for_nan_positions(sample): comparison = None else: comparison = comp_func(sample, ref, **kwargs) @@ -68,35 +63,6 @@ def _comparison_per_key( return comparisons -def get_structure_keys( - structures: BatchedStructures | dict[str, Structure], -) -> tuple[list[str], list[str | None]]: - """Get the unique keys of the structures with out sample indices. - - This is used to group structures that are samples of the same - base structure. - - Args: - structures (dict | BatchedStructures): - The structures to get the keys from. - - Returns: - set[str]: The unique structure keys. - """ - keys = structures.keys() - structure_keys = [] - sample_indices = [] - for key in keys: - if "_sample_" in key: - key, sample_idx = key.split("_sample_") - else: - sample_idx = None - structure_keys.append(key) - sample_indices.append(sample_idx) - - return structure_keys, sample_indices - - def worker_init(ref_structures, inp_structures_grp): """Initialize worker.""" global matcher, reference_structures, inpainted_structures_grouped diff --git a/src/xtalpaint/utils/data_utils.py b/src/xtalpaint/utils/data_utils.py index 2982f2a..9aeb399 100644 --- a/src/xtalpaint/utils/data_utils.py +++ b/src/xtalpaint/utils/data_utils.py @@ -8,8 +8,44 @@ from mattergen.common.data.collate import collate from mattergen.common.data.dataset import CrystalDataset from mattergen.diffusion.data.batched_data import BatchedData +from pymatgen.core.structure import Structure from torch.utils.data import DataLoader +from xtalpaint.data import BatchedStructures + + +def get_structure_keys( + structures: BatchedStructures | dict[str, Structure], +) -> tuple[list[str], list[str | None]]: + """Get the unique keys of the structures with out sample indices. + + This is used to group structures that are samples of the same + base structure. Example keys are ``mp-1234_sample_0``, + ``mp-1234_sample_1``, etc. This function will return ``mp-1234`` as the + unique key for both of these, and the sample indices as ``0`` and ``1`` + respectively. If a key does not have a ``_sample_`` suffix, it is returned + as-is with a sample index of ``None``. + + Args: + structures (dict | BatchedStructures): + The structures to get the keys from. + + Returns: + set[str]: The unique structure keys. + """ + keys = structures.keys() + structure_keys = [] + sample_indices = [] + for key in keys: + if "_sample_" in key: + key, sample_idx = key.split("_sample_") + else: + sample_idx = None + structure_keys.append(key) + sample_indices.append(sample_idx) + + return structure_keys, sample_indices + def create_dataloader( dataset: CrystalDataset, batch_size: int, fix_cell: bool = True diff --git a/src/xtalpaint/utils/structure_utils.py b/src/xtalpaint/utils/structure_utils.py new file mode 100644 index 0000000..c01150c --- /dev/null +++ b/src/xtalpaint/utils/structure_utils.py @@ -0,0 +1,134 @@ +"""Utility functions for structure processing.""" + +import numpy as np +from pymatgen.analysis.structure_matcher import StructureMatcher +from pymatgen.core.structure import Structure +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer + +from xtalpaint.data import BatchedStructures +from xtalpaint.utils import _is_batched_structure +from xtalpaint.utils.data_utils import get_structure_keys + + +def check_for_nan_positions(structure: Structure) -> bool: + """Check if a pymatgen Structure has NaN values in its atomic positions.""" + positions = structure.cart_coords + return np.isnan(positions).any() + + +def refine_structures( + structures: "BatchedStructures | dict[str, Structure]", + symprec: float, + primitive: bool = False, +) -> BatchedStructures: + """Refine structures to standard conventional (or primitive) cells. + + Args: + structures: Input structures. + symprec: Symmetry precision passed to SpacegroupAnalyzer. + primitive: If ``True``, return the primitive cell instead of the + conventional cell. + + Returns: + BatchedStructures with refined structures. Structures for which + refinement raises an exception are kept as-is. + """ + if _is_batched_structure(structures): + structures_dict: dict[str, Structure] = structures.get_structures( + strct_type="pymatgen" + ) + else: + structures_dict = dict(structures) + + refined: dict[str, Structure] = {} + for k, s in structures_dict.items(): + analyzer = SpacegroupAnalyzer(s, symprec=symprec) + try: + result = analyzer.get_refined_structure() + except Exception: + result = s + + if primitive: + try: + result = SpacegroupAnalyzer( + result, symprec=symprec + ).get_primitive_structure() + except Exception: + pass + + refined[k] = result + + return BatchedStructures(refined) + + +def filter_unique_structures( + structures: "BatchedStructures | dict[str, Structure]", + symprec: float = 0.1, + ltol: float = 0.2, + stol: float = 0.3, + angle_tol: float = 5.0, +) -> BatchedStructures: + """Filter unique structures (per parent key) and space group. + + Groups samples by their parent structure key (splitting on ``_sample_``), + then by space group number, then applies StructureMatcher within each + sub-group to retain one representative per equivalence class. NaN + structures are skipped. The first encountered structure in each equivalence + class is kept as the representative. + + Args: + structures: Inpainting samples, typically keyed as + ``{base_key}_sample_{idx}``. + symprec: Symmetry precision passed to SpacegroupAnalyzer. + ltol: Fractional length tolerance for StructureMatcher. + stol: Site tolerance for StructureMatcher. + angle_tol: Angle tolerance in degrees for StructureMatcher. + + Returns: + BatchedStructures containing one representative per unique structure. + """ + if _is_batched_structure(structures): + structures_dict: dict[str, Structure] = structures.get_structures( + strct_type="pymatgen" + ) + else: + structures_dict = dict(structures) + + base_keys, _ = get_structure_keys(structures_dict) + + groups: dict[str, list[tuple[str, Structure]]] = {} + for full_key, base_key in zip(structures_dict.keys(), base_keys): + groups.setdefault(base_key, []).append( + (full_key, structures_dict[full_key]) + ) + + structure_matcher = StructureMatcher( + ltol=ltol, stol=stol, angle_tol=angle_tol + ) + unique: dict[str, Structure] = {} + + for members in groups.values(): + sg_groups: dict[int, list[tuple[str, Structure]]] = {} + for full_key, structure in members: + if check_for_nan_positions(structure): + continue + try: + sg_num = SpacegroupAnalyzer( + structure, symprec=symprec + ).get_space_group_number() + except Exception: + sg_num = -1 + sg_groups.setdefault(sg_num, []).append((full_key, structure)) + + for sg_members in sg_groups.values(): + representatives: list[tuple[str, Structure]] = [] + for full_key, structure in sg_members: + if not any( + structure_matcher.fit(structure, rep_strct) + for _, rep_strct in representatives + ): + representatives.append((full_key, structure)) + for rep_key, rep_strct in representatives: + unique[rep_key] = rep_strct + + return BatchedStructures(unique) diff --git a/tests/test_eval.py b/tests/test_eval.py index c04253a..4115a9f 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -8,12 +8,9 @@ from pymatgen.io.ase import AseAtomsAdaptor from xtalpaint.eval import ( - _check_for_nan, evaluate_inpainting, - get_structure_keys, ) - @pytest.fixture def reference_structures(): """Load reference structures from extxyz file.""" @@ -67,57 +64,6 @@ def nan_structure(): coords = [[np.nan, 0.0, 0.0], [0.5, 0.5, 0.5]] return Structure(lattice, species, coords) - -class TestCheckForNan: - """Test the _check_for_nan function.""" - - def test_no_nan(self, simple_structure): - """Test structure without NaN values.""" - assert not _check_for_nan(simple_structure) - - def test_with_nan(self, nan_structure): - """Test structure with NaN values.""" - assert _check_for_nan(nan_structure) - - def test_real_structures(self, reference_structures): - """Test that real structures have no NaN values.""" - for key, structure in reference_structures.items(): - assert not _check_for_nan(structure), f"Structure {key} has NaN" - -class TestGetStructureKeys: - """Test the get_structure_keys function.""" - - def test_no_samples(self): - """Test with keys without sample indices.""" - structures = { - "structure_1": None, - "structure_2": None, - } - keys, indices = get_structure_keys(structures) - assert keys == ["structure_1", "structure_2"] - assert indices == [None, None] - - def test_with_samples(self): - """Test with keys containing sample indices.""" - structures = { - "structure_1_sample_0": None, - "structure_1_sample_1": None, - "structure_2_sample_0": None, - } - keys, indices = get_structure_keys(structures) - assert keys == ["structure_1", "structure_1", "structure_2"] - assert indices == ["0", "1", "0"] - - def test_real_structure_keys(self, reference_structures): - """Test with real structure keys from extxyz.""" - keys, indices = get_structure_keys(reference_structures) - # All keys should have sample indices - assert len(keys) == len(reference_structures) - # Check that sample indices are extracted correctly - for idx in indices: - assert idx is not None - - class TestEvaluateInpainting: """Test the evaluate_inpainting function.""" From 5cce2210c8b99a215f8746b0e1112fa6615d3b94 Mon Sep 17 00:00:00 2001 From: t-reents Date: Tue, 26 May 2026 15:39:47 +0200 Subject: [PATCH 2/7] Adding a ``RelaxationWorkGraph`` with optional symmetry refinement and uniqueness analysis This simplifies the several stages of relaxation in the big inpainting workflow. Moreover, it makes it easier to enable symmetry refinement and uniqueness analysis at different stages of the workflow. --- src/xtalpaint/aiida/workgraphs/__init__.py | 4 + src/xtalpaint/aiida/workgraphs/relaxation.py | 126 ++++++ tests/test_relaxation_graph.py | 397 +++++++++++++++++++ 3 files changed, 527 insertions(+) create mode 100644 src/xtalpaint/aiida/workgraphs/relaxation.py create mode 100644 tests/test_relaxation_graph.py diff --git a/src/xtalpaint/aiida/workgraphs/__init__.py b/src/xtalpaint/aiida/workgraphs/__init__.py index 2d17522..4a3aed5 100644 --- a/src/xtalpaint/aiida/workgraphs/__init__.py +++ b/src/xtalpaint/aiida/workgraphs/__init__.py @@ -1 +1,5 @@ """Modules defining workgraphs for inpainting tasks.""" + +from xtalpaint.aiida.workgraphs.relaxation import relaxation_graph + +__all__ = ("relaxation_graph",) diff --git a/src/xtalpaint/aiida/workgraphs/relaxation.py b/src/xtalpaint/aiida/workgraphs/relaxation.py new file mode 100644 index 0000000..8cb9aa9 --- /dev/null +++ b/src/xtalpaint/aiida/workgraphs/relaxation.py @@ -0,0 +1,126 @@ +"""Relaxation WorkGraph with optional refinement and uniqueness filtering.""" + +import typing as t + +from aiida import orm +from aiida_workgraph import spec, task + +from xtalpaint.aiida.tasks import tasks +from xtalpaint.inpainting.config_schema import ( + AiiDATaskOptions, + RelaxationGraphConfig, +) + + +@task.graph( + outputs=spec.namespace( + structures=t.Any, + final_energies=t.Any, + initial_energies=spec.socket(t.Any, required=False), + initial_forces=spec.socket(t.Any, required=False), + final_forces=spec.socket(t.Any, required=False), + ) +) +def relaxation_graph( + structures: t.Any, + relax_config: RelaxationGraphConfig, + aiida_options: AiiDATaskOptions = None, + code_label: str = None, + command_info: dict = None, + constrained: bool = True, +): + """Relaxation WG with optional symmetry refinement and deduplication. + + Runs ``_relaxation_task`` and then optionally: + + 1. ``_refine_structures_task`` — symmetry-refine the relaxed structures. + 2. ``_uniqueness_filter_task`` — keep one representative per unique + (space-group, StructureMatcher equivalence class) group. + + The ``structures`` output always points to the last active step, so + downstream tasks see a consistent socket name regardless of which optional + steps are enabled. + + ``relax_config.refine`` and ``relax_config.filter_unique`` are evaluated + at graph build-time (when the WorkGraph is materialised), so they must + resolve to plain Python ``bool`` values; passing AiiDA nodes wired from + another task's output is not supported for these flags. + + Args: + structures: Input structures to relax. + relax_config: Relaxation and post-processing configuration. + ``relax_config.params`` is forwarded to ``relax_structures`` as + ``relax_inputs``. ``relax_config.refine`` and + ``relax_config.filter_unique`` control the optional steps. + aiida_options: AiiDA scheduler/resource options forwarded to all inner + tasks. If ``None``, default options are used (no resource limits, + no MPI). + code_label: AiiDA code label for all inner pythonjob tasks. If + ``None``, aiida-pythonjob locates ``python3`` automatically. + command_info: Passed as ``command_info`` to every inner pythonjob task + (e.g. ``{"filepath_executable": "/path/to/python"}``). Overrides + automatic executable detection when set. + constrained: If ``True`` (default), ``elements_to_relax`` from + ``relax_config.params`` is included in the relax call so that only + those elements are relaxed. Pass ``False`` for full relaxation + of all atoms. + + Returns: + dict with ``structures`` (relaxed, and optionally refined/filtered), + ``final_energies``, and optionally ``initial_energies``, + ``initial_forces``, ``final_forces`` when requested via + ``relax_config.params``. + """ + _aiida = aiida_options or AiiDATaskOptions() + _options = _aiida.model_dump(exclude={"withmpi"}, exclude_none=True) + _code = orm.load_code(code_label) if code_label else None + _metadata = {"options": _options} + _command_info = command_info or {} + + relaxed = tasks.relaxation_task( + structures=structures, + relax_inputs=relax_config.relax_inputs(constrained=constrained), + usempi=_aiida.withmpi, + metadata=_metadata, + code=_code, + command_info=_command_info, + ) + + current_structures = relaxed.structures + + if relax_config.refine: + refined = tasks.refine_structures_task( + structures=current_structures, + refinement_symprec=relax_config.refinement_symprec, + primitive=relax_config.refinement_primitive, + metadata=_metadata, + code=_code, + command_info=_command_info, + ) + current_structures = refined.structures + + if relax_config.filter_unique: + filtered = tasks.uniqueness_filter_task( + structures=current_structures, + symprec=relax_config.uniqueness.symprec, + ltol=relax_config.uniqueness.ltol, + stol=relax_config.uniqueness.stol, + angle_tol=relax_config.uniqueness.angle_tol, + metadata=_metadata, + code=_code, + command_info=_command_info, + ) + current_structures = filtered.unique_structures + + outputs = { + "structures": current_structures, + "final_energies": relaxed.final_energies, + } + if relax_config.params.return_initial_energies: + outputs["initial_energies"] = relaxed.initial_energies + if relax_config.params.return_initial_forces: + outputs["initial_forces"] = relaxed.initial_forces + if relax_config.params.return_final_forces: + outputs["final_forces"] = relaxed.final_forces + + return outputs diff --git a/tests/test_relaxation_graph.py b/tests/test_relaxation_graph.py new file mode 100644 index 0000000..24dfcd7 --- /dev/null +++ b/tests/test_relaxation_graph.py @@ -0,0 +1,397 @@ +"""Tests for relaxation_graph and filter_unique_structures.""" + +import sys + +from aiida import orm +import numpy as np +import pandas as pd +import pytest +from pymatgen.core import Lattice +from pymatgen.core.structure import Structure + +from xtalpaint.aiida.data import BatchedStructuresData +from xtalpaint.aiida.workgraphs.relaxation import relaxation_graph +from xtalpaint.data import BatchedStructures +from xtalpaint.eval import filter_unique_structures + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def bcc_si(): + """BCC silicon — space group Im-3m (229).""" + return Structure( + [[3.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 3.0]], + ["Si", "Si"], + [[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], + ) + + +@pytest.fixture +def fcc_al(): + """FCC aluminium — space group Fm-3m (225).""" + a = 4.05 + return Structure( + [[a, 0, 0], [0, a, 0], [0, 0, a]], + ["Al", "Al", "Al", "Al"], + [[0, 0, 0], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]], + ) + + +@pytest.fixture +def nan_si(bcc_si): + """BCC silicon with a NaN fractional coordinate.""" + s = bcc_si.copy() + s[0].frac_coords[0] = np.nan + return s + + +# --------------------------------------------------------------------------- +# relaxation_graph: graph-structure tests (no AiiDA runtime required) +# --------------------------------------------------------------------------- + + +# graph_inputs / graph_outputs / graph_ctx are always present in every built +# WorkGraph; they are infrastructure nodes, not user tasks. +_BUILTIN_TASKS = {"graph_inputs", "graph_outputs", "graph_ctx"} + + +def _user_tasks(wg) -> set: + return {t.name for t in wg.tasks if t.name not in _BUILTIN_TASKS} + + +def _structures_source(wg) -> str: + """Return the task name that the graph-level ``structures`` output links from.""" + return wg.outputs["structures"]._links[0].from_task.name + + +class TestRelaxationGraphStructure: + """Verify that relaxation_graph wires the correct tasks for each flag combination. + + These tests call ``.build()`` to materialise the inner WorkGraph without + running any AiiDA processes. They check which tasks exist, which output + sockets are declared, and that the ``structures`` output socket is wired to + the last active task in the chain. + """ + + def _build(self, structures, **kwargs): + return relaxation_graph.build( + structures=structures, + relax_inputs={}, + **kwargs, + ) + + def test_base_contains_only_relaxation_task(self, bcc_si): + wg = self._build(BatchedStructures({"s": bcc_si})) + assert _user_tasks(wg) == {"relaxation_task"} + + def test_refine_flag_adds_refinement_task(self, bcc_si): + wg = self._build(BatchedStructures({"s": bcc_si}), refine=True) + assert _user_tasks(wg) == {"relaxation_task", "refine_structures_task"} + + def test_filter_unique_flag_adds_uniqueness_task(self, bcc_si): + wg = self._build(BatchedStructures({"s": bcc_si}), filter_unique=True) + assert _user_tasks(wg) == {"relaxation_task", "uniqueness_filter_task"} + + def test_both_flags_produce_full_chain(self, bcc_si): + wg = self._build( + BatchedStructures({"s": bcc_si}), refine=True, filter_unique=True + ) + assert _user_tasks(wg) == { + "relaxation_task", + "refine_structures_task", + "uniqueness_filter_task", + } + + @pytest.mark.parametrize( + "refine,filter_unique", + [(False, False), (True, False), (False, True), (True, True)], + ) + def test_structures_and_energies_always_in_outputs( + self, bcc_si, refine, filter_unique + ): + wg = self._build( + BatchedStructures({"s": bcc_si}), + refine=refine, + filter_unique=filter_unique, + ) + assert "structures" in wg.outputs + assert "final_energies" in wg.outputs + + def test_optional_force_energy_sockets_declared(self, bcc_si): + """initial_energies / initial_forces / final_forces are declared as + optional sockets even though they are only populated at runtime when + requested via relax_inputs.""" + wg = self._build(BatchedStructures({"s": bcc_si})) + for socket in ("initial_energies", "initial_forces", "final_forces"): + assert socket in wg.outputs + + @pytest.mark.parametrize( + "refine,filter_unique,expected_src", + [ + (False, False, "relaxation_task"), + (True, False, "refine_structures_task"), + (False, True, "uniqueness_filter_task"), + (True, True, "uniqueness_filter_task"), + ], + ) + def test_structures_output_linked_to_last_active_task( + self, bcc_si, refine, filter_unique, expected_src + ): + """The graph-level ``structures`` output must be wired to the final + step in the active chain, not hardcoded to the relaxation task.""" + wg = self._build( + BatchedStructures({"s": bcc_si}), + refine=refine, + filter_unique=filter_unique, + ) + assert _structures_source(wg) == expected_src + + +# --------------------------------------------------------------------------- +# filter_unique_structures: functional tests +# --------------------------------------------------------------------------- + + +class TestFilterUniqueStructures: + """Tests for the pure-Python filter_unique_structures function.""" + + def test_identical_samples_collapsed_to_one(self, bcc_si): + """Two identical samples of the same parent become one representative.""" + structures = BatchedStructures( + {"s_sample_0": bcc_si, "s_sample_1": bcc_si} + ) + result = filter_unique_structures(structures) + assert len(result.keys()) == 1 + + def test_distinct_compositions_both_kept(self, bcc_si, fcc_al): + """Samples of different compositions can never be StructureMatcher-equal.""" + structures = BatchedStructures( + { + "a_sample_0": bcc_si, + "a_sample_1": fcc_al, + } + ) + result = filter_unique_structures(structures) + assert len(result.keys()) == 2 + + def test_different_space_groups_both_kept(self, bcc_si, fcc_al): + """Samples assigned to different space groups are never merged, even + if StructureMatcher would call them equal (it won't across SG groups).""" + structures = BatchedStructures( + {"p_sample_0": bcc_si, "p_sample_1": fcc_al} + ) + result = filter_unique_structures(structures) + # bcc_si (SG 229) and fcc_al (SG 225) land in different bins → both kept + assert len(result.keys()) == 2 + + def test_nan_structures_are_excluded(self, bcc_si, nan_si): + """NaN-coordinate structures are dropped entirely.""" + structures = BatchedStructures( + {"s_sample_0": nan_si, "s_sample_1": bcc_si} + ) + result = filter_unique_structures(structures) + assert len(result.keys()) == 1 + assert "s_sample_0" not in result.keys() + + def test_all_nan_returns_empty(self, nan_si): + structures = BatchedStructures( + {"s_sample_0": nan_si, "s_sample_1": nan_si} + ) + result = filter_unique_structures(structures) + assert len(result.keys()) == 0 + + def test_multiple_parents_filtered_independently(self, bcc_si, fcc_al): + """Each parent key group is deduplicated independently.""" + structures = BatchedStructures( + { + "a_sample_0": bcc_si, + "a_sample_1": bcc_si, # duplicate of a_sample_0 + "b_sample_0": fcc_al, + "b_sample_1": fcc_al, # duplicate of b_sample_0 + } + ) + result = filter_unique_structures(structures) + keys = result.keys() + # One unique per parent + assert len(keys) == 2 + a_keys = [k for k in keys if k.startswith("a_")] + b_keys = [k for k in keys if k.startswith("b_")] + assert len(a_keys) == 1 + assert len(b_keys) == 1 + + def test_accepts_plain_dict(self, bcc_si): + """filter_unique_structures works with a plain dict, not just BatchedStructures.""" + result = filter_unique_structures({"s": bcc_si}) + assert len(result.keys()) == 1 + + def test_returns_batched_structures(self, bcc_si): + result = filter_unique_structures({"s": bcc_si}) + assert isinstance(result, BatchedStructures) + + +# --------------------------------------------------------------------------- +# Fixtures for execution tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def diamond_si(): + """Diamond-cubic silicon (Fd-3m, SG 227) — 8-atom conventional cell.""" + a = 5.43 + return Structure( + Lattice.cubic(a), + ["Si"] * 8, + [ + [0.000, 0.000, 0.000], + [0.250, 0.250, 0.250], + [0.500, 0.500, 0.000], + [0.750, 0.750, 0.250], + [0.500, 0.000, 0.500], + [0.750, 0.250, 0.750], + [0.000, 0.500, 0.500], + [0.250, 0.750, 0.750], + ], + ) + + +@pytest.fixture(scope="module") +def strained_si(diamond_si): + """Diamond silicon compressed to 98 % of equilibrium volume. + + MatterSim will relax this back to the same diamond-cubic minimum as + ``diamond_si``, which lets us verify that the uniqueness filter correctly + identifies the two relaxed structures as equivalent. + """ + s = diamond_si.copy() + s.apply_strain(-0.02) + return s + + +@pytest.fixture(scope="module") +def fcc_al_conventional(): + """FCC aluminium (Fm-3m, SG 225) — 4-atom conventional cell.""" + a = 4.05 + return Structure( + Lattice.cubic(a), + ["Al"] * 4, + [[0.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.5, 0.0, 0.5], [0.0, 0.5, 0.5]], + ) + + +# MatterSim relax settings kept minimal so the tests run quickly on CPU. +_MATTERSIM_RELAX_INPUTS = { + "device": "cpu", + "mlip": "mattersim", + "fmax": 0.1, + "max_n_steps": 100, +} + + +# --------------------------------------------------------------------------- +# TestRelaxationGraphExecution +# --------------------------------------------------------------------------- + + +class TestRelaxationGraphExecution: + """Integration tests: actually execute the WorkGraph with MatterSim. + + These tests require a temporary AiiDA profile (``aiida_profile``) and a + configured localhost computer (``aiida_localhost``). The WorkGraph tasks + are pythonjob tasks that run in a subprocess using the current Python + executable so that all installed packages (including MatterSim) are + available. + + Each test method requests both ``aiida_profile`` and ``aiida_localhost`` + as fixtures to ensure the AiiDA backend and transport layer are ready + before the WorkGraph is submitted. + """ + + def _build_and_run(self, structures: dict, **graph_kwargs): + wg = relaxation_graph.build( + structures=BatchedStructures(structures), + relax_inputs=_MATTERSIM_RELAX_INPUTS, + command_info={"filepath_executable": sys.executable}, + **graph_kwargs, + ) + wg.run() + + assert wg.process.is_finished_ok, "WorkGraph did not finish successfully" + + return wg + + def test_basic_relaxation_returns_structures_and_energies( + self, aiida_profile, aiida_localhost, diamond_si, fcc_al_conventional + ): + """WorkGraph runs to completion and returns one relaxed structure and + one energy row per input.""" + wg = self._build_and_run( + {"si": diamond_si, "al": fcc_al_conventional}, + ) + + assert wg.state == "FINISHED" + + # --- structures output --- + out_node = wg.tasks.relaxation_task.outputs.structures.value + assert isinstance(out_node, BatchedStructuresData) + out_keys = set(out_node.value.keys()) + assert out_keys == {"si", "al"} + + # --- final_energies output --- + energies_node = wg.tasks.relaxation_task.outputs.final_energies.value + energies_df: pd.DataFrame = energies_node.value + assert isinstance(energies_df, pd.DataFrame) + assert set(energies_df.index) == {"si", "al"} + assert (energies_df["final_energy"] < 0).all(), ( + "Energies from MatterSim should be negative (eV)" + ) + + def test_filter_unique_deduplicates_identical_relaxed_copies( + self, + aiida_profile, + aiida_localhost, + diamond_si, + strained_si, + fcc_al_conventional, + ): + """Three Si samples (two identical, one 2 %-strained) and two Al + samples are passed through the full relaxation + uniqueness-filter + pipeline. After relaxation all three Si structures converge to the + same diamond-cubic minimum, so the filter should retain exactly one Si + representative. The two identical Al copies should likewise collapse + to one, leaving two unique structures in total.""" + structures = { + # Three Si samples with the same parent key "si" — all should + # relax to diamond-cubic Si and be deduplicated to one. + "si_sample_0": diamond_si, + "si_sample_1": diamond_si, # exact copy + "si_sample_2": strained_si, # 2 % compressed, same basin + # Two Al samples with parent key "al" — both relax to FCC Al. + "al_sample_0": fcc_al_conventional, + "al_sample_1": fcc_al_conventional, # exact copy + } + + wg = self._build_and_run(structures, filter_unique=True) + + assert wg.state == "FINISHED" + + unique_node = ( + wg.tasks.uniqueness_filter_task.outputs.unique_structures.value + ) + assert isinstance(unique_node, BatchedStructuresData) + unique_keys = list(unique_node.value.keys()) + + si_keys = [k for k in unique_keys if k.startswith("si_")] + al_keys = [k for k in unique_keys if k.startswith("al_")] + + assert len(si_keys) == 1, ( + f"Expected 1 unique Si structure after deduplication, " + f"got {len(si_keys)}: {si_keys}" + ) + assert len(al_keys) == 1, ( + f"Expected 1 unique Al structure after deduplication, " + f"got {len(al_keys)}: {al_keys}" + ) From fe8c83a266d0730c21c3d80161e4fe0b2ecd82cd Mon Sep 17 00:00:00 2001 From: t-reents Date: Tue, 26 May 2026 16:00:25 +0200 Subject: [PATCH 3/7] Major refactoring of input schema and inpainting workflow The input schema was significantly adjusted and the inpainting WorkGraph was also updated to adopt the `task.graph` approach. --- src/xtalpaint/aiida/workgraphs/inpainting.py | 386 +++++------------ src/xtalpaint/inpainting/config_schema.py | 407 ++++++++++++------ .../inpainting/inpainting_process.py | 63 ++- 3 files changed, 435 insertions(+), 421 deletions(-) diff --git a/src/xtalpaint/aiida/workgraphs/inpainting.py b/src/xtalpaint/aiida/workgraphs/inpainting.py index 2e55625..866ecec 100644 --- a/src/xtalpaint/aiida/workgraphs/inpainting.py +++ b/src/xtalpaint/aiida/workgraphs/inpainting.py @@ -1,296 +1,146 @@ """AiiDA WorkGraph for inpainting of crystal structures.""" -from copy import deepcopy - from aiida import orm -from aiida_workgraph import WorkGraph -from pymatgen.core.structure import Structure - -from xtalpaint.aiida.data import ( - BatchedStructuresData, -) -from xtalpaint.aiida.tasks.tasks import ( - _evaluate_inpainting_task, - _generate_inpainting_candidates_task, - _inpainting_pipeline_task, - _refine_structures_task, - _relaxation_task, +from aiida_workgraph import WorkGraph, task + +from xtalpaint.aiida.tasks import tasks +from xtalpaint.aiida.workgraphs.relaxation import relaxation_graph +from xtalpaint.inpainting.config_schema import ( + AiiDAOptions, + RelaxationGraphConfig, + XtalPaintConfig, ) -from xtalpaint.data import BatchedStructures -from xtalpaint.inpainting.config_schema import InpaintingWorkflowConfig -def setup_inpainting_wg( - inputs: InpaintingWorkflowConfig, -) -> WorkGraph: - """Create a WorkGraph for inpainting of crystal structures.""" - possible_relaxation_tasks = { - "inpainted_constrained_relaxation": inputs.relax, - "unrelaxed_inpainted_full_relaxation": inputs.full_relax - and inputs.full_relax_wo_pre_relax, - "pre_relaxed_inpainted_full_relaxation": inputs.full_relax - and inputs.relax, +def _relax_outputs(prefix: str, out, relax: RelaxationGraphConfig) -> dict: + """Build a prefixed output dict from a relaxation_graph result.""" + outputs = { + f"{prefix}.structures": out.structures, + f"{prefix}.final_energies": out.final_energies, } + if relax.params.return_initial_energies: + outputs[f"{prefix}.initial_energies"] = out.initial_energies + if relax.params.return_initial_forces: + outputs[f"{prefix}.initial_forces"] = out.initial_forces + if relax.params.return_final_forces: + outputs[f"{prefix}.final_forces"] = out.final_forces + return outputs + - wg = WorkGraph() +@task.graph +def InpaintingWorkGraph(inputs: XtalPaintConfig): + """WorkGraph for inpainting of crystal structures.""" + graph_outputs = {} - if not inputs.is_inpainting_structures and inputs.run_inpainting: - _add_inpainting_candidates_generation(wg, inputs) + _aiida: AiiDAOptions = inputs.aiida or AiiDAOptions() + # --- Generate inpainting candidates --- if inputs.run_inpainting: - _add_inpainting_pipeline(wg, inputs) - inpainted_structures = wg.tasks["inpainting"].outputs["structures"] + cand_opts = _aiida.candidate_generation_options + gen_out = tasks.generate_inpainting_candidates_task( + structures=inputs.structures, + **inputs.candidate_generation.model_dump(), + metadata={ + "call_link_label": "generate_inpainting_candidates", + "options": cand_opts.model_dump(exclude={"withmpi"}), + }, + ) + inpainting_candidates = gen_out.candidates + graph_outputs["inpainting_candidates"] = inpainting_candidates else: - inpainted_structures = inputs.structures + inpainting_candidates = inputs.structures - if inputs.refine_structures: - _add_refinement_task( - wg, - structures=inpainted_structures, - refinement_symprec=inputs.refinement_symprec, - primitive=inputs.refinement_primitive, - inputs=inputs, - task_name="refine_structures", + # --- Inpainting pipeline --- + if inputs.run_inpainting: + inp_opts = _aiida.inpainting_options + code_label = _aiida.get_code_label(_aiida.inpainting_code_label) + inp_out = tasks.inpainting_pipeline_task( + structures=inpainting_candidates, + config=inputs.inpainting.model_dump(exclude_none=True), + usempi=inp_opts.withmpi, + metadata={ + "call_link_label": "inpainting", + "options": inp_opts.model_dump(exclude={"withmpi"}), + }, + code=orm.load_code(code_label) if code_label else None, ) - inpainted_structures = wg.tasks["refine_structures"].outputs[ - "structures" - ] - - wg.outputs.inpainted_structures = inpainted_structures - - if inputs.relax or inputs.full_relax: - _add_relaxation_tasks(wg, inpainted_structures, inputs) - - if inputs.evaluate: - relaxation_tasks = { - k: k for k, v in possible_relaxation_tasks.items() if v - } - _add_evaluation_tasks(wg, inputs, relaxation_tasks) - - return wg - - -def _add_inpainting_candidates_generation( - wg: WorkGraph, - inputs: InpaintingWorkflowConfig, -) -> None: - """Add inpainting candidates generation task to the workgraph.""" - wg.add_task( - _generate_inpainting_candidates_task, - structures=inputs.structures, - n_inp=inputs.gen_inpainting_candidates_params.n_inp, - element=inputs.gen_inpainting_candidates_params.element, - num_samples=inputs.gen_inpainting_candidates_params.num_samples, - name="generate_inpainting_candidates", - metadata={ - "options": ( - inputs.gen_inpainting_candidates_options or inputs.options - ) - }, - ) - - wg.outputs.inpainting_candidates = wg.tasks[ - "generate_inpainting_candidates" - ].outputs["candidates"] - - -def _add_refinement_task( - wg: WorkGraph, - structures: BatchedStructures | BatchedStructuresData, - refinement_symprec: float, - inputs: InpaintingWorkflowConfig, - task_name: str = "refine_structures", -) -> None: - """Add structure refinement task to the workgraph.""" - wg.add_task( - _refine_structures_task, - structures=structures, - refinement_symprec=refinement_symprec, - name=task_name, - metadata={ - "options": inputs.options or {}, - }, - ) - - -def _add_inpainting_pipeline( - wg: WorkGraph, - inputs: InpaintingWorkflowConfig, -) -> None: - """Add inpainting pipeline task to the workgraph.""" - inpainting_candidates = ( - wg.tasks["generate_inpainting_candidates"].outputs["candidates"] - if not inputs.is_inpainting_structures and inputs.run_inpainting - else inputs.structures - ) - - code_label = inputs.inpainting_code_label or inputs.code_label + inpainted_structures = inp_out.structures - wg.add_task( - _inpainting_pipeline_task, - structures=inpainting_candidates, - config=inputs.inpainting_pipeline_params.model_dump(exclude_none=True), - usempi=( - inputs.inpainting_pipeline_options.get("withmpi", False) - if inputs.inpainting_pipeline_options - else False - ), - name="inpainting", - metadata={ - "options": (inputs.inpainting_pipeline_options or inputs.options), - }, - code=orm.load_code(code_label) if code_label else None, - ) - - if inputs.inpainting_pipeline_params.record_trajectories: - wg.outputs.inpainted_trajectories = wg.tasks["inpainting"].outputs[ - "trajectories" - ] - if "mean_trajectories" in wg.tasks["inpainting"].outputs: - wg.outputs.inpainted_mean_trajectories = wg.tasks[ - "inpainting" - ].outputs["mean_trajectories"] - - -def _add_relaxation_tasks( - wg: WorkGraph, - structures: BatchedStructures | BatchedStructuresData, - inputs: InpaintingWorkflowConfig, -) -> None: - """Add relaxation tasks to the workgraph.""" - code_label = inputs.relax_code_label or inputs.code_label - relax_kwargs = deepcopy(inputs.relax_kwargs.model_dump()) + if inputs.inpainting.record_trajectories: + graph_outputs["inpainted_trajectories"] = inp_out.trajectories + else: + inpainted_structures = inputs.structures - if inputs.relax: - wg = _add_full_relax_task( - wg=wg, - structures=structures, - relax_inputs=relax_kwargs, - task_name="inpainted_constrained_relaxation", - options=inputs.relax_options or inputs.options, - code=orm.load_code(code_label) if code_label else None, - as_graph_outputs=True, + # --- Pre-refinement (before relaxation) --- + if inputs.pre_refinement is not None: + ref_out = tasks.refine_structures_task( + structures=inpainted_structures, + refinement_symprec=inputs.pre_refinement.symprec, + primitive=inputs.pre_refinement.primitive, + metadata={ + "call_link_label": "refine_structures", + "options": {}, + }, ) - - if inputs.full_relax: - relax_kwargs.pop("elements_to_relax", None) - if inputs.full_relax_wo_pre_relax: - wg = _add_full_relax_task( - wg=wg, - structures=structures, - relax_inputs=relax_kwargs, - task_name="unrelaxed_inpainted_full_relaxation", - options=inputs.relax_options or inputs.options, - code=orm.load_code(code_label) if code_label else None, - as_graph_outputs=True, + inpainted_structures = ref_out.structures + + graph_outputs["inpainted_structures"] = inpainted_structures + + # --- Relaxation --- + if inputs.relaxation is not None: + relax = inputs.relaxation + relax_opts = _aiida.relax_options + relax_code_label = _aiida.get_code_label(_aiida.relax_code_label) + + cr_out = None + if relax.constrained: + cr_out = relaxation_graph( + structures=inpainted_structures, + relax_config=relax, + aiida_options=relax_opts, + code_label=relax_code_label, + constrained=True, + metadata={ + "call_link_label": "inpainted_constrained_relaxation" + }, ) - - if inputs.relax: - wg = _add_full_relax_task( - wg=wg, - structures=wg.tasks[ - "inpainted_constrained_relaxation" - ].outputs["structures"], - relax_inputs=relax_kwargs, - task_name="pre_relaxed_inpainted_full_relaxation", - options=inputs.relax_options or inputs.options, - code=orm.load_code(code_label) if code_label else None, - as_graph_outputs=True, + graph_outputs |= _relax_outputs( + "inpainted_constrained_relaxation", cr_out, relax ) - -def _add_evaluation_tasks( - wg: WorkGraph, - inputs: InpaintingWorkflowConfig, - relaxation_tasks: dict[str, str], -) -> None: - """Add evaluation tasks to the workgraph.""" - code_label = inputs.evaluate_params.code_label or inputs.code_label - - evaluation_results = {} - metrics = ( - inputs.evaluate_params.metrics - if isinstance(inputs.evaluate_params.metrics, list) - else [inputs.evaluate_params.metrics] - ) - tasks_to_evaluate = {} - if inputs.run_inpainting: - tasks_to_evaluate["inpainting"] = "inpainting" - if inputs.refine_structures: - tasks_to_evaluate["inpainting"] = "refine_structures" - - tasks_to_evaluate.update(relaxation_tasks) - - for metric in metrics: - for label, task_name in tasks_to_evaluate.items(): - wg.add_task( - _evaluate_inpainting_task, - inpainted_structures=wg.tasks[task_name].outputs["structures"], - reference_structures=inputs.structures, - metric=metric, - max_workers=inputs.evaluate_params.max_workers, - name=f"evaluate_inpainting_{metric}_{label}", + if relax.full_direct: + ufr_out = relaxation_graph( + structures=inpainted_structures, + relax_config=relax, + aiida_options=relax_opts, + code_label=relax_code_label, + constrained=False, metadata={ - "options": inputs.options or {}, + "call_link_label": "unrelaxed_inpainted_full_relaxation" }, - code=orm.load_code(code_label) if code_label else None, ) - evaluation_results.setdefault(label, {}).update( - { - f"{metric}": wg.tasks[ - f"evaluate_inpainting_{metric}_{label}" - ].outputs["metric_results"], - } + graph_outputs |= _relax_outputs( + "unrelaxed_inpainted_full_relaxation", ufr_out, relax ) - wg.outputs.evaluation_results = evaluation_results - -def _add_full_relax_task( - wg: WorkGraph, - structures: ( - dict[str, Structure] | BatchedStructuresData | BatchedStructures - ), - relax_inputs: dict, - task_name: str = "full_relaxation", - options: dict = None, - code: orm.Code | None = None, - as_graph_outputs: bool = False, -): - """Add a full relaxation task to the workgraph.""" - wg.add_task( - _relaxation_task, - structures=structures, - relax_inputs=relax_inputs, - usempi=options.get("withmpi", False), - name=task_name, - metadata={ - "options": options or {}, - }, - code=code, - ) - if as_graph_outputs: - outputs = { - f"{task_name}.structures": wg.tasks[task_name].outputs[ - "structures" - ], - f"{task_name}.final_energies": wg.tasks[task_name].outputs[ - "final_energies" - ], - } + if relax.full and cr_out is not None: + pfr_out = relaxation_graph( + structures=cr_out.structures, + relax_config=relax, + aiida_options=relax_opts, + code_label=relax_code_label, + constrained=False, + metadata={ + "call_link_label": "pre_relaxed_inpainted_full_relaxation" + }, + ) + graph_outputs |= _relax_outputs( + "pre_relaxed_inpainted_full_relaxation", pfr_out, relax + ) - if relax_inputs.get("return_initial_energies", False): - outputs[f"{task_name}.initial_energies"] = wg.tasks[ - task_name - ].outputs["initial_energies"] - if relax_inputs.get("return_initial_forces", False): - outputs[f"{task_name}.initial_forces"] = wg.tasks[ - task_name - ].outputs["initial_forces"] - if relax_inputs.get("return_final_forces", False): - outputs[f"{task_name}.final_forces"] = wg.tasks[task_name].outputs[ - "final_forces" - ] + return graph_outputs - wg.outputs = outputs - return wg +def setup_inpainting_wg(inputs: XtalPaintConfig) -> WorkGraph: + """Create a WorkGraph for inpainting of crystal structures.""" + return InpaintingWorkGraph.build(inputs=inputs) diff --git a/src/xtalpaint/inpainting/config_schema.py b/src/xtalpaint/inpainting/config_schema.py index 1c8ec52..620110c 100644 --- a/src/xtalpaint/inpainting/config_schema.py +++ b/src/xtalpaint/inpainting/config_schema.py @@ -24,61 +24,49 @@ def _is_valid_structure_type(obj) -> bool: return False -def _is_inpainting_structure(obj) -> bool: - """Check if object is an InpaintingStructureData (requires AiiDA).""" - if is_aiida_installed(): - from xtalpaint.aiida.data import InpaintingStructureData +# --------------------------------------------------------------------------- +# Stage configs +# --------------------------------------------------------------------------- - return isinstance(obj, InpaintingStructureData) - return False +class CandidateGenerationConfig(BaseModel): + """Configuration for generating inpainting candidates.""" -class RelaxParameters(BaseModel): - """Configuration for the relaxation stage.""" + n_inp: int | dict[str, int] + element: str | dict[str, str] + num_samples: int = 1 - load_path: str | None = None - fmax: float = 0.05 - elements_to_relax: Optional[list[str]] = Field( - default=None, - description="List of elements to relax during optimization.", - ) - max_natoms_per_batch: int = 512 - max_n_steps: int = 500 - device: str = "cpu" - filter: Optional[str] = None - optimizer: str - mlip: str - return_initial_energies: bool = False - return_initial_forces: bool = False - return_final_forces: bool = False +class InpaintingConfig(BaseModel): + """Configuration for the diffusion inpainting stage. -class InpaintingModelParams(BaseModel): - """Diffusion sampling parameters for the inpainting model.""" + Sampling parameters are kept flat (no nested ModelParams sub-object) + so they can be specified in a single block and passed directly as a + dict to the inpainting pipeline. + """ + # Model — exactly one of these must be provided + pretrained_name: Optional[str] = None + model_path: Optional[str] = None + + # Diffusion sampling + predictor_corrector: str N_steps: int coordinates_snr: float n_corrector_steps: int batch_size: int - n_resample_steps: Optional[int] = None - jump_length: Optional[int] = None - - -class InpaintingPipelineParams(BaseModel): - """Settings for constructing an inpainting pipeline.""" - - predictor_corrector: str fix_cell: bool = True - inpainting_model_params: InpaintingModelParams - pretrained_name: Optional[str] = None - model_path: Optional[str] = None - record_trajectories: Optional[bool] = False + record_trajectories: bool = False sampling_config_path: Optional[str] = None + # Repaint-specific (required when predictor_corrector contains 'repaint') + n_resample_steps: Optional[int] = None + jump_length: Optional[int] = None + @field_validator("predictor_corrector") @classmethod def validate_predictor_corrector(cls, v): - """Validator to ensure 'predictor_corrector' is a supported key.""" + """Validate that predictor_corrector is one of the allowed options.""" from xtalpaint.inpainting.inpainting_process import ( GUIDED_PREDICTOR_CORRECTOR_MAPPING, ) @@ -93,11 +81,7 @@ def validate_predictor_corrector(cls, v): @model_validator(mode="after") @classmethod def check_pretrained_model_exclusive(cls, cfg): - """Validate model specification. - - Ensure that either 'pretrained_name' or 'model_path' is provided, - but not both. - """ + """Validate model specification.""" if ( cfg.pretrained_name is not None and cfg.model_path is not None ) or (cfg.pretrained_name is None and cfg.model_path is None): @@ -110,83 +94,261 @@ def check_pretrained_model_exclusive(cls, cfg): @model_validator(mode="after") @classmethod def check_repaint_requires_resample_and_jump(cls, cfg): - """Validate 'n_resample_steps' and 'jump_length'. - - If 'predictor_corrector' contains 'repaint', both parameters must be - set in 'inpainting_model_params'. - """ + """Validate RePaint-specific parameters.""" if "repaint" in cfg.predictor_corrector.lower(): - params = cfg.inpainting_model_params - if params.n_resample_steps is None or params.jump_length is None: + if cfg.n_resample_steps is None or cfg.jump_length is None: raise ValueError( "When 'predictor_corrector' contains 'repaint', " - "inpainting_model_params must set both 'n_resample_steps' " - "and 'jump_length'." + "both 'n_resample_steps' and 'jump_length' must be set." ) return cfg -class GenInpaintingCandidatesParams(BaseModel): - """Configuration for generating inpainting candidates.""" +class RefinementConfig(BaseModel): + """Symmetry refinement stage.""" - n_inp: int | dict[str, int] - element: str | dict[str, str] - num_samples: int = 1 + symprec: float = 0.01 + primitive: bool = False -class EvalParameters(BaseModel): - """Evaluation parameters for generated structures.""" +class UniquenessConfig(BaseModel): + """Parameters for post-relaxation uniqueness/deduplication filtering.""" - max_workers: int = 6 - chunksize: int = 50 - metrics: str | list[str] = "match" - code_label: Optional[str] = None + symprec: float = 0.01 + ltol: float = 0.2 + stol: float = 0.3 + angle_tol: float = 5.0 -class InpaintingWorkflowConfig(BaseModel): - """Top-level configuration for a XtalPaint inpainting workflow. +class RelaxationParams(BaseModel): + """Core relaxation parameters forwarded to ``relax_structures()``. - This config can be used for both AiiDA-based workflows (WorkGraphs) - and regular Python-based workflows. + These are the settings that control *how* a single relaxation is run + (MLIP, optimiser, convergence criteria, etc.). They are kept separate + from the inpainting-workflow-level controls in ``RelaxationConfig``. """ - structures: BatchedStructures | dict[str, Structure] - run_inpainting: bool = True - inpainting_pipeline_params: InpaintingPipelineParams - gen_inpainting_candidates_params: Optional[ - GenInpaintingCandidatesParams - ] = None - code_label: Optional[str] = None - relax_code_label: Optional[str] = None - inpainting_code_label: Optional[str] = None - relax: Optional[bool] = False - relax_kwargs: Optional[RelaxParameters] = {} - full_relax: Optional[bool] = False - full_relax_wo_pre_relax: Optional[bool] = False - options: Optional[dict] = {} - relax_options: Optional[dict] = {} - gen_inpainting_candidates_options: Optional[dict] = {} - inpainting_pipeline_options: Optional[dict] = {} - evaluate: Optional[bool] = False - evaluate_params: Optional[EvalParameters] = None - refine_structures: bool = False - refine_structures_after_relax: bool = False + mlip: str + optimizer: str + fmax: float = 0.05 + max_n_steps: int = 500 + max_natoms_per_batch: int = 512 + device: str = "cpu" + filter: Optional[str] = None + load_path: Optional[str] = None + elements_to_relax: Optional[list[str]] = None + return_initial_energies: bool = False + return_initial_forces: bool = False + return_final_forces: bool = False + + +class RelaxationGraphConfig(BaseModel): + """Configuration for a single ``relaxation_graph`` call. + + Bundles the core relaxation parameters with the optional post-relaxation + processing steps (symmetry refinement and uniqueness filtering) that + ``relaxation_graph`` can apply after each pass. + + This class is the direct input type for ``relaxation_graph``. + ``RelaxationConfig`` extends it with multi-pass orchestration flags for + use inside the inpainting WorkGraph. + """ + + # Core relaxation parameters (forwarded as relax_inputs + # to relax_structures) + params: RelaxationParams + + # Post-relaxation steps + refine: bool = False refinement_symprec: float = 0.01 refinement_primitive: bool = False + filter_unique: bool = False + uniqueness: UniquenessConfig = Field(default_factory=UniquenessConfig) + + def relax_inputs(self, constrained: bool = True) -> dict: + """Build the ``relax_inputs`` kwargs dict for ``relax_structures()``. + + Args: + constrained: If True, include ``elements_to_relax`` so that only + those elements are relaxed. Pass False for full-relax passes + where all atoms move freely. + """ + if constrained: + return self.params.model_dump() + return self.params.model_dump(exclude={"elements_to_relax"}) + + +class RelaxationConfig(RelaxationGraphConfig): + """Configuration for the relaxation stage in the inpainting workflow. + + Extends ``RelaxationGraphConfig`` with multi-pass orchestration flags + that are specific to the inpainting WorkGraph. The three passes share + the same ``params`` and post-relaxation settings. + + Pass names and their semantics + -------------------------------- + constrained + Relax only the atoms listed in ``params.elements_to_relax``. Requires + ``params.elements_to_relax`` to be set. Labelled + ``inpainted_constrained_relaxation`` in the WorkGraph. + full + Run a full (all-atom) relaxation on the output of the constrained pass. + Requires ``constrained=True``. Labelled + ``pre_relaxed_inpainted_full_relaxation``. + full_direct + Run a full relaxation directly on the inpainted structures, bypassing + the constrained pre-relax step (useful for comparison). Labelled + ``unrelaxed_inpainted_full_relaxation``. + """ + + # Which relaxation passes to run (inpainting-WG-specific) + constrained: bool = True + full: bool = False + full_direct: bool = False + + @model_validator(mode="after") + @classmethod + def validate_passes(cls, cfg): + """Validate relaxation modes.""" + if not any([cfg.constrained, cfg.full, cfg.full_direct]): + raise ValueError( + "At least one of 'constrained', 'full', or 'full_direct' " + "must be True." + ) + if cfg.constrained and cfg.params.elements_to_relax is None: + raise ValueError( + "'params.elements_to_relax' must be set when " + "'constrained=True'." + ) + if cfg.full and not cfg.constrained: + raise ValueError( + "'full=True' requires 'constrained=True': the full-relax pass " + "runs on the output of the constrained pass." + ) + return cfg + + +# --------------------------------------------------------------------------- +# AiiDA-specific options (ignored outside AiiDA execution) +# --------------------------------------------------------------------------- + + +class AiiDATaskOptions(BaseModel): + """AiiDA scheduler and resource options for a single task. + + Replaces the raw ``Optional[dict]`` options fields. ``withmpi`` + controls MPI-parallel execution and is kept here (infrastructure concern) + rather than in the pipeline config. + """ + + resources: dict = Field(default_factory=dict) + max_wallclock_seconds: Optional[int] = None + queue_name: Optional[str] = None + withmpi: bool = False + + +class AiiDAOptions(BaseModel): + """AiiDA-specific settings: code labels and per-task scheduler options. + + Place this in ``XtalPaintConfig.aiida``; it is ignored entirely in + non-AiiDA (plain Python) execution. + """ + + default_code_label: Optional[str] = None + inpainting_code_label: Optional[str] = None + relax_code_label: Optional[str] = None + candidate_generation_code_label: Optional[str] = None + + inpainting_options: AiiDATaskOptions = Field( + default_factory=AiiDATaskOptions + ) + relax_options: AiiDATaskOptions = Field(default_factory=AiiDATaskOptions) + candidate_generation_options: AiiDATaskOptions = Field( + default_factory=AiiDATaskOptions + ) + + def get_code_label(self, specific: Optional[str] = None) -> Optional[str]: + """Return *specific* code label, falling back to the default.""" + return specific or self.default_code_label + + +# --------------------------------------------------------------------------- +# Top-level config +# --------------------------------------------------------------------------- + + +class XtalPaintConfig(BaseModel): + """Complete configuration for an XtalPaint inpainting workflow. + + Works for both AiiDA-based (WorkGraph) and plain-Python execution. + AiiDA-specific settings live in the optional ``aiida`` block and are + ignored in non-AiiDA runs. + + Pipeline stages are controlled by presence/absence of their config + objects — no boolean flags required: + + * ``candidate_generation`` — omit if structures are already + ``InpaintingStructureData`` objects. + * ``pre_refinement`` — symmetry-refine structures before relaxation; + omit to skip. + * ``relaxation`` — geometry optimisation; omit to skip. + + Example (minimal):: + + XtalPaintConfig( + structures={"si_001": structure}, + inpainting=InpaintingConfig( + pretrained_name="mattergen_base", + predictor_corrector="baseline", + N_steps=5, coordinates_snr=0.2, + n_corrector_steps=1, batch_size=1000, + ), + ) + + Example (with relaxation + deduplication on AiiDA):: + + XtalPaintConfig( + structures=..., + candidate_generation=CandidateGenerationConfig( + n_inp={"H": 2}, element="H" + ), + inpainting=InpaintingConfig(...), + pre_refinement=RefinementConfig(symprec=0.01), + relaxation=RelaxationConfig( + params=RelaxationParams( + mlip="mattersim", + optimizer="BFGS", + elements_to_relax=["H"], + fmax=0.01, + ), + full=True, + filter_unique=True, + ), + aiida=AiiDAOptions( + default_code_label="xtalpaint@localhost", + relax_code_label="relax@hpc", + relax_options=AiiDATaskOptions( + resources={"num_machines": 1}, + withmpi=True, + ), + ), + ) + """ + + structures: BatchedStructures | dict[str, Structure] + run_inpainting: bool = True + candidate_generation: Optional[CandidateGenerationConfig] = None + pre_refinement: Optional[RefinementConfig] = None + inpainting: InpaintingConfig + relaxation: Optional[RelaxationConfig] = None + aiida: Optional[AiiDAOptions] = None model_config = ConfigDict(arbitrary_types_allowed=True) @field_validator("structures") @classmethod def validate_structures(cls, v): - """Validate input structures. - - Ensure 'structures' is a dictionary with string keys and values of - uniform, supported types. - - Raises: - TypeError: If the structure mapping is not valid. - """ + """Validate `structures` type.""" structures = v if _is_batched_structure(v): structures = v.get_structures(strct_type="pymatgen") @@ -201,7 +363,6 @@ def validate_structures(cls, v): "All values in the dictionary must be of type StructureData, " "Structure, ase.Atoms, or InpaintingStructureData" ) - types = {type(s) for s in structures.values()} if len(types) > 1: raise TypeError( @@ -209,48 +370,16 @@ def validate_structures(cls, v): ) return v - @model_validator(mode="after") - @classmethod - def check_n_inp_for_structures(cls, cfg): - """Validate inputs for inpainting candidates. - Ensure that 'gen_inpainting_candidates_params' is provided when - structures are not already inpainting structure instances. - """ - values = ( - list(cfg.structures.values()) - if isinstance(cfg.structures, dict) - else cfg.structures.get_structures(strct_type="pymatgen") - ) - if not all( - _is_inpainting_structure(s) or isinstance(s, Structure) - for s in values - ): - if cfg.gen_inpainting_candidates_params is None: - raise ValueError( - "If structures are not InpaintingStructure objects, " - "gen_inpainting_candidates_params must be provided." - ) - return cfg +# --------------------------------------------------------------------------- +# Evaluation parameters (standalone — not part of the inpainting workflow) +# --------------------------------------------------------------------------- - @model_validator(mode="after") - @classmethod - def check_evaluate_inpainting_structures(cls, cfg): - """Check if structures are already InpaintingStructure objects.""" - if cfg.evaluate and cfg.is_inpainting_structures: - raise ValueError( - "If 'evaluate' is True, structures must not be " - "InpaintingStructure objects. We need the original structures " - "to compare against inpainted structures." - ) - return cfg - @property - def is_inpainting_structures(self) -> bool: - """Check if structures are already InpaintingStructure objects.""" - structures = ( - self.structures.values() - if isinstance(self.structures, dict) - else self.structures.get_structures(strct_type="pymatgen") - ) - return all(_is_inpainting_structure(s) for s in structures) +class EvalParameters(BaseModel): + """Evaluation parameters for generated structures.""" + + max_workers: int = 6 + chunksize: int = 50 + metrics: str | list[str] = "match" + code_label: Optional[str] = None diff --git a/src/xtalpaint/inpainting/inpainting_process.py b/src/xtalpaint/inpainting/inpainting_process.py index d3401c7..8fed493 100644 --- a/src/xtalpaint/inpainting/inpainting_process.py +++ b/src/xtalpaint/inpainting/inpainting_process.py @@ -18,7 +18,7 @@ from xtalpaint.generate_inpainting import ( generate_reconstructed_structures, ) -from xtalpaint.inpainting.config_schema import InpaintingPipelineParams +from xtalpaint.inpainting.config_schema import InpaintingConfig from xtalpaint.utils.data_utils import create_dataloader XTALPAINT_BASE = "xtalpaint.predictor_corrector" @@ -173,30 +173,51 @@ def _get_overrides( def _run_inpainting( predictor_corrector: str, structures_dl: DataLoader, - inpainting_model_params: dict[str, Any], + N_steps: int, + coordinates_snr: float, + n_corrector_steps: int, + batch_size: int, fix_cell: bool = True, record_trajectories: bool = False, pretrained_name: str | None = None, model_path: str | None = None, sampling_config_path: str | None = None, + n_resample_steps: int | None = None, + jump_length: int | None = None, ) -> tuple[list[Structure], list, list | None]: """Run the inpainting process using MatterGen. Args: predictor_corrector: Type of predictor-corrector to use. structures_dl: DataLoader containing structures to inpaint. - inpainting_model_params: Parameters for the inpainting model. + N_steps: Number of diffusion steps. + coordinates_snr: Signal-to-noise ratio for coordinate corrector. + n_corrector_steps: Number of corrector steps per diffusion step. + batch_size: Batch size for the DataLoader. fix_cell: Whether to fix the cell during sampling. record_trajectories: Whether to record trajectories. pretrained_name: Name of pretrained model, if any. model_path: Path to model checkpoint. - sampling_config_path: Path to the sampling config directory - for mattergen + sampling_config_path: Path to the sampling config directory of + mattergen. + n_resample_steps: Number of resampling steps (repaint variants only). + jump_length: Jump length for resampling (repaint variants only). Returns: Tuple of (inpainted_structures, trajectories, mean_trajectories). mean_trajectories is None if not recorded. """ + inpainting_model_params: dict[str, Any] = { + "N_steps": N_steps, + "coordinates_snr": coordinates_snr, + "n_corrector_steps": n_corrector_steps, + "batch_size": batch_size, + } + if n_resample_steps is not None: + inpainting_model_params["n_resample_steps"] = n_resample_steps + if jump_length is not None: + inpainting_model_params["jump_length"] = jump_length + sampling_config_overrides, config_overrides = _get_overrides( inpainting_model_params, predictor_corrector, fix_cell, pretrained_name ) @@ -305,13 +326,15 @@ def _extract_outputs( def run_inpainting_pipeline( structures: dict[str, Structure], - config: InpaintingPipelineParams | dict[str, Any], + config: InpaintingConfig | dict[str, Any], ) -> dict[str, Any]: """Run the inpainting experiment using MatterGen. Args: structures: Input structures for inpainting. - config: Configuration for the inpainting process. + config: Configuration for the inpainting process. Pass an + ``InpaintingConfig`` instance or an equivalent plain dict + (e.g. from ``InpaintingConfig.model_dump(exclude_none=True)``). Returns: Dictionary containing inpainted structures, trajectories, and scores. @@ -321,13 +344,19 @@ def run_inpainting_pipeline( labels, structures = map(list, zip(*structures.items())) + cfg = ( + config + if isinstance(config, dict) + else config.model_dump(exclude_none=True) + ) + prepared_structures = _prepare_structures( structures, - batch_size=config["inpainting_model_params"].get("batch_size", 64), + batch_size=cfg["batch_size"], ) inpainted_structures, trajectories, mean_trajectories = _run_inpainting( - structures_dl=prepared_structures, **config + structures_dl=prepared_structures, **cfg ) return _extract_outputs( @@ -335,13 +364,13 @@ def run_inpainting_pipeline( trajectories, mean_trajectories, labels, - config["record_trajectories"], + cfg["record_trajectories"], ) def run_mpi_parallel_inpainting_pipeline( structures: dict[str, Structure], - config: InpaintingPipelineParams | dict[str, Any], + config: InpaintingConfig | dict[str, Any], ) -> dict[str, Any] | None: """Run the inpainting experiment using MatterGen with MPI parallelization. @@ -360,6 +389,12 @@ def run_mpi_parallel_inpainting_pipeline( labels, structures = map(list, zip(*structures.items())) + cfg = ( + config + if isinstance(config, dict) + else config.model_dump(exclude_none=True) + ) + comm = mpi4py.MPI.COMM_WORLD rank = comm.rank nranks = comm.size @@ -385,10 +420,10 @@ def run_mpi_parallel_inpainting_pipeline( prepared_structures = _prepare_structures( local_structures, - batch_size=config["inpainting_model_params"].get("batch_size", 64), + batch_size=cfg["batch_size"], ) - rank_results = _run_inpainting(structures_dl=prepared_structures, **config) + rank_results = _run_inpainting(structures_dl=prepared_structures, **cfg) # Gather results on rank 0 all_results = comm.gather(rank_results, root=0) @@ -412,7 +447,7 @@ def run_mpi_parallel_inpainting_pipeline( all_trajectories, all_mean_trajectories if all_mean_trajectories else None, labels, - config["record_trajectories"], + cfg["record_trajectories"], ) return None From e5a1411d50763cbd96757661b0d0e61d7a561770 Mon Sep 17 00:00:00 2001 From: t-reents Date: Tue, 26 May 2026 16:13:00 +0200 Subject: [PATCH 4/7] Add input schema description to docs --- docs/configuration.md | 480 ++++++++++++++++++++++++++++++++++++++++++ docs/index.md | 10 +- mkdocs.yml | 1 + 3 files changed, 486 insertions(+), 5 deletions(-) create mode 100644 docs/configuration.md diff --git a/docs/configuration.md b/docs/configuration.md new file mode 100644 index 0000000..2009593 --- /dev/null +++ b/docs/configuration.md @@ -0,0 +1,480 @@ +# Workflow Configuration + +XtalPaint uses a single configuration object — `XtalPaintConfig` — to drive both plain-Python and AiiDA-based workflows. This page explains how to build that config, what each field does, and how the same config object is shared between the two execution modes. + +--- + +## Design principles + +- **Presence = enabled, `None` = skip.** Each pipeline stage is controlled by its own typed config object. If the field is `None` the stage is omitted — no boolean flags needed. +- **AiiDA options are isolated.** Everything in the optional `aiida` block (code labels, scheduler resources) is invisible to plain-Python execution. +- **Flat, validated inputs.** Pydantic validates every field at construction time so mistakes surface immediately rather than at run time. + +--- + +## Top-level structure + +```python +from xtalpaint.inpainting.config_schema import ( + XtalPaintConfig, + CandidateGenerationConfig, + InpaintingConfig, + RefinementConfig, + RelaxationConfig, + RelaxationParams, + UniquenessConfig, + AiiDAOptions, + AiiDATaskOptions, +) + +config = XtalPaintConfig( + structures=..., # required — dict[str, Structure] or BatchedStructures + run_inpainting=True, # set False to skip the diffusion step + candidate_generation=..., # CandidateGenerationConfig | None + pre_refinement=..., # RefinementConfig | None + inpainting=..., # InpaintingConfig (always required) + relaxation=..., # RelaxationConfig | None + aiida=..., # AiiDAOptions | None (ignored outside AiiDA) +) +``` + +The pipeline runs in this order when a stage is enabled: + +``` +candidate_generation → inpainting → pre_refinement → relaxation + ├─ constrained pass + ├─ full pass (on constrained output) + └─ full_direct pass (on inpainted directly) + each pass: → [refine] → [filter_unique] +``` + +--- + +## Stage reference + +### Input structures + +```python +from pymatgen.core import Structure + +config = XtalPaintConfig( + structures={"host_001": Structure(...), "host_002": Structure(...)}, + inpainting=..., +) +``` + +`structures` accepts: + +- `dict[str, Structure]` — plain pymatgen structures +- `BatchedStructures` — XtalPaint's batched wrapper +- AiiDA `StructureData` / `InpaintingStructureData` — when running inside AiiDA + +--- + +### Candidate generation + +Required when the input structures are plain `Structure` objects (not yet marked as inpainting targets). Omit this block if your structures are already `InpaintingStructureData` instances. + +```python +candidate_generation=CandidateGenerationConfig( + n_inp=2, # int or dict[str, int] — number of sites to inpaint + element="H", # element to place; dict[str, str] for per-structure control + num_samples=1, # how many candidate sets to generate +) +``` + +For per-structure control over the number of sites and element: + +```python +candidate_generation=CandidateGenerationConfig( + n_inp={"host_001": 2, "host_002": 4}, + element={"host_001": "H", "host_002": "Li"}, +) +``` + +--- + +### Inpainting + +The core diffusion stage. All sampling parameters live in one flat block. + +```python +inpainting=InpaintingConfig( + # Model — provide exactly one of these: + pretrained_name="mattergen_base", # use a bundled pretrained checkpoint + # model_path="/path/to/checkpoint", # or point to a local file + + # Sampling + predictor_corrector="baseline", # see supported keys below + N_steps=5, + coordinates_snr=0.2, + n_corrector_steps=1, + batch_size=1000, + + # Optional + fix_cell=True, # keep unit cell fixed during sampling + record_trajectories=False, + sampling_config_path=None, # override MatterGen sampling config dir +) +``` + +**Supported `predictor_corrector` values:** + +| Key | Description | +|---|---| +| `baseline` | Standard guided predictor-corrector | +| `baseline-with-noise` | Custom variant with additional noise | +| `baseline-store-scores` | Records score function outputs | +| `repaint-v1` | RePaint resampling (legacy) | +| `repaint-v2` | RePaint resampling (v2) | +| `TD` | Time-dependent (TD-Paint) variant | + +!!! note "Repaint variants" + When using `repaint-v1` or `repaint-v2`, you must also set `n_resample_steps` and `jump_length`: + + ```python + inpainting=InpaintingConfig( + predictor_corrector="repaint-v2", + n_resample_steps=10, + jump_length=5, + # ... other fields ... + ) + ``` + +--- + +### Pre-refinement + +Optional symmetry refinement applied *after* inpainting and *before* relaxation. + +```python +pre_refinement=RefinementConfig( + symprec=0.01, # symmetry precision for SpacegroupAnalyzer + primitive=False, # if True, convert to primitive cell +) +``` + +Omit `pre_refinement` (or set it to `None`) to skip this step. + +--- + +### Relaxation + +The relaxation stage is split into two distinct layers: + +- **`RelaxationParams`** — the inputs forwarded directly to `relax_structures()` (MLIP, optimiser, convergence) +- **`RelaxationConfig`** — workflow-level controls: which passes to run and post-relaxation processing + +#### Relaxation passes + +Three passes can be run independently or in combination: + +| Flag | Behaviour | WorkGraph label | +|---|---|---| +| `constrained=True` *(default)* | Relax only `elements_to_relax` | `inpainted_constrained_relaxation` | +| `full=True` | Full relax on the *constrained* output | `pre_relaxed_inpainted_full_relaxation` | +| `full_direct=True` | Full relax directly on inpainted structures | `unrelaxed_inpainted_full_relaxation` | + +`full` and `full_direct` together give a direct comparison between relaxing from the raw inpainted geometry versus relaxing from an already-constrained geometry. + +#### Post-relaxation steps + +`refine` and `filter_unique` run *after each active pass*, in order: relax → refine → deduplicate. + +```python +relaxation=RelaxationConfig( + params=RelaxationParams( + mlip="mattersim", + optimizer="BFGS", + fmax=0.05, + max_n_steps=500, + elements_to_relax=["H"], # required when constrained=True + return_initial_energies=False, + return_final_forces=False, + ), + # Which passes to run: + constrained=True, # relax only H atoms + full=True, # then do a full relax on that output + full_direct=False, # skip direct full relax + + # Post-relaxation processing (applied to each pass): + refine=True, + refinement_symprec=0.01, + refinement_primitive=False, + filter_unique=True, + uniqueness=UniquenessConfig( + symprec=0.01, + ltol=0.2, + stol=0.3, + angle_tol=5.0, + ), +) +``` + +!!! warning "Constraints for `constrained`" + `constrained=True` requires `params.elements_to_relax` to be set. + `full=True` requires `constrained=True` (the full-relax pass operates on the constrained output). + +--- + +## Running without AiiDA + +Without AiiDA, pass the `inpainting` config directly to the pipeline functions. The `aiida` block is simply not set. + +```python +from xtalpaint.inpainting.config_schema import XtalPaintConfig, InpaintingConfig +from xtalpaint.inpainting.inpainting_process import run_inpainting_pipeline +from xtalpaint.utils.relaxation_utils import relax_structures + +config = XtalPaintConfig( + structures={"host_001": structure}, + candidate_generation=CandidateGenerationConfig(n_inp=2, element="H"), + inpainting=InpaintingConfig( + model_path="/path/to/checkpoint.ckpt", + predictor_corrector="baseline", + N_steps=5, + coordinates_snr=0.2, + n_corrector_steps=1, + batch_size=1000, + sampling_config_path="/path/to/sampling_conf", + ), +) + +# Run inpainting +results = run_inpainting_pipeline( + structures=config.structures, + config=config.inpainting, # pass InpaintingConfig directly +) +inpainted = results["structures"] + +# Optional relaxation (using config.relaxation.relax_inputs()) +if config.relaxation is not None: + relaxed = relax_structures( + inpainted, + **config.relaxation.relax_inputs(constrained=True), + ) +``` + +!!! tip + `InpaintingConfig.model_dump(exclude_none=True)` produces a plain dict that the pipeline functions also accept, which is convenient when serialising configs to JSON/YAML. + +--- + +## Running with AiiDA + +Add the `aiida` block to the same `XtalPaintConfig`. Everything else stays identical — the pipeline stages, their parameters, and all validation remain unchanged. + +```python +from xtalpaint.inpainting.config_schema import AiiDAOptions, AiiDATaskOptions +from xtalpaint.aiida.workgraphs.inpainting_graph_task import setup_inpainting_wg + +config = XtalPaintConfig( + structures=..., + candidate_generation=..., + inpainting=..., + relaxation=..., + + # AiiDA-specific block — ignored in plain-Python runs + aiida=AiiDAOptions( + default_code_label="xtalpaint@localhost", # fallback for all tasks + relax_code_label="xtalpaint@hpc", # override for relaxation + inpainting_options=AiiDATaskOptions( + resources={"num_machines": 1, "num_mpiprocs_per_machine": 4}, + max_wallclock_seconds=3600, + withmpi=True, + ), + relax_options=AiiDATaskOptions( + resources={"num_machines": 2, "num_mpiprocs_per_machine": 8}, + withmpi=True, + ), + ), +) + +# Build and submit the WorkGraph +wg = setup_inpainting_wg(config) +wg.submit() +``` + +### Code label resolution + +Each task resolves its code label in order: + +1. Task-specific label (`inpainting_code_label`, `relax_code_label`, `candidate_generation_code_label`) +2. Fall back to `default_code_label` + +This lets you run most tasks on one machine and override just the resource-heavy ones. + +### `AiiDATaskOptions` fields + +| Field | Type | Default | Description | +|---|---|---|---| +| `resources` | `dict` | `{}` | AiiDA scheduler resource dict | +| `max_wallclock_seconds` | `int \| None` | `None` | Wall-clock limit | +| `queue_name` | `str \| None` | `None` | Scheduler queue/partition | +| `withmpi` | `bool` | `False` | Enable MPI-parallel execution | + +--- + +## Full examples + +=== "Without AiiDA" + + ```python + from pymatgen.core import Structure + from xtalpaint.inpainting.config_schema import ( + XtalPaintConfig, + CandidateGenerationConfig, + InpaintingConfig, + RefinementConfig, + RelaxationConfig, + RelaxationParams, + ) + from xtalpaint.inpainting.inpainting_process import run_inpainting_pipeline + + structure = Structure.from_file("host.cif") + + config = XtalPaintConfig( + structures={"host": structure}, + candidate_generation=CandidateGenerationConfig( + n_inp=2, + element="H", + ), + inpainting=InpaintingConfig( + pretrained_name="mattergen_base", + predictor_corrector="baseline", + N_steps=5, + coordinates_snr=0.2, + n_corrector_steps=1, + batch_size=1000, + ), + pre_refinement=RefinementConfig(symprec=0.01), + relaxation=RelaxationConfig( + params=RelaxationParams( + mlip="mattersim", + optimizer="BFGS", + elements_to_relax=["H"], + fmax=0.05, + ), + constrained=True, + refine=True, + filter_unique=True, + ), + # no aiida block → plain Python execution + ) + + results = run_inpainting_pipeline( + structures=config.structures, + config=config.inpainting, + ) + print(results["structures"]) + ``` + +=== "With AiiDA" + + ```python + from pymatgen.core import Structure + from xtalpaint.inpainting.config_schema import ( + XtalPaintConfig, + CandidateGenerationConfig, + InpaintingConfig, + RefinementConfig, + RelaxationConfig, + RelaxationParams, + AiiDAOptions, + AiiDATaskOptions, + ) + from xtalpaint.aiida.workgraphs.inpainting_graph_task import setup_inpainting_wg + + structure = Structure.from_file("host.cif") + + config = XtalPaintConfig( + structures={"host": structure}, + candidate_generation=CandidateGenerationConfig( + n_inp=2, + element="H", + ), + inpainting=InpaintingConfig( + pretrained_name="mattergen_base", + predictor_corrector="baseline", + N_steps=5, + coordinates_snr=0.2, + n_corrector_steps=1, + batch_size=1000, + ), + pre_refinement=RefinementConfig(symprec=0.01), + relaxation=RelaxationConfig( + params=RelaxationParams( + mlip="mattersim", + optimizer="BFGS", + elements_to_relax=["H"], + fmax=0.05, + ), + constrained=True, + refine=True, + filter_unique=True, + ), + aiida=AiiDAOptions( + default_code_label="xtalpaint@localhost", + relax_code_label="xtalpaint@hpc", + inpainting_options=AiiDATaskOptions( + resources={"num_machines": 1, "num_mpiprocs_per_machine": 4}, + withmpi=True, + ), + relax_options=AiiDATaskOptions( + resources={"num_machines": 2, "num_mpiprocs_per_machine": 8}, + withmpi=True, + ), + ), + ) + + wg = setup_inpainting_wg(config) + wg.submit() + ``` + +The two snippets are identical except for the `aiida=` block. This means you can develop and test workflows locally (without AiiDA) and then promote them to a remote HPC environment by adding the `aiida` block — no other changes needed. + +--- + +## Configuration reference summary + +| Class | Required fields | Purpose | +|---|---|---| +| `XtalPaintConfig` | `structures`, `inpainting` | Top-level workflow config | +| `CandidateGenerationConfig` | `n_inp`, `element` | Generate inpainting masks | +| `InpaintingConfig` | `predictor_corrector`, `N_steps`, `coordinates_snr`, `n_corrector_steps`, `batch_size`, one of `pretrained_name`/`model_path` | Diffusion sampling | +| `RefinementConfig` | — | Symmetry refinement before relaxation | +| `RelaxationGraphConfig` | `params` | Single-pass input for `relaxation_graph` | +| `RelaxationConfig` | `params` | Multi-pass relaxation stage (extends `RelaxationGraphConfig`) | +| `RelaxationParams` | `mlip`, `optimizer` | Inputs forwarded to `relax_structures()` | +| `UniquenessConfig` | — | Deduplication tolerances | +| `AiiDAOptions` | — | Code labels + per-task scheduler options | +| `AiiDATaskOptions` | — | Resources, wall-clock, MPI flag | + +### Using `relaxation_graph` directly + +`relaxation_graph` accepts `RelaxationGraphConfig` directly, so you can call it outside the inpainting WorkGraph without needing the full `RelaxationConfig` (which carries inpainting-WG-specific pass flags): + +```python +from xtalpaint.aiida.workgraphs.relaxation import relaxation_graph +from xtalpaint.inpainting.config_schema import ( + RelaxationGraphConfig, RelaxationParams, UniquenessConfig, + AiiDATaskOptions, +) + +relax_cfg = RelaxationGraphConfig( + params=RelaxationParams(mlip="mattersim", optimizer="BFGS", elements_to_relax=["H"]), + refine=True, + filter_unique=True, + uniqueness=UniquenessConfig(symprec=0.01), +) + +out = relaxation_graph( + structures=my_structures, + relax_config=relax_cfg, + aiida_options=AiiDATaskOptions(withmpi=True, resources={"num_machines": 1}), + code_label="xtalpaint@hpc", + constrained=True, # True → include elements_to_relax; False → full relax +) +``` + +Since `RelaxationConfig` inherits from `RelaxationGraphConfig`, you can also pass a `RelaxationConfig` directly wherever `RelaxationGraphConfig` is expected. diff --git a/docs/index.md b/docs/index.md index 98e15d4..12ca88b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,6 +1,5 @@ # XtalPaint – A framework for crystal structure inpainting based on diffusion models - Welcome to the `XtalPaint` Documentation. ## Overview @@ -16,10 +15,12 @@ Welcome to the `XtalPaint` Documentation. ## Getting Started -Check out the examples on to run the inpainting pipeline: +Read the [Configuration Guide](configuration.md) to learn how to specify workflows and understand the AiiDA vs. plain-Python execution modes. + +Then check out the worked examples: -* [With AiiDA integration](examples/running-with-AiiDA.ipynb) -* [Without AiiDA integration](examples/running-wo-AiiDA.ipynb) +- [With AiiDA integration](examples/running-with-AiiDA.ipynb) +- [Without AiiDA integration](examples/running-wo-AiiDA.ipynb) ## Installation @@ -40,7 +41,6 @@ uv pip install .[aiida] Model checkpoints for the retrained versions of MatterGen used in our work can be downloaded from [Hugging Face](https://huggingface.co/t-reents/XtalPaint). Currently, the repository contains the `pos-only` and `TD-pos-only` models discussed in the paper. - ## Acknowledgements This project is developed to perform crystal structure inpainting, currently on top of Microsoft's [MatterGen](https://github.com/microsoft/mattergen). Some parts of the codebase are adapted from MatterGen's implementation (as highlighted in the respective files). diff --git a/mkdocs.yml b/mkdocs.yml index 74cdeb5..ffa4016 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -49,6 +49,7 @@ markdown_extensions: nav: - Home: index.md + - Configuration Guide: configuration.md - Examples: - Running with AiiDA: examples/running-with-AiiDA.ipynb - Running without AiiDA: examples/running-wo-AiiDA.ipynb From 7b3940170a9cc016b8b89c77292da36ee0b4c1ef Mon Sep 17 00:00:00 2001 From: t-reents Date: Tue, 26 May 2026 16:20:29 +0200 Subject: [PATCH 5/7] Fix import --- tests/test_relaxation_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_relaxation_graph.py b/tests/test_relaxation_graph.py index 24dfcd7..17ac112 100644 --- a/tests/test_relaxation_graph.py +++ b/tests/test_relaxation_graph.py @@ -12,7 +12,7 @@ from xtalpaint.aiida.data import BatchedStructuresData from xtalpaint.aiida.workgraphs.relaxation import relaxation_graph from xtalpaint.data import BatchedStructures -from xtalpaint.eval import filter_unique_structures +from xtalpaint.utils.structure_utils import filter_unique_structures # --------------------------------------------------------------------------- From e6bbf32762eab16d1749393e2dec3ecf4adef2c7 Mon Sep 17 00:00:00 2001 From: Timo Reents <77727843+t-reents@users.noreply.github.com> Date: Wed, 27 May 2026 09:12:41 +0200 Subject: [PATCH 6/7] Update docs/index.md --- docs/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.md b/docs/index.md index 12ca88b..93a8a5c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -17,7 +17,7 @@ Welcome to the `XtalPaint` Documentation. Read the [Configuration Guide](configuration.md) to learn how to specify workflows and understand the AiiDA vs. plain-Python execution modes. -Then check out the worked examples: +Afterwards, check out the examples: - [With AiiDA integration](examples/running-with-AiiDA.ipynb) - [Without AiiDA integration](examples/running-wo-AiiDA.ipynb) From cc53d511932165e4bad8f529abc1f18fcc5a5465 Mon Sep 17 00:00:00 2001 From: Timo Reents <77727843+t-reents@users.noreply.github.com> Date: Wed, 27 May 2026 09:12:56 +0200 Subject: [PATCH 7/7] Update src/xtalpaint/inpainting/config_schema.py --- src/xtalpaint/inpainting/config_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/xtalpaint/inpainting/config_schema.py b/src/xtalpaint/inpainting/config_schema.py index 620110c..a0ecd6a 100644 --- a/src/xtalpaint/inpainting/config_schema.py +++ b/src/xtalpaint/inpainting/config_schema.py @@ -278,7 +278,7 @@ def get_code_label(self, specific: Optional[str] = None) -> Optional[str]: class XtalPaintConfig(BaseModel): - """Complete configuration for an XtalPaint inpainting workflow. + """Complete configuration for the XtalPaint inpainting workflow. Works for both AiiDA-based (WorkGraph) and plain-Python execution. AiiDA-specific settings live in the optional ``aiida`` block and are