diff --git a/docs/configuration.md b/docs/configuration.md new file mode 100644 index 0000000..2009593 --- /dev/null +++ b/docs/configuration.md @@ -0,0 +1,480 @@ +# Workflow Configuration + +XtalPaint uses a single configuration object — `XtalPaintConfig` — to drive both plain-Python and AiiDA-based workflows. This page explains how to build that config, what each field does, and how the same config object is shared between the two execution modes. + +--- + +## Design principles + +- **Presence = enabled, `None` = skip.** Each pipeline stage is controlled by its own typed config object. If the field is `None` the stage is omitted — no boolean flags needed. +- **AiiDA options are isolated.** Everything in the optional `aiida` block (code labels, scheduler resources) is invisible to plain-Python execution. +- **Flat, validated inputs.** Pydantic validates every field at construction time so mistakes surface immediately rather than at run time. + +--- + +## Top-level structure + +```python +from xtalpaint.inpainting.config_schema import ( + XtalPaintConfig, + CandidateGenerationConfig, + InpaintingConfig, + RefinementConfig, + RelaxationConfig, + RelaxationParams, + UniquenessConfig, + AiiDAOptions, + AiiDATaskOptions, +) + +config = XtalPaintConfig( + structures=..., # required — dict[str, Structure] or BatchedStructures + run_inpainting=True, # set False to skip the diffusion step + candidate_generation=..., # CandidateGenerationConfig | None + pre_refinement=..., # RefinementConfig | None + inpainting=..., # InpaintingConfig (always required) + relaxation=..., # RelaxationConfig | None + aiida=..., # AiiDAOptions | None (ignored outside AiiDA) +) +``` + +The pipeline runs in this order when a stage is enabled: + +``` +candidate_generation → inpainting → pre_refinement → relaxation + ├─ constrained pass + ├─ full pass (on constrained output) + └─ full_direct pass (on inpainted directly) + each pass: → [refine] → [filter_unique] +``` + +--- + +## Stage reference + +### Input structures + +```python +from pymatgen.core import Structure + +config = XtalPaintConfig( + structures={"host_001": Structure(...), "host_002": Structure(...)}, + inpainting=..., +) +``` + +`structures` accepts: + +- `dict[str, Structure]` — plain pymatgen structures +- `BatchedStructures` — XtalPaint's batched wrapper +- AiiDA `StructureData` / `InpaintingStructureData` — when running inside AiiDA + +--- + +### Candidate generation + +Required when the input structures are plain `Structure` objects (not yet marked as inpainting targets). Omit this block if your structures are already `InpaintingStructureData` instances. + +```python +candidate_generation=CandidateGenerationConfig( + n_inp=2, # int or dict[str, int] — number of sites to inpaint + element="H", # element to place; dict[str, str] for per-structure control + num_samples=1, # how many candidate sets to generate +) +``` + +For per-structure control over the number of sites and element: + +```python +candidate_generation=CandidateGenerationConfig( + n_inp={"host_001": 2, "host_002": 4}, + element={"host_001": "H", "host_002": "Li"}, +) +``` + +--- + +### Inpainting + +The core diffusion stage. All sampling parameters live in one flat block. + +```python +inpainting=InpaintingConfig( + # Model — provide exactly one of these: + pretrained_name="mattergen_base", # use a bundled pretrained checkpoint + # model_path="/path/to/checkpoint", # or point to a local file + + # Sampling + predictor_corrector="baseline", # see supported keys below + N_steps=5, + coordinates_snr=0.2, + n_corrector_steps=1, + batch_size=1000, + + # Optional + fix_cell=True, # keep unit cell fixed during sampling + record_trajectories=False, + sampling_config_path=None, # override MatterGen sampling config dir +) +``` + +**Supported `predictor_corrector` values:** + +| Key | Description | +|---|---| +| `baseline` | Standard guided predictor-corrector | +| `baseline-with-noise` | Custom variant with additional noise | +| `baseline-store-scores` | Records score function outputs | +| `repaint-v1` | RePaint resampling (legacy) | +| `repaint-v2` | RePaint resampling (v2) | +| `TD` | Time-dependent (TD-Paint) variant | + +!!! note "Repaint variants" + When using `repaint-v1` or `repaint-v2`, you must also set `n_resample_steps` and `jump_length`: + + ```python + inpainting=InpaintingConfig( + predictor_corrector="repaint-v2", + n_resample_steps=10, + jump_length=5, + # ... other fields ... + ) + ``` + +--- + +### Pre-refinement + +Optional symmetry refinement applied *after* inpainting and *before* relaxation. + +```python +pre_refinement=RefinementConfig( + symprec=0.01, # symmetry precision for SpacegroupAnalyzer + primitive=False, # if True, convert to primitive cell +) +``` + +Omit `pre_refinement` (or set it to `None`) to skip this step. + +--- + +### Relaxation + +The relaxation stage is split into two distinct layers: + +- **`RelaxationParams`** — the inputs forwarded directly to `relax_structures()` (MLIP, optimiser, convergence) +- **`RelaxationConfig`** — workflow-level controls: which passes to run and post-relaxation processing + +#### Relaxation passes + +Three passes can be run independently or in combination: + +| Flag | Behaviour | WorkGraph label | +|---|---|---| +| `constrained=True` *(default)* | Relax only `elements_to_relax` | `inpainted_constrained_relaxation` | +| `full=True` | Full relax on the *constrained* output | `pre_relaxed_inpainted_full_relaxation` | +| `full_direct=True` | Full relax directly on inpainted structures | `unrelaxed_inpainted_full_relaxation` | + +`full` and `full_direct` together give a direct comparison between relaxing from the raw inpainted geometry versus relaxing from an already-constrained geometry. + +#### Post-relaxation steps + +`refine` and `filter_unique` run *after each active pass*, in order: relax → refine → deduplicate. + +```python +relaxation=RelaxationConfig( + params=RelaxationParams( + mlip="mattersim", + optimizer="BFGS", + fmax=0.05, + max_n_steps=500, + elements_to_relax=["H"], # required when constrained=True + return_initial_energies=False, + return_final_forces=False, + ), + # Which passes to run: + constrained=True, # relax only H atoms + full=True, # then do a full relax on that output + full_direct=False, # skip direct full relax + + # Post-relaxation processing (applied to each pass): + refine=True, + refinement_symprec=0.01, + refinement_primitive=False, + filter_unique=True, + uniqueness=UniquenessConfig( + symprec=0.01, + ltol=0.2, + stol=0.3, + angle_tol=5.0, + ), +) +``` + +!!! warning "Constraints for `constrained`" + `constrained=True` requires `params.elements_to_relax` to be set. + `full=True` requires `constrained=True` (the full-relax pass operates on the constrained output). + +--- + +## Running without AiiDA + +Without AiiDA, pass the `inpainting` config directly to the pipeline functions. The `aiida` block is simply not set. + +```python +from xtalpaint.inpainting.config_schema import XtalPaintConfig, InpaintingConfig +from xtalpaint.inpainting.inpainting_process import run_inpainting_pipeline +from xtalpaint.utils.relaxation_utils import relax_structures + +config = XtalPaintConfig( + structures={"host_001": structure}, + candidate_generation=CandidateGenerationConfig(n_inp=2, element="H"), + inpainting=InpaintingConfig( + model_path="/path/to/checkpoint.ckpt", + predictor_corrector="baseline", + N_steps=5, + coordinates_snr=0.2, + n_corrector_steps=1, + batch_size=1000, + sampling_config_path="/path/to/sampling_conf", + ), +) + +# Run inpainting +results = run_inpainting_pipeline( + structures=config.structures, + config=config.inpainting, # pass InpaintingConfig directly +) +inpainted = results["structures"] + +# Optional relaxation (using config.relaxation.relax_inputs()) +if config.relaxation is not None: + relaxed = relax_structures( + inpainted, + **config.relaxation.relax_inputs(constrained=True), + ) +``` + +!!! tip + `InpaintingConfig.model_dump(exclude_none=True)` produces a plain dict that the pipeline functions also accept, which is convenient when serialising configs to JSON/YAML. + +--- + +## Running with AiiDA + +Add the `aiida` block to the same `XtalPaintConfig`. Everything else stays identical — the pipeline stages, their parameters, and all validation remain unchanged. + +```python +from xtalpaint.inpainting.config_schema import AiiDAOptions, AiiDATaskOptions +from xtalpaint.aiida.workgraphs.inpainting_graph_task import setup_inpainting_wg + +config = XtalPaintConfig( + structures=..., + candidate_generation=..., + inpainting=..., + relaxation=..., + + # AiiDA-specific block — ignored in plain-Python runs + aiida=AiiDAOptions( + default_code_label="xtalpaint@localhost", # fallback for all tasks + relax_code_label="xtalpaint@hpc", # override for relaxation + inpainting_options=AiiDATaskOptions( + resources={"num_machines": 1, "num_mpiprocs_per_machine": 4}, + max_wallclock_seconds=3600, + withmpi=True, + ), + relax_options=AiiDATaskOptions( + resources={"num_machines": 2, "num_mpiprocs_per_machine": 8}, + withmpi=True, + ), + ), +) + +# Build and submit the WorkGraph +wg = setup_inpainting_wg(config) +wg.submit() +``` + +### Code label resolution + +Each task resolves its code label in order: + +1. Task-specific label (`inpainting_code_label`, `relax_code_label`, `candidate_generation_code_label`) +2. Fall back to `default_code_label` + +This lets you run most tasks on one machine and override just the resource-heavy ones. + +### `AiiDATaskOptions` fields + +| Field | Type | Default | Description | +|---|---|---|---| +| `resources` | `dict` | `{}` | AiiDA scheduler resource dict | +| `max_wallclock_seconds` | `int \| None` | `None` | Wall-clock limit | +| `queue_name` | `str \| None` | `None` | Scheduler queue/partition | +| `withmpi` | `bool` | `False` | Enable MPI-parallel execution | + +--- + +## Full examples + +=== "Without AiiDA" + + ```python + from pymatgen.core import Structure + from xtalpaint.inpainting.config_schema import ( + XtalPaintConfig, + CandidateGenerationConfig, + InpaintingConfig, + RefinementConfig, + RelaxationConfig, + RelaxationParams, + ) + from xtalpaint.inpainting.inpainting_process import run_inpainting_pipeline + + structure = Structure.from_file("host.cif") + + config = XtalPaintConfig( + structures={"host": structure}, + candidate_generation=CandidateGenerationConfig( + n_inp=2, + element="H", + ), + inpainting=InpaintingConfig( + pretrained_name="mattergen_base", + predictor_corrector="baseline", + N_steps=5, + coordinates_snr=0.2, + n_corrector_steps=1, + batch_size=1000, + ), + pre_refinement=RefinementConfig(symprec=0.01), + relaxation=RelaxationConfig( + params=RelaxationParams( + mlip="mattersim", + optimizer="BFGS", + elements_to_relax=["H"], + fmax=0.05, + ), + constrained=True, + refine=True, + filter_unique=True, + ), + # no aiida block → plain Python execution + ) + + results = run_inpainting_pipeline( + structures=config.structures, + config=config.inpainting, + ) + print(results["structures"]) + ``` + +=== "With AiiDA" + + ```python + from pymatgen.core import Structure + from xtalpaint.inpainting.config_schema import ( + XtalPaintConfig, + CandidateGenerationConfig, + InpaintingConfig, + RefinementConfig, + RelaxationConfig, + RelaxationParams, + AiiDAOptions, + AiiDATaskOptions, + ) + from xtalpaint.aiida.workgraphs.inpainting_graph_task import setup_inpainting_wg + + structure = Structure.from_file("host.cif") + + config = XtalPaintConfig( + structures={"host": structure}, + candidate_generation=CandidateGenerationConfig( + n_inp=2, + element="H", + ), + inpainting=InpaintingConfig( + pretrained_name="mattergen_base", + predictor_corrector="baseline", + N_steps=5, + coordinates_snr=0.2, + n_corrector_steps=1, + batch_size=1000, + ), + pre_refinement=RefinementConfig(symprec=0.01), + relaxation=RelaxationConfig( + params=RelaxationParams( + mlip="mattersim", + optimizer="BFGS", + elements_to_relax=["H"], + fmax=0.05, + ), + constrained=True, + refine=True, + filter_unique=True, + ), + aiida=AiiDAOptions( + default_code_label="xtalpaint@localhost", + relax_code_label="xtalpaint@hpc", + inpainting_options=AiiDATaskOptions( + resources={"num_machines": 1, "num_mpiprocs_per_machine": 4}, + withmpi=True, + ), + relax_options=AiiDATaskOptions( + resources={"num_machines": 2, "num_mpiprocs_per_machine": 8}, + withmpi=True, + ), + ), + ) + + wg = setup_inpainting_wg(config) + wg.submit() + ``` + +The two snippets are identical except for the `aiida=` block. This means you can develop and test workflows locally (without AiiDA) and then promote them to a remote HPC environment by adding the `aiida` block — no other changes needed. + +--- + +## Configuration reference summary + +| Class | Required fields | Purpose | +|---|---|---| +| `XtalPaintConfig` | `structures`, `inpainting` | Top-level workflow config | +| `CandidateGenerationConfig` | `n_inp`, `element` | Generate inpainting masks | +| `InpaintingConfig` | `predictor_corrector`, `N_steps`, `coordinates_snr`, `n_corrector_steps`, `batch_size`, one of `pretrained_name`/`model_path` | Diffusion sampling | +| `RefinementConfig` | — | Symmetry refinement before relaxation | +| `RelaxationGraphConfig` | `params` | Single-pass input for `relaxation_graph` | +| `RelaxationConfig` | `params` | Multi-pass relaxation stage (extends `RelaxationGraphConfig`) | +| `RelaxationParams` | `mlip`, `optimizer` | Inputs forwarded to `relax_structures()` | +| `UniquenessConfig` | — | Deduplication tolerances | +| `AiiDAOptions` | — | Code labels + per-task scheduler options | +| `AiiDATaskOptions` | — | Resources, wall-clock, MPI flag | + +### Using `relaxation_graph` directly + +`relaxation_graph` accepts `RelaxationGraphConfig` directly, so you can call it outside the inpainting WorkGraph without needing the full `RelaxationConfig` (which carries inpainting-WG-specific pass flags): + +```python +from xtalpaint.aiida.workgraphs.relaxation import relaxation_graph +from xtalpaint.inpainting.config_schema import ( + RelaxationGraphConfig, RelaxationParams, UniquenessConfig, + AiiDATaskOptions, +) + +relax_cfg = RelaxationGraphConfig( + params=RelaxationParams(mlip="mattersim", optimizer="BFGS", elements_to_relax=["H"]), + refine=True, + filter_unique=True, + uniqueness=UniquenessConfig(symprec=0.01), +) + +out = relaxation_graph( + structures=my_structures, + relax_config=relax_cfg, + aiida_options=AiiDATaskOptions(withmpi=True, resources={"num_machines": 1}), + code_label="xtalpaint@hpc", + constrained=True, # True → include elements_to_relax; False → full relax +) +``` + +Since `RelaxationConfig` inherits from `RelaxationGraphConfig`, you can also pass a `RelaxationConfig` directly wherever `RelaxationGraphConfig` is expected. diff --git a/docs/index.md b/docs/index.md index 98e15d4..93a8a5c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,6 +1,5 @@ # XtalPaint – A framework for crystal structure inpainting based on diffusion models - Welcome to the `XtalPaint` Documentation. ## Overview @@ -16,10 +15,12 @@ Welcome to the `XtalPaint` Documentation. ## Getting Started -Check out the examples on to run the inpainting pipeline: +Read the [Configuration Guide](configuration.md) to learn how to specify workflows and understand the AiiDA vs. plain-Python execution modes. + +Afterwards, check out the examples: -* [With AiiDA integration](examples/running-with-AiiDA.ipynb) -* [Without AiiDA integration](examples/running-wo-AiiDA.ipynb) +- [With AiiDA integration](examples/running-with-AiiDA.ipynb) +- [Without AiiDA integration](examples/running-wo-AiiDA.ipynb) ## Installation @@ -40,7 +41,6 @@ uv pip install .[aiida] Model checkpoints for the retrained versions of MatterGen used in our work can be downloaded from [Hugging Face](https://huggingface.co/t-reents/XtalPaint). Currently, the repository contains the `pos-only` and `TD-pos-only` models discussed in the paper. - ## Acknowledgements This project is developed to perform crystal structure inpainting, currently on top of Microsoft's [MatterGen](https://github.com/microsoft/mattergen). Some parts of the codebase are adapted from MatterGen's implementation (as highlighted in the respective files). diff --git a/mkdocs.yml b/mkdocs.yml index 74cdeb5..ffa4016 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -49,6 +49,7 @@ markdown_extensions: nav: - Home: index.md + - Configuration Guide: configuration.md - Examples: - Running with AiiDA: examples/running-with-AiiDA.ipynb - Running without AiiDA: examples/running-wo-AiiDA.ipynb diff --git a/src/xtalpaint/aiida/tasks/tasks.py b/src/xtalpaint/aiida/tasks/tasks.py index 7811da7..77fb555 100644 --- a/src/xtalpaint/aiida/tasks/tasks.py +++ b/src/xtalpaint/aiida/tasks/tasks.py @@ -6,7 +6,6 @@ from aiida_workgraph import spec, task from aiida_workgraph.socket_spec import meta from pymatgen.core.structure import Structure -from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from xtalpaint.aiida.data import ( BatchedStructures, @@ -24,12 +23,32 @@ run_mpi_parallel_inpainting_pipeline, ) from xtalpaint.utils.relaxation_utils import relax_structures +from xtalpaint.utils.structure_utils import ( + filter_unique_structures, + refine_structures, +) + +evaluate_inpainting_task = task.pythonjob( + outputs=spec.namespace( + metric_results=t.Any, + ), +)(evaluate_inpainting) + + +refine_structures_task = task.pythonjob( + outputs=spec.namespace(structures=t.Any), +)(refine_structures) + + +uniqueness_filter_task = task.pythonjob( + outputs=spec.namespace(unique_structures=t.Any), +)(filter_unique_structures) @task.pythonjob( outputs=spec.namespace(candidates=t.Any), ) -def _generate_inpainting_candidates_task( +def generate_inpainting_candidates_task( structures: t.Union[Structure, t.Iterable[Structure]] | BatchedStructures, n_inp: t.Union[ int, t.Tuple[int, int], t.List[int], t.List[t.Tuple[int, int]] @@ -37,6 +56,7 @@ def _generate_inpainting_candidates_task( element: t.Union[str, t.List[str]], num_samples: int = 1, ) -> BatchedStructures: + """Task wrapper for the inpainting candidates generation.""" if isinstance(structures, BatchedStructures): structures = structures.get_structures("pymatgen") candidates = generate_inpainting_candidates( @@ -49,40 +69,6 @@ def _generate_inpainting_candidates_task( return {"candidates": BatchedStructures(candidates)} -@task.pythonjob( - outputs=spec.namespace(structures=t.Any), -) -def _refine_structures_task( - structures: t.Union[Structure, t.Iterable[Structure]] | BatchedStructures, - refinement_symprec: float, - primitive: bool = False, -) -> BatchedStructures: - """Refine structures to standard conventional cells.""" - if isinstance(structures, BatchedStructures): - structures = structures.get_structures("pymatgen") - - refined_structures = {} - for k, s in structures.items(): - analyzer = SpacegroupAnalyzer(s, symprec=refinement_symprec) - try: - refined_structure = analyzer.get_refined_structure() - except Exception: - refined_structure = s - - if primitive: - analyzer = SpacegroupAnalyzer( - refined_structure, symprec=refinement_symprec - ) - try: - refined_structure = analyzer.get_primitive_structure() - except Exception: - refined_structure = refined_structure - - refined_structures[k] = refined_structure - - return {"structures": BatchedStructures(refined_structures)} - - @task.pythonjob( outputs=spec.namespace( structures=t.Any, @@ -95,24 +81,18 @@ def _refine_structures_task( ], ) ) -def _inpainting_pipeline_task( +def inpainting_pipeline_task( structures: t.Union[Structure, t.Iterable[Structure]] | BatchedStructures, config: dict, usempi: bool = False, ): + """Task wrapper for the inpainting pipeline.""" if usempi: return run_mpi_parallel_inpainting_pipeline(structures, config) return run_inpainting_pipeline(structures, config) -_evaluate_inpainting_task = task.pythonjob( - outputs=spec.namespace( - metric_results=t.Any, - ), -)(evaluate_inpainting) - - @task.pythonjob( outputs=spec.namespace( structures=t.Any, @@ -122,7 +102,7 @@ def _inpainting_pipeline_task( final_forces=spec.socket(t.Any, required=False), ) ) -def _relaxation_task( +def relaxation_task( structures: t.Union[ dict[str, Structure], BatchedStructuresData, BatchedStructures ], diff --git a/src/xtalpaint/aiida/workgraphs/__init__.py b/src/xtalpaint/aiida/workgraphs/__init__.py index 2d17522..4a3aed5 100644 --- a/src/xtalpaint/aiida/workgraphs/__init__.py +++ b/src/xtalpaint/aiida/workgraphs/__init__.py @@ -1 +1,5 @@ """Modules defining workgraphs for inpainting tasks.""" + +from xtalpaint.aiida.workgraphs.relaxation import relaxation_graph + +__all__ = ("relaxation_graph",) diff --git a/src/xtalpaint/aiida/workgraphs/inpainting.py b/src/xtalpaint/aiida/workgraphs/inpainting.py index 2e55625..866ecec 100644 --- a/src/xtalpaint/aiida/workgraphs/inpainting.py +++ b/src/xtalpaint/aiida/workgraphs/inpainting.py @@ -1,296 +1,146 @@ """AiiDA WorkGraph for inpainting of crystal structures.""" -from copy import deepcopy - from aiida import orm -from aiida_workgraph import WorkGraph -from pymatgen.core.structure import Structure - -from xtalpaint.aiida.data import ( - BatchedStructuresData, -) -from xtalpaint.aiida.tasks.tasks import ( - _evaluate_inpainting_task, - _generate_inpainting_candidates_task, - _inpainting_pipeline_task, - _refine_structures_task, - _relaxation_task, +from aiida_workgraph import WorkGraph, task + +from xtalpaint.aiida.tasks import tasks +from xtalpaint.aiida.workgraphs.relaxation import relaxation_graph +from xtalpaint.inpainting.config_schema import ( + AiiDAOptions, + RelaxationGraphConfig, + XtalPaintConfig, ) -from xtalpaint.data import BatchedStructures -from xtalpaint.inpainting.config_schema import InpaintingWorkflowConfig -def setup_inpainting_wg( - inputs: InpaintingWorkflowConfig, -) -> WorkGraph: - """Create a WorkGraph for inpainting of crystal structures.""" - possible_relaxation_tasks = { - "inpainted_constrained_relaxation": inputs.relax, - "unrelaxed_inpainted_full_relaxation": inputs.full_relax - and inputs.full_relax_wo_pre_relax, - "pre_relaxed_inpainted_full_relaxation": inputs.full_relax - and inputs.relax, +def _relax_outputs(prefix: str, out, relax: RelaxationGraphConfig) -> dict: + """Build a prefixed output dict from a relaxation_graph result.""" + outputs = { + f"{prefix}.structures": out.structures, + f"{prefix}.final_energies": out.final_energies, } + if relax.params.return_initial_energies: + outputs[f"{prefix}.initial_energies"] = out.initial_energies + if relax.params.return_initial_forces: + outputs[f"{prefix}.initial_forces"] = out.initial_forces + if relax.params.return_final_forces: + outputs[f"{prefix}.final_forces"] = out.final_forces + return outputs + - wg = WorkGraph() +@task.graph +def InpaintingWorkGraph(inputs: XtalPaintConfig): + """WorkGraph for inpainting of crystal structures.""" + graph_outputs = {} - if not inputs.is_inpainting_structures and inputs.run_inpainting: - _add_inpainting_candidates_generation(wg, inputs) + _aiida: AiiDAOptions = inputs.aiida or AiiDAOptions() + # --- Generate inpainting candidates --- if inputs.run_inpainting: - _add_inpainting_pipeline(wg, inputs) - inpainted_structures = wg.tasks["inpainting"].outputs["structures"] + cand_opts = _aiida.candidate_generation_options + gen_out = tasks.generate_inpainting_candidates_task( + structures=inputs.structures, + **inputs.candidate_generation.model_dump(), + metadata={ + "call_link_label": "generate_inpainting_candidates", + "options": cand_opts.model_dump(exclude={"withmpi"}), + }, + ) + inpainting_candidates = gen_out.candidates + graph_outputs["inpainting_candidates"] = inpainting_candidates else: - inpainted_structures = inputs.structures + inpainting_candidates = inputs.structures - if inputs.refine_structures: - _add_refinement_task( - wg, - structures=inpainted_structures, - refinement_symprec=inputs.refinement_symprec, - primitive=inputs.refinement_primitive, - inputs=inputs, - task_name="refine_structures", + # --- Inpainting pipeline --- + if inputs.run_inpainting: + inp_opts = _aiida.inpainting_options + code_label = _aiida.get_code_label(_aiida.inpainting_code_label) + inp_out = tasks.inpainting_pipeline_task( + structures=inpainting_candidates, + config=inputs.inpainting.model_dump(exclude_none=True), + usempi=inp_opts.withmpi, + metadata={ + "call_link_label": "inpainting", + "options": inp_opts.model_dump(exclude={"withmpi"}), + }, + code=orm.load_code(code_label) if code_label else None, ) - inpainted_structures = wg.tasks["refine_structures"].outputs[ - "structures" - ] - - wg.outputs.inpainted_structures = inpainted_structures - - if inputs.relax or inputs.full_relax: - _add_relaxation_tasks(wg, inpainted_structures, inputs) - - if inputs.evaluate: - relaxation_tasks = { - k: k for k, v in possible_relaxation_tasks.items() if v - } - _add_evaluation_tasks(wg, inputs, relaxation_tasks) - - return wg - - -def _add_inpainting_candidates_generation( - wg: WorkGraph, - inputs: InpaintingWorkflowConfig, -) -> None: - """Add inpainting candidates generation task to the workgraph.""" - wg.add_task( - _generate_inpainting_candidates_task, - structures=inputs.structures, - n_inp=inputs.gen_inpainting_candidates_params.n_inp, - element=inputs.gen_inpainting_candidates_params.element, - num_samples=inputs.gen_inpainting_candidates_params.num_samples, - name="generate_inpainting_candidates", - metadata={ - "options": ( - inputs.gen_inpainting_candidates_options or inputs.options - ) - }, - ) - - wg.outputs.inpainting_candidates = wg.tasks[ - "generate_inpainting_candidates" - ].outputs["candidates"] - - -def _add_refinement_task( - wg: WorkGraph, - structures: BatchedStructures | BatchedStructuresData, - refinement_symprec: float, - inputs: InpaintingWorkflowConfig, - task_name: str = "refine_structures", -) -> None: - """Add structure refinement task to the workgraph.""" - wg.add_task( - _refine_structures_task, - structures=structures, - refinement_symprec=refinement_symprec, - name=task_name, - metadata={ - "options": inputs.options or {}, - }, - ) - - -def _add_inpainting_pipeline( - wg: WorkGraph, - inputs: InpaintingWorkflowConfig, -) -> None: - """Add inpainting pipeline task to the workgraph.""" - inpainting_candidates = ( - wg.tasks["generate_inpainting_candidates"].outputs["candidates"] - if not inputs.is_inpainting_structures and inputs.run_inpainting - else inputs.structures - ) - - code_label = inputs.inpainting_code_label or inputs.code_label + inpainted_structures = inp_out.structures - wg.add_task( - _inpainting_pipeline_task, - structures=inpainting_candidates, - config=inputs.inpainting_pipeline_params.model_dump(exclude_none=True), - usempi=( - inputs.inpainting_pipeline_options.get("withmpi", False) - if inputs.inpainting_pipeline_options - else False - ), - name="inpainting", - metadata={ - "options": (inputs.inpainting_pipeline_options or inputs.options), - }, - code=orm.load_code(code_label) if code_label else None, - ) - - if inputs.inpainting_pipeline_params.record_trajectories: - wg.outputs.inpainted_trajectories = wg.tasks["inpainting"].outputs[ - "trajectories" - ] - if "mean_trajectories" in wg.tasks["inpainting"].outputs: - wg.outputs.inpainted_mean_trajectories = wg.tasks[ - "inpainting" - ].outputs["mean_trajectories"] - - -def _add_relaxation_tasks( - wg: WorkGraph, - structures: BatchedStructures | BatchedStructuresData, - inputs: InpaintingWorkflowConfig, -) -> None: - """Add relaxation tasks to the workgraph.""" - code_label = inputs.relax_code_label or inputs.code_label - relax_kwargs = deepcopy(inputs.relax_kwargs.model_dump()) + if inputs.inpainting.record_trajectories: + graph_outputs["inpainted_trajectories"] = inp_out.trajectories + else: + inpainted_structures = inputs.structures - if inputs.relax: - wg = _add_full_relax_task( - wg=wg, - structures=structures, - relax_inputs=relax_kwargs, - task_name="inpainted_constrained_relaxation", - options=inputs.relax_options or inputs.options, - code=orm.load_code(code_label) if code_label else None, - as_graph_outputs=True, + # --- Pre-refinement (before relaxation) --- + if inputs.pre_refinement is not None: + ref_out = tasks.refine_structures_task( + structures=inpainted_structures, + refinement_symprec=inputs.pre_refinement.symprec, + primitive=inputs.pre_refinement.primitive, + metadata={ + "call_link_label": "refine_structures", + "options": {}, + }, ) - - if inputs.full_relax: - relax_kwargs.pop("elements_to_relax", None) - if inputs.full_relax_wo_pre_relax: - wg = _add_full_relax_task( - wg=wg, - structures=structures, - relax_inputs=relax_kwargs, - task_name="unrelaxed_inpainted_full_relaxation", - options=inputs.relax_options or inputs.options, - code=orm.load_code(code_label) if code_label else None, - as_graph_outputs=True, + inpainted_structures = ref_out.structures + + graph_outputs["inpainted_structures"] = inpainted_structures + + # --- Relaxation --- + if inputs.relaxation is not None: + relax = inputs.relaxation + relax_opts = _aiida.relax_options + relax_code_label = _aiida.get_code_label(_aiida.relax_code_label) + + cr_out = None + if relax.constrained: + cr_out = relaxation_graph( + structures=inpainted_structures, + relax_config=relax, + aiida_options=relax_opts, + code_label=relax_code_label, + constrained=True, + metadata={ + "call_link_label": "inpainted_constrained_relaxation" + }, ) - - if inputs.relax: - wg = _add_full_relax_task( - wg=wg, - structures=wg.tasks[ - "inpainted_constrained_relaxation" - ].outputs["structures"], - relax_inputs=relax_kwargs, - task_name="pre_relaxed_inpainted_full_relaxation", - options=inputs.relax_options or inputs.options, - code=orm.load_code(code_label) if code_label else None, - as_graph_outputs=True, + graph_outputs |= _relax_outputs( + "inpainted_constrained_relaxation", cr_out, relax ) - -def _add_evaluation_tasks( - wg: WorkGraph, - inputs: InpaintingWorkflowConfig, - relaxation_tasks: dict[str, str], -) -> None: - """Add evaluation tasks to the workgraph.""" - code_label = inputs.evaluate_params.code_label or inputs.code_label - - evaluation_results = {} - metrics = ( - inputs.evaluate_params.metrics - if isinstance(inputs.evaluate_params.metrics, list) - else [inputs.evaluate_params.metrics] - ) - tasks_to_evaluate = {} - if inputs.run_inpainting: - tasks_to_evaluate["inpainting"] = "inpainting" - if inputs.refine_structures: - tasks_to_evaluate["inpainting"] = "refine_structures" - - tasks_to_evaluate.update(relaxation_tasks) - - for metric in metrics: - for label, task_name in tasks_to_evaluate.items(): - wg.add_task( - _evaluate_inpainting_task, - inpainted_structures=wg.tasks[task_name].outputs["structures"], - reference_structures=inputs.structures, - metric=metric, - max_workers=inputs.evaluate_params.max_workers, - name=f"evaluate_inpainting_{metric}_{label}", + if relax.full_direct: + ufr_out = relaxation_graph( + structures=inpainted_structures, + relax_config=relax, + aiida_options=relax_opts, + code_label=relax_code_label, + constrained=False, metadata={ - "options": inputs.options or {}, + "call_link_label": "unrelaxed_inpainted_full_relaxation" }, - code=orm.load_code(code_label) if code_label else None, ) - evaluation_results.setdefault(label, {}).update( - { - f"{metric}": wg.tasks[ - f"evaluate_inpainting_{metric}_{label}" - ].outputs["metric_results"], - } + graph_outputs |= _relax_outputs( + "unrelaxed_inpainted_full_relaxation", ufr_out, relax ) - wg.outputs.evaluation_results = evaluation_results - -def _add_full_relax_task( - wg: WorkGraph, - structures: ( - dict[str, Structure] | BatchedStructuresData | BatchedStructures - ), - relax_inputs: dict, - task_name: str = "full_relaxation", - options: dict = None, - code: orm.Code | None = None, - as_graph_outputs: bool = False, -): - """Add a full relaxation task to the workgraph.""" - wg.add_task( - _relaxation_task, - structures=structures, - relax_inputs=relax_inputs, - usempi=options.get("withmpi", False), - name=task_name, - metadata={ - "options": options or {}, - }, - code=code, - ) - if as_graph_outputs: - outputs = { - f"{task_name}.structures": wg.tasks[task_name].outputs[ - "structures" - ], - f"{task_name}.final_energies": wg.tasks[task_name].outputs[ - "final_energies" - ], - } + if relax.full and cr_out is not None: + pfr_out = relaxation_graph( + structures=cr_out.structures, + relax_config=relax, + aiida_options=relax_opts, + code_label=relax_code_label, + constrained=False, + metadata={ + "call_link_label": "pre_relaxed_inpainted_full_relaxation" + }, + ) + graph_outputs |= _relax_outputs( + "pre_relaxed_inpainted_full_relaxation", pfr_out, relax + ) - if relax_inputs.get("return_initial_energies", False): - outputs[f"{task_name}.initial_energies"] = wg.tasks[ - task_name - ].outputs["initial_energies"] - if relax_inputs.get("return_initial_forces", False): - outputs[f"{task_name}.initial_forces"] = wg.tasks[ - task_name - ].outputs["initial_forces"] - if relax_inputs.get("return_final_forces", False): - outputs[f"{task_name}.final_forces"] = wg.tasks[task_name].outputs[ - "final_forces" - ] + return graph_outputs - wg.outputs = outputs - return wg +def setup_inpainting_wg(inputs: XtalPaintConfig) -> WorkGraph: + """Create a WorkGraph for inpainting of crystal structures.""" + return InpaintingWorkGraph.build(inputs=inputs) diff --git a/src/xtalpaint/aiida/workgraphs/relaxation.py b/src/xtalpaint/aiida/workgraphs/relaxation.py new file mode 100644 index 0000000..8cb9aa9 --- /dev/null +++ b/src/xtalpaint/aiida/workgraphs/relaxation.py @@ -0,0 +1,126 @@ +"""Relaxation WorkGraph with optional refinement and uniqueness filtering.""" + +import typing as t + +from aiida import orm +from aiida_workgraph import spec, task + +from xtalpaint.aiida.tasks import tasks +from xtalpaint.inpainting.config_schema import ( + AiiDATaskOptions, + RelaxationGraphConfig, +) + + +@task.graph( + outputs=spec.namespace( + structures=t.Any, + final_energies=t.Any, + initial_energies=spec.socket(t.Any, required=False), + initial_forces=spec.socket(t.Any, required=False), + final_forces=spec.socket(t.Any, required=False), + ) +) +def relaxation_graph( + structures: t.Any, + relax_config: RelaxationGraphConfig, + aiida_options: AiiDATaskOptions = None, + code_label: str = None, + command_info: dict = None, + constrained: bool = True, +): + """Relaxation WG with optional symmetry refinement and deduplication. + + Runs ``_relaxation_task`` and then optionally: + + 1. ``_refine_structures_task`` — symmetry-refine the relaxed structures. + 2. ``_uniqueness_filter_task`` — keep one representative per unique + (space-group, StructureMatcher equivalence class) group. + + The ``structures`` output always points to the last active step, so + downstream tasks see a consistent socket name regardless of which optional + steps are enabled. + + ``relax_config.refine`` and ``relax_config.filter_unique`` are evaluated + at graph build-time (when the WorkGraph is materialised), so they must + resolve to plain Python ``bool`` values; passing AiiDA nodes wired from + another task's output is not supported for these flags. + + Args: + structures: Input structures to relax. + relax_config: Relaxation and post-processing configuration. + ``relax_config.params`` is forwarded to ``relax_structures`` as + ``relax_inputs``. ``relax_config.refine`` and + ``relax_config.filter_unique`` control the optional steps. + aiida_options: AiiDA scheduler/resource options forwarded to all inner + tasks. If ``None``, default options are used (no resource limits, + no MPI). + code_label: AiiDA code label for all inner pythonjob tasks. If + ``None``, aiida-pythonjob locates ``python3`` automatically. + command_info: Passed as ``command_info`` to every inner pythonjob task + (e.g. ``{"filepath_executable": "/path/to/python"}``). Overrides + automatic executable detection when set. + constrained: If ``True`` (default), ``elements_to_relax`` from + ``relax_config.params`` is included in the relax call so that only + those elements are relaxed. Pass ``False`` for full relaxation + of all atoms. + + Returns: + dict with ``structures`` (relaxed, and optionally refined/filtered), + ``final_energies``, and optionally ``initial_energies``, + ``initial_forces``, ``final_forces`` when requested via + ``relax_config.params``. + """ + _aiida = aiida_options or AiiDATaskOptions() + _options = _aiida.model_dump(exclude={"withmpi"}, exclude_none=True) + _code = orm.load_code(code_label) if code_label else None + _metadata = {"options": _options} + _command_info = command_info or {} + + relaxed = tasks.relaxation_task( + structures=structures, + relax_inputs=relax_config.relax_inputs(constrained=constrained), + usempi=_aiida.withmpi, + metadata=_metadata, + code=_code, + command_info=_command_info, + ) + + current_structures = relaxed.structures + + if relax_config.refine: + refined = tasks.refine_structures_task( + structures=current_structures, + refinement_symprec=relax_config.refinement_symprec, + primitive=relax_config.refinement_primitive, + metadata=_metadata, + code=_code, + command_info=_command_info, + ) + current_structures = refined.structures + + if relax_config.filter_unique: + filtered = tasks.uniqueness_filter_task( + structures=current_structures, + symprec=relax_config.uniqueness.symprec, + ltol=relax_config.uniqueness.ltol, + stol=relax_config.uniqueness.stol, + angle_tol=relax_config.uniqueness.angle_tol, + metadata=_metadata, + code=_code, + command_info=_command_info, + ) + current_structures = filtered.unique_structures + + outputs = { + "structures": current_structures, + "final_energies": relaxed.final_energies, + } + if relax_config.params.return_initial_energies: + outputs["initial_energies"] = relaxed.initial_energies + if relax_config.params.return_initial_forces: + outputs["initial_forces"] = relaxed.initial_forces + if relax_config.params.return_final_forces: + outputs["final_forces"] = relaxed.final_forces + + return outputs diff --git a/src/xtalpaint/eval.py b/src/xtalpaint/eval.py index 45c859b..a7d0282 100644 --- a/src/xtalpaint/eval.py +++ b/src/xtalpaint/eval.py @@ -3,7 +3,6 @@ from concurrent.futures import ProcessPoolExecutor from functools import partial -import numpy as np import pandas as pd from mattergen.evaluation.utils.utils import compute_rmsd_angstrom from pymatgen.analysis.structure_matcher import StructureMatcher @@ -12,12 +11,8 @@ from xtalpaint.data import BatchedStructures from xtalpaint.utils import _is_batched_structure - - -def _check_for_nan(structure: Structure) -> bool: - """Check if a pymatgen Structure has NaN values in its atomic positions.""" - positions = structure.cart_coords - return np.isnan(positions).any() +from xtalpaint.utils.data_utils import get_structure_keys +from xtalpaint.utils.structure_utils import check_for_nan_positions def _rmsd(strct1, strct2, normalization_element: str | None = None) -> float: @@ -59,7 +54,7 @@ def _comparison_per_key( comparisons = [] comp_func = COMPARISON_METHODS[metric] for sample_idx, sample in inpainted_structures_grouped[key]: - if _check_for_nan(sample): + if check_for_nan_positions(sample): comparison = None else: comparison = comp_func(sample, ref, **kwargs) @@ -68,35 +63,6 @@ def _comparison_per_key( return comparisons -def get_structure_keys( - structures: BatchedStructures | dict[str, Structure], -) -> tuple[list[str], list[str | None]]: - """Get the unique keys of the structures with out sample indices. - - This is used to group structures that are samples of the same - base structure. - - Args: - structures (dict | BatchedStructures): - The structures to get the keys from. - - Returns: - set[str]: The unique structure keys. - """ - keys = structures.keys() - structure_keys = [] - sample_indices = [] - for key in keys: - if "_sample_" in key: - key, sample_idx = key.split("_sample_") - else: - sample_idx = None - structure_keys.append(key) - sample_indices.append(sample_idx) - - return structure_keys, sample_indices - - def worker_init(ref_structures, inp_structures_grp): """Initialize worker.""" global matcher, reference_structures, inpainted_structures_grouped diff --git a/src/xtalpaint/inpainting/config_schema.py b/src/xtalpaint/inpainting/config_schema.py index 1c8ec52..a0ecd6a 100644 --- a/src/xtalpaint/inpainting/config_schema.py +++ b/src/xtalpaint/inpainting/config_schema.py @@ -24,61 +24,49 @@ def _is_valid_structure_type(obj) -> bool: return False -def _is_inpainting_structure(obj) -> bool: - """Check if object is an InpaintingStructureData (requires AiiDA).""" - if is_aiida_installed(): - from xtalpaint.aiida.data import InpaintingStructureData +# --------------------------------------------------------------------------- +# Stage configs +# --------------------------------------------------------------------------- - return isinstance(obj, InpaintingStructureData) - return False +class CandidateGenerationConfig(BaseModel): + """Configuration for generating inpainting candidates.""" -class RelaxParameters(BaseModel): - """Configuration for the relaxation stage.""" + n_inp: int | dict[str, int] + element: str | dict[str, str] + num_samples: int = 1 - load_path: str | None = None - fmax: float = 0.05 - elements_to_relax: Optional[list[str]] = Field( - default=None, - description="List of elements to relax during optimization.", - ) - max_natoms_per_batch: int = 512 - max_n_steps: int = 500 - device: str = "cpu" - filter: Optional[str] = None - optimizer: str - mlip: str - return_initial_energies: bool = False - return_initial_forces: bool = False - return_final_forces: bool = False +class InpaintingConfig(BaseModel): + """Configuration for the diffusion inpainting stage. -class InpaintingModelParams(BaseModel): - """Diffusion sampling parameters for the inpainting model.""" + Sampling parameters are kept flat (no nested ModelParams sub-object) + so they can be specified in a single block and passed directly as a + dict to the inpainting pipeline. + """ + # Model — exactly one of these must be provided + pretrained_name: Optional[str] = None + model_path: Optional[str] = None + + # Diffusion sampling + predictor_corrector: str N_steps: int coordinates_snr: float n_corrector_steps: int batch_size: int - n_resample_steps: Optional[int] = None - jump_length: Optional[int] = None - - -class InpaintingPipelineParams(BaseModel): - """Settings for constructing an inpainting pipeline.""" - - predictor_corrector: str fix_cell: bool = True - inpainting_model_params: InpaintingModelParams - pretrained_name: Optional[str] = None - model_path: Optional[str] = None - record_trajectories: Optional[bool] = False + record_trajectories: bool = False sampling_config_path: Optional[str] = None + # Repaint-specific (required when predictor_corrector contains 'repaint') + n_resample_steps: Optional[int] = None + jump_length: Optional[int] = None + @field_validator("predictor_corrector") @classmethod def validate_predictor_corrector(cls, v): - """Validator to ensure 'predictor_corrector' is a supported key.""" + """Validate that predictor_corrector is one of the allowed options.""" from xtalpaint.inpainting.inpainting_process import ( GUIDED_PREDICTOR_CORRECTOR_MAPPING, ) @@ -93,11 +81,7 @@ def validate_predictor_corrector(cls, v): @model_validator(mode="after") @classmethod def check_pretrained_model_exclusive(cls, cfg): - """Validate model specification. - - Ensure that either 'pretrained_name' or 'model_path' is provided, - but not both. - """ + """Validate model specification.""" if ( cfg.pretrained_name is not None and cfg.model_path is not None ) or (cfg.pretrained_name is None and cfg.model_path is None): @@ -110,83 +94,261 @@ def check_pretrained_model_exclusive(cls, cfg): @model_validator(mode="after") @classmethod def check_repaint_requires_resample_and_jump(cls, cfg): - """Validate 'n_resample_steps' and 'jump_length'. - - If 'predictor_corrector' contains 'repaint', both parameters must be - set in 'inpainting_model_params'. - """ + """Validate RePaint-specific parameters.""" if "repaint" in cfg.predictor_corrector.lower(): - params = cfg.inpainting_model_params - if params.n_resample_steps is None or params.jump_length is None: + if cfg.n_resample_steps is None or cfg.jump_length is None: raise ValueError( "When 'predictor_corrector' contains 'repaint', " - "inpainting_model_params must set both 'n_resample_steps' " - "and 'jump_length'." + "both 'n_resample_steps' and 'jump_length' must be set." ) return cfg -class GenInpaintingCandidatesParams(BaseModel): - """Configuration for generating inpainting candidates.""" +class RefinementConfig(BaseModel): + """Symmetry refinement stage.""" - n_inp: int | dict[str, int] - element: str | dict[str, str] - num_samples: int = 1 + symprec: float = 0.01 + primitive: bool = False -class EvalParameters(BaseModel): - """Evaluation parameters for generated structures.""" +class UniquenessConfig(BaseModel): + """Parameters for post-relaxation uniqueness/deduplication filtering.""" - max_workers: int = 6 - chunksize: int = 50 - metrics: str | list[str] = "match" - code_label: Optional[str] = None + symprec: float = 0.01 + ltol: float = 0.2 + stol: float = 0.3 + angle_tol: float = 5.0 -class InpaintingWorkflowConfig(BaseModel): - """Top-level configuration for a XtalPaint inpainting workflow. +class RelaxationParams(BaseModel): + """Core relaxation parameters forwarded to ``relax_structures()``. - This config can be used for both AiiDA-based workflows (WorkGraphs) - and regular Python-based workflows. + These are the settings that control *how* a single relaxation is run + (MLIP, optimiser, convergence criteria, etc.). They are kept separate + from the inpainting-workflow-level controls in ``RelaxationConfig``. """ - structures: BatchedStructures | dict[str, Structure] - run_inpainting: bool = True - inpainting_pipeline_params: InpaintingPipelineParams - gen_inpainting_candidates_params: Optional[ - GenInpaintingCandidatesParams - ] = None - code_label: Optional[str] = None - relax_code_label: Optional[str] = None - inpainting_code_label: Optional[str] = None - relax: Optional[bool] = False - relax_kwargs: Optional[RelaxParameters] = {} - full_relax: Optional[bool] = False - full_relax_wo_pre_relax: Optional[bool] = False - options: Optional[dict] = {} - relax_options: Optional[dict] = {} - gen_inpainting_candidates_options: Optional[dict] = {} - inpainting_pipeline_options: Optional[dict] = {} - evaluate: Optional[bool] = False - evaluate_params: Optional[EvalParameters] = None - refine_structures: bool = False - refine_structures_after_relax: bool = False + mlip: str + optimizer: str + fmax: float = 0.05 + max_n_steps: int = 500 + max_natoms_per_batch: int = 512 + device: str = "cpu" + filter: Optional[str] = None + load_path: Optional[str] = None + elements_to_relax: Optional[list[str]] = None + return_initial_energies: bool = False + return_initial_forces: bool = False + return_final_forces: bool = False + + +class RelaxationGraphConfig(BaseModel): + """Configuration for a single ``relaxation_graph`` call. + + Bundles the core relaxation parameters with the optional post-relaxation + processing steps (symmetry refinement and uniqueness filtering) that + ``relaxation_graph`` can apply after each pass. + + This class is the direct input type for ``relaxation_graph``. + ``RelaxationConfig`` extends it with multi-pass orchestration flags for + use inside the inpainting WorkGraph. + """ + + # Core relaxation parameters (forwarded as relax_inputs + # to relax_structures) + params: RelaxationParams + + # Post-relaxation steps + refine: bool = False refinement_symprec: float = 0.01 refinement_primitive: bool = False + filter_unique: bool = False + uniqueness: UniquenessConfig = Field(default_factory=UniquenessConfig) + + def relax_inputs(self, constrained: bool = True) -> dict: + """Build the ``relax_inputs`` kwargs dict for ``relax_structures()``. + + Args: + constrained: If True, include ``elements_to_relax`` so that only + those elements are relaxed. Pass False for full-relax passes + where all atoms move freely. + """ + if constrained: + return self.params.model_dump() + return self.params.model_dump(exclude={"elements_to_relax"}) + + +class RelaxationConfig(RelaxationGraphConfig): + """Configuration for the relaxation stage in the inpainting workflow. + + Extends ``RelaxationGraphConfig`` with multi-pass orchestration flags + that are specific to the inpainting WorkGraph. The three passes share + the same ``params`` and post-relaxation settings. + + Pass names and their semantics + -------------------------------- + constrained + Relax only the atoms listed in ``params.elements_to_relax``. Requires + ``params.elements_to_relax`` to be set. Labelled + ``inpainted_constrained_relaxation`` in the WorkGraph. + full + Run a full (all-atom) relaxation on the output of the constrained pass. + Requires ``constrained=True``. Labelled + ``pre_relaxed_inpainted_full_relaxation``. + full_direct + Run a full relaxation directly on the inpainted structures, bypassing + the constrained pre-relax step (useful for comparison). Labelled + ``unrelaxed_inpainted_full_relaxation``. + """ + + # Which relaxation passes to run (inpainting-WG-specific) + constrained: bool = True + full: bool = False + full_direct: bool = False + + @model_validator(mode="after") + @classmethod + def validate_passes(cls, cfg): + """Validate relaxation modes.""" + if not any([cfg.constrained, cfg.full, cfg.full_direct]): + raise ValueError( + "At least one of 'constrained', 'full', or 'full_direct' " + "must be True." + ) + if cfg.constrained and cfg.params.elements_to_relax is None: + raise ValueError( + "'params.elements_to_relax' must be set when " + "'constrained=True'." + ) + if cfg.full and not cfg.constrained: + raise ValueError( + "'full=True' requires 'constrained=True': the full-relax pass " + "runs on the output of the constrained pass." + ) + return cfg + + +# --------------------------------------------------------------------------- +# AiiDA-specific options (ignored outside AiiDA execution) +# --------------------------------------------------------------------------- + + +class AiiDATaskOptions(BaseModel): + """AiiDA scheduler and resource options for a single task. + + Replaces the raw ``Optional[dict]`` options fields. ``withmpi`` + controls MPI-parallel execution and is kept here (infrastructure concern) + rather than in the pipeline config. + """ + + resources: dict = Field(default_factory=dict) + max_wallclock_seconds: Optional[int] = None + queue_name: Optional[str] = None + withmpi: bool = False + + +class AiiDAOptions(BaseModel): + """AiiDA-specific settings: code labels and per-task scheduler options. + + Place this in ``XtalPaintConfig.aiida``; it is ignored entirely in + non-AiiDA (plain Python) execution. + """ + + default_code_label: Optional[str] = None + inpainting_code_label: Optional[str] = None + relax_code_label: Optional[str] = None + candidate_generation_code_label: Optional[str] = None + + inpainting_options: AiiDATaskOptions = Field( + default_factory=AiiDATaskOptions + ) + relax_options: AiiDATaskOptions = Field(default_factory=AiiDATaskOptions) + candidate_generation_options: AiiDATaskOptions = Field( + default_factory=AiiDATaskOptions + ) + + def get_code_label(self, specific: Optional[str] = None) -> Optional[str]: + """Return *specific* code label, falling back to the default.""" + return specific or self.default_code_label + + +# --------------------------------------------------------------------------- +# Top-level config +# --------------------------------------------------------------------------- + + +class XtalPaintConfig(BaseModel): + """Complete configuration for the XtalPaint inpainting workflow. + + Works for both AiiDA-based (WorkGraph) and plain-Python execution. + AiiDA-specific settings live in the optional ``aiida`` block and are + ignored in non-AiiDA runs. + + Pipeline stages are controlled by presence/absence of their config + objects — no boolean flags required: + + * ``candidate_generation`` — omit if structures are already + ``InpaintingStructureData`` objects. + * ``pre_refinement`` — symmetry-refine structures before relaxation; + omit to skip. + * ``relaxation`` — geometry optimisation; omit to skip. + + Example (minimal):: + + XtalPaintConfig( + structures={"si_001": structure}, + inpainting=InpaintingConfig( + pretrained_name="mattergen_base", + predictor_corrector="baseline", + N_steps=5, coordinates_snr=0.2, + n_corrector_steps=1, batch_size=1000, + ), + ) + + Example (with relaxation + deduplication on AiiDA):: + + XtalPaintConfig( + structures=..., + candidate_generation=CandidateGenerationConfig( + n_inp={"H": 2}, element="H" + ), + inpainting=InpaintingConfig(...), + pre_refinement=RefinementConfig(symprec=0.01), + relaxation=RelaxationConfig( + params=RelaxationParams( + mlip="mattersim", + optimizer="BFGS", + elements_to_relax=["H"], + fmax=0.01, + ), + full=True, + filter_unique=True, + ), + aiida=AiiDAOptions( + default_code_label="xtalpaint@localhost", + relax_code_label="relax@hpc", + relax_options=AiiDATaskOptions( + resources={"num_machines": 1}, + withmpi=True, + ), + ), + ) + """ + + structures: BatchedStructures | dict[str, Structure] + run_inpainting: bool = True + candidate_generation: Optional[CandidateGenerationConfig] = None + pre_refinement: Optional[RefinementConfig] = None + inpainting: InpaintingConfig + relaxation: Optional[RelaxationConfig] = None + aiida: Optional[AiiDAOptions] = None model_config = ConfigDict(arbitrary_types_allowed=True) @field_validator("structures") @classmethod def validate_structures(cls, v): - """Validate input structures. - - Ensure 'structures' is a dictionary with string keys and values of - uniform, supported types. - - Raises: - TypeError: If the structure mapping is not valid. - """ + """Validate `structures` type.""" structures = v if _is_batched_structure(v): structures = v.get_structures(strct_type="pymatgen") @@ -201,7 +363,6 @@ def validate_structures(cls, v): "All values in the dictionary must be of type StructureData, " "Structure, ase.Atoms, or InpaintingStructureData" ) - types = {type(s) for s in structures.values()} if len(types) > 1: raise TypeError( @@ -209,48 +370,16 @@ def validate_structures(cls, v): ) return v - @model_validator(mode="after") - @classmethod - def check_n_inp_for_structures(cls, cfg): - """Validate inputs for inpainting candidates. - Ensure that 'gen_inpainting_candidates_params' is provided when - structures are not already inpainting structure instances. - """ - values = ( - list(cfg.structures.values()) - if isinstance(cfg.structures, dict) - else cfg.structures.get_structures(strct_type="pymatgen") - ) - if not all( - _is_inpainting_structure(s) or isinstance(s, Structure) - for s in values - ): - if cfg.gen_inpainting_candidates_params is None: - raise ValueError( - "If structures are not InpaintingStructure objects, " - "gen_inpainting_candidates_params must be provided." - ) - return cfg +# --------------------------------------------------------------------------- +# Evaluation parameters (standalone — not part of the inpainting workflow) +# --------------------------------------------------------------------------- - @model_validator(mode="after") - @classmethod - def check_evaluate_inpainting_structures(cls, cfg): - """Check if structures are already InpaintingStructure objects.""" - if cfg.evaluate and cfg.is_inpainting_structures: - raise ValueError( - "If 'evaluate' is True, structures must not be " - "InpaintingStructure objects. We need the original structures " - "to compare against inpainted structures." - ) - return cfg - @property - def is_inpainting_structures(self) -> bool: - """Check if structures are already InpaintingStructure objects.""" - structures = ( - self.structures.values() - if isinstance(self.structures, dict) - else self.structures.get_structures(strct_type="pymatgen") - ) - return all(_is_inpainting_structure(s) for s in structures) +class EvalParameters(BaseModel): + """Evaluation parameters for generated structures.""" + + max_workers: int = 6 + chunksize: int = 50 + metrics: str | list[str] = "match" + code_label: Optional[str] = None diff --git a/src/xtalpaint/inpainting/inpainting_process.py b/src/xtalpaint/inpainting/inpainting_process.py index d3401c7..8fed493 100644 --- a/src/xtalpaint/inpainting/inpainting_process.py +++ b/src/xtalpaint/inpainting/inpainting_process.py @@ -18,7 +18,7 @@ from xtalpaint.generate_inpainting import ( generate_reconstructed_structures, ) -from xtalpaint.inpainting.config_schema import InpaintingPipelineParams +from xtalpaint.inpainting.config_schema import InpaintingConfig from xtalpaint.utils.data_utils import create_dataloader XTALPAINT_BASE = "xtalpaint.predictor_corrector" @@ -173,30 +173,51 @@ def _get_overrides( def _run_inpainting( predictor_corrector: str, structures_dl: DataLoader, - inpainting_model_params: dict[str, Any], + N_steps: int, + coordinates_snr: float, + n_corrector_steps: int, + batch_size: int, fix_cell: bool = True, record_trajectories: bool = False, pretrained_name: str | None = None, model_path: str | None = None, sampling_config_path: str | None = None, + n_resample_steps: int | None = None, + jump_length: int | None = None, ) -> tuple[list[Structure], list, list | None]: """Run the inpainting process using MatterGen. Args: predictor_corrector: Type of predictor-corrector to use. structures_dl: DataLoader containing structures to inpaint. - inpainting_model_params: Parameters for the inpainting model. + N_steps: Number of diffusion steps. + coordinates_snr: Signal-to-noise ratio for coordinate corrector. + n_corrector_steps: Number of corrector steps per diffusion step. + batch_size: Batch size for the DataLoader. fix_cell: Whether to fix the cell during sampling. record_trajectories: Whether to record trajectories. pretrained_name: Name of pretrained model, if any. model_path: Path to model checkpoint. - sampling_config_path: Path to the sampling config directory - for mattergen + sampling_config_path: Path to the sampling config directory of + mattergen. + n_resample_steps: Number of resampling steps (repaint variants only). + jump_length: Jump length for resampling (repaint variants only). Returns: Tuple of (inpainted_structures, trajectories, mean_trajectories). mean_trajectories is None if not recorded. """ + inpainting_model_params: dict[str, Any] = { + "N_steps": N_steps, + "coordinates_snr": coordinates_snr, + "n_corrector_steps": n_corrector_steps, + "batch_size": batch_size, + } + if n_resample_steps is not None: + inpainting_model_params["n_resample_steps"] = n_resample_steps + if jump_length is not None: + inpainting_model_params["jump_length"] = jump_length + sampling_config_overrides, config_overrides = _get_overrides( inpainting_model_params, predictor_corrector, fix_cell, pretrained_name ) @@ -305,13 +326,15 @@ def _extract_outputs( def run_inpainting_pipeline( structures: dict[str, Structure], - config: InpaintingPipelineParams | dict[str, Any], + config: InpaintingConfig | dict[str, Any], ) -> dict[str, Any]: """Run the inpainting experiment using MatterGen. Args: structures: Input structures for inpainting. - config: Configuration for the inpainting process. + config: Configuration for the inpainting process. Pass an + ``InpaintingConfig`` instance or an equivalent plain dict + (e.g. from ``InpaintingConfig.model_dump(exclude_none=True)``). Returns: Dictionary containing inpainted structures, trajectories, and scores. @@ -321,13 +344,19 @@ def run_inpainting_pipeline( labels, structures = map(list, zip(*structures.items())) + cfg = ( + config + if isinstance(config, dict) + else config.model_dump(exclude_none=True) + ) + prepared_structures = _prepare_structures( structures, - batch_size=config["inpainting_model_params"].get("batch_size", 64), + batch_size=cfg["batch_size"], ) inpainted_structures, trajectories, mean_trajectories = _run_inpainting( - structures_dl=prepared_structures, **config + structures_dl=prepared_structures, **cfg ) return _extract_outputs( @@ -335,13 +364,13 @@ def run_inpainting_pipeline( trajectories, mean_trajectories, labels, - config["record_trajectories"], + cfg["record_trajectories"], ) def run_mpi_parallel_inpainting_pipeline( structures: dict[str, Structure], - config: InpaintingPipelineParams | dict[str, Any], + config: InpaintingConfig | dict[str, Any], ) -> dict[str, Any] | None: """Run the inpainting experiment using MatterGen with MPI parallelization. @@ -360,6 +389,12 @@ def run_mpi_parallel_inpainting_pipeline( labels, structures = map(list, zip(*structures.items())) + cfg = ( + config + if isinstance(config, dict) + else config.model_dump(exclude_none=True) + ) + comm = mpi4py.MPI.COMM_WORLD rank = comm.rank nranks = comm.size @@ -385,10 +420,10 @@ def run_mpi_parallel_inpainting_pipeline( prepared_structures = _prepare_structures( local_structures, - batch_size=config["inpainting_model_params"].get("batch_size", 64), + batch_size=cfg["batch_size"], ) - rank_results = _run_inpainting(structures_dl=prepared_structures, **config) + rank_results = _run_inpainting(structures_dl=prepared_structures, **cfg) # Gather results on rank 0 all_results = comm.gather(rank_results, root=0) @@ -412,7 +447,7 @@ def run_mpi_parallel_inpainting_pipeline( all_trajectories, all_mean_trajectories if all_mean_trajectories else None, labels, - config["record_trajectories"], + cfg["record_trajectories"], ) return None diff --git a/src/xtalpaint/utils/data_utils.py b/src/xtalpaint/utils/data_utils.py index 2982f2a..9aeb399 100644 --- a/src/xtalpaint/utils/data_utils.py +++ b/src/xtalpaint/utils/data_utils.py @@ -8,8 +8,44 @@ from mattergen.common.data.collate import collate from mattergen.common.data.dataset import CrystalDataset from mattergen.diffusion.data.batched_data import BatchedData +from pymatgen.core.structure import Structure from torch.utils.data import DataLoader +from xtalpaint.data import BatchedStructures + + +def get_structure_keys( + structures: BatchedStructures | dict[str, Structure], +) -> tuple[list[str], list[str | None]]: + """Get the unique keys of the structures with out sample indices. + + This is used to group structures that are samples of the same + base structure. Example keys are ``mp-1234_sample_0``, + ``mp-1234_sample_1``, etc. This function will return ``mp-1234`` as the + unique key for both of these, and the sample indices as ``0`` and ``1`` + respectively. If a key does not have a ``_sample_`` suffix, it is returned + as-is with a sample index of ``None``. + + Args: + structures (dict | BatchedStructures): + The structures to get the keys from. + + Returns: + set[str]: The unique structure keys. + """ + keys = structures.keys() + structure_keys = [] + sample_indices = [] + for key in keys: + if "_sample_" in key: + key, sample_idx = key.split("_sample_") + else: + sample_idx = None + structure_keys.append(key) + sample_indices.append(sample_idx) + + return structure_keys, sample_indices + def create_dataloader( dataset: CrystalDataset, batch_size: int, fix_cell: bool = True diff --git a/src/xtalpaint/utils/structure_utils.py b/src/xtalpaint/utils/structure_utils.py new file mode 100644 index 0000000..c01150c --- /dev/null +++ b/src/xtalpaint/utils/structure_utils.py @@ -0,0 +1,134 @@ +"""Utility functions for structure processing.""" + +import numpy as np +from pymatgen.analysis.structure_matcher import StructureMatcher +from pymatgen.core.structure import Structure +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer + +from xtalpaint.data import BatchedStructures +from xtalpaint.utils import _is_batched_structure +from xtalpaint.utils.data_utils import get_structure_keys + + +def check_for_nan_positions(structure: Structure) -> bool: + """Check if a pymatgen Structure has NaN values in its atomic positions.""" + positions = structure.cart_coords + return np.isnan(positions).any() + + +def refine_structures( + structures: "BatchedStructures | dict[str, Structure]", + symprec: float, + primitive: bool = False, +) -> BatchedStructures: + """Refine structures to standard conventional (or primitive) cells. + + Args: + structures: Input structures. + symprec: Symmetry precision passed to SpacegroupAnalyzer. + primitive: If ``True``, return the primitive cell instead of the + conventional cell. + + Returns: + BatchedStructures with refined structures. Structures for which + refinement raises an exception are kept as-is. + """ + if _is_batched_structure(structures): + structures_dict: dict[str, Structure] = structures.get_structures( + strct_type="pymatgen" + ) + else: + structures_dict = dict(structures) + + refined: dict[str, Structure] = {} + for k, s in structures_dict.items(): + analyzer = SpacegroupAnalyzer(s, symprec=symprec) + try: + result = analyzer.get_refined_structure() + except Exception: + result = s + + if primitive: + try: + result = SpacegroupAnalyzer( + result, symprec=symprec + ).get_primitive_structure() + except Exception: + pass + + refined[k] = result + + return BatchedStructures(refined) + + +def filter_unique_structures( + structures: "BatchedStructures | dict[str, Structure]", + symprec: float = 0.1, + ltol: float = 0.2, + stol: float = 0.3, + angle_tol: float = 5.0, +) -> BatchedStructures: + """Filter unique structures (per parent key) and space group. + + Groups samples by their parent structure key (splitting on ``_sample_``), + then by space group number, then applies StructureMatcher within each + sub-group to retain one representative per equivalence class. NaN + structures are skipped. The first encountered structure in each equivalence + class is kept as the representative. + + Args: + structures: Inpainting samples, typically keyed as + ``{base_key}_sample_{idx}``. + symprec: Symmetry precision passed to SpacegroupAnalyzer. + ltol: Fractional length tolerance for StructureMatcher. + stol: Site tolerance for StructureMatcher. + angle_tol: Angle tolerance in degrees for StructureMatcher. + + Returns: + BatchedStructures containing one representative per unique structure. + """ + if _is_batched_structure(structures): + structures_dict: dict[str, Structure] = structures.get_structures( + strct_type="pymatgen" + ) + else: + structures_dict = dict(structures) + + base_keys, _ = get_structure_keys(structures_dict) + + groups: dict[str, list[tuple[str, Structure]]] = {} + for full_key, base_key in zip(structures_dict.keys(), base_keys): + groups.setdefault(base_key, []).append( + (full_key, structures_dict[full_key]) + ) + + structure_matcher = StructureMatcher( + ltol=ltol, stol=stol, angle_tol=angle_tol + ) + unique: dict[str, Structure] = {} + + for members in groups.values(): + sg_groups: dict[int, list[tuple[str, Structure]]] = {} + for full_key, structure in members: + if check_for_nan_positions(structure): + continue + try: + sg_num = SpacegroupAnalyzer( + structure, symprec=symprec + ).get_space_group_number() + except Exception: + sg_num = -1 + sg_groups.setdefault(sg_num, []).append((full_key, structure)) + + for sg_members in sg_groups.values(): + representatives: list[tuple[str, Structure]] = [] + for full_key, structure in sg_members: + if not any( + structure_matcher.fit(structure, rep_strct) + for _, rep_strct in representatives + ): + representatives.append((full_key, structure)) + for rep_key, rep_strct in representatives: + unique[rep_key] = rep_strct + + return BatchedStructures(unique) diff --git a/tests/test_eval.py b/tests/test_eval.py index c04253a..4115a9f 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -8,12 +8,9 @@ from pymatgen.io.ase import AseAtomsAdaptor from xtalpaint.eval import ( - _check_for_nan, evaluate_inpainting, - get_structure_keys, ) - @pytest.fixture def reference_structures(): """Load reference structures from extxyz file.""" @@ -67,57 +64,6 @@ def nan_structure(): coords = [[np.nan, 0.0, 0.0], [0.5, 0.5, 0.5]] return Structure(lattice, species, coords) - -class TestCheckForNan: - """Test the _check_for_nan function.""" - - def test_no_nan(self, simple_structure): - """Test structure without NaN values.""" - assert not _check_for_nan(simple_structure) - - def test_with_nan(self, nan_structure): - """Test structure with NaN values.""" - assert _check_for_nan(nan_structure) - - def test_real_structures(self, reference_structures): - """Test that real structures have no NaN values.""" - for key, structure in reference_structures.items(): - assert not _check_for_nan(structure), f"Structure {key} has NaN" - -class TestGetStructureKeys: - """Test the get_structure_keys function.""" - - def test_no_samples(self): - """Test with keys without sample indices.""" - structures = { - "structure_1": None, - "structure_2": None, - } - keys, indices = get_structure_keys(structures) - assert keys == ["structure_1", "structure_2"] - assert indices == [None, None] - - def test_with_samples(self): - """Test with keys containing sample indices.""" - structures = { - "structure_1_sample_0": None, - "structure_1_sample_1": None, - "structure_2_sample_0": None, - } - keys, indices = get_structure_keys(structures) - assert keys == ["structure_1", "structure_1", "structure_2"] - assert indices == ["0", "1", "0"] - - def test_real_structure_keys(self, reference_structures): - """Test with real structure keys from extxyz.""" - keys, indices = get_structure_keys(reference_structures) - # All keys should have sample indices - assert len(keys) == len(reference_structures) - # Check that sample indices are extracted correctly - for idx in indices: - assert idx is not None - - class TestEvaluateInpainting: """Test the evaluate_inpainting function.""" diff --git a/tests/test_relaxation_graph.py b/tests/test_relaxation_graph.py new file mode 100644 index 0000000..17ac112 --- /dev/null +++ b/tests/test_relaxation_graph.py @@ -0,0 +1,397 @@ +"""Tests for relaxation_graph and filter_unique_structures.""" + +import sys + +from aiida import orm +import numpy as np +import pandas as pd +import pytest +from pymatgen.core import Lattice +from pymatgen.core.structure import Structure + +from xtalpaint.aiida.data import BatchedStructuresData +from xtalpaint.aiida.workgraphs.relaxation import relaxation_graph +from xtalpaint.data import BatchedStructures +from xtalpaint.utils.structure_utils import filter_unique_structures + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def bcc_si(): + """BCC silicon — space group Im-3m (229).""" + return Structure( + [[3.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 3.0]], + ["Si", "Si"], + [[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], + ) + + +@pytest.fixture +def fcc_al(): + """FCC aluminium — space group Fm-3m (225).""" + a = 4.05 + return Structure( + [[a, 0, 0], [0, a, 0], [0, 0, a]], + ["Al", "Al", "Al", "Al"], + [[0, 0, 0], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]], + ) + + +@pytest.fixture +def nan_si(bcc_si): + """BCC silicon with a NaN fractional coordinate.""" + s = bcc_si.copy() + s[0].frac_coords[0] = np.nan + return s + + +# --------------------------------------------------------------------------- +# relaxation_graph: graph-structure tests (no AiiDA runtime required) +# --------------------------------------------------------------------------- + + +# graph_inputs / graph_outputs / graph_ctx are always present in every built +# WorkGraph; they are infrastructure nodes, not user tasks. +_BUILTIN_TASKS = {"graph_inputs", "graph_outputs", "graph_ctx"} + + +def _user_tasks(wg) -> set: + return {t.name for t in wg.tasks if t.name not in _BUILTIN_TASKS} + + +def _structures_source(wg) -> str: + """Return the task name that the graph-level ``structures`` output links from.""" + return wg.outputs["structures"]._links[0].from_task.name + + +class TestRelaxationGraphStructure: + """Verify that relaxation_graph wires the correct tasks for each flag combination. + + These tests call ``.build()`` to materialise the inner WorkGraph without + running any AiiDA processes. They check which tasks exist, which output + sockets are declared, and that the ``structures`` output socket is wired to + the last active task in the chain. + """ + + def _build(self, structures, **kwargs): + return relaxation_graph.build( + structures=structures, + relax_inputs={}, + **kwargs, + ) + + def test_base_contains_only_relaxation_task(self, bcc_si): + wg = self._build(BatchedStructures({"s": bcc_si})) + assert _user_tasks(wg) == {"relaxation_task"} + + def test_refine_flag_adds_refinement_task(self, bcc_si): + wg = self._build(BatchedStructures({"s": bcc_si}), refine=True) + assert _user_tasks(wg) == {"relaxation_task", "refine_structures_task"} + + def test_filter_unique_flag_adds_uniqueness_task(self, bcc_si): + wg = self._build(BatchedStructures({"s": bcc_si}), filter_unique=True) + assert _user_tasks(wg) == {"relaxation_task", "uniqueness_filter_task"} + + def test_both_flags_produce_full_chain(self, bcc_si): + wg = self._build( + BatchedStructures({"s": bcc_si}), refine=True, filter_unique=True + ) + assert _user_tasks(wg) == { + "relaxation_task", + "refine_structures_task", + "uniqueness_filter_task", + } + + @pytest.mark.parametrize( + "refine,filter_unique", + [(False, False), (True, False), (False, True), (True, True)], + ) + def test_structures_and_energies_always_in_outputs( + self, bcc_si, refine, filter_unique + ): + wg = self._build( + BatchedStructures({"s": bcc_si}), + refine=refine, + filter_unique=filter_unique, + ) + assert "structures" in wg.outputs + assert "final_energies" in wg.outputs + + def test_optional_force_energy_sockets_declared(self, bcc_si): + """initial_energies / initial_forces / final_forces are declared as + optional sockets even though they are only populated at runtime when + requested via relax_inputs.""" + wg = self._build(BatchedStructures({"s": bcc_si})) + for socket in ("initial_energies", "initial_forces", "final_forces"): + assert socket in wg.outputs + + @pytest.mark.parametrize( + "refine,filter_unique,expected_src", + [ + (False, False, "relaxation_task"), + (True, False, "refine_structures_task"), + (False, True, "uniqueness_filter_task"), + (True, True, "uniqueness_filter_task"), + ], + ) + def test_structures_output_linked_to_last_active_task( + self, bcc_si, refine, filter_unique, expected_src + ): + """The graph-level ``structures`` output must be wired to the final + step in the active chain, not hardcoded to the relaxation task.""" + wg = self._build( + BatchedStructures({"s": bcc_si}), + refine=refine, + filter_unique=filter_unique, + ) + assert _structures_source(wg) == expected_src + + +# --------------------------------------------------------------------------- +# filter_unique_structures: functional tests +# --------------------------------------------------------------------------- + + +class TestFilterUniqueStructures: + """Tests for the pure-Python filter_unique_structures function.""" + + def test_identical_samples_collapsed_to_one(self, bcc_si): + """Two identical samples of the same parent become one representative.""" + structures = BatchedStructures( + {"s_sample_0": bcc_si, "s_sample_1": bcc_si} + ) + result = filter_unique_structures(structures) + assert len(result.keys()) == 1 + + def test_distinct_compositions_both_kept(self, bcc_si, fcc_al): + """Samples of different compositions can never be StructureMatcher-equal.""" + structures = BatchedStructures( + { + "a_sample_0": bcc_si, + "a_sample_1": fcc_al, + } + ) + result = filter_unique_structures(structures) + assert len(result.keys()) == 2 + + def test_different_space_groups_both_kept(self, bcc_si, fcc_al): + """Samples assigned to different space groups are never merged, even + if StructureMatcher would call them equal (it won't across SG groups).""" + structures = BatchedStructures( + {"p_sample_0": bcc_si, "p_sample_1": fcc_al} + ) + result = filter_unique_structures(structures) + # bcc_si (SG 229) and fcc_al (SG 225) land in different bins → both kept + assert len(result.keys()) == 2 + + def test_nan_structures_are_excluded(self, bcc_si, nan_si): + """NaN-coordinate structures are dropped entirely.""" + structures = BatchedStructures( + {"s_sample_0": nan_si, "s_sample_1": bcc_si} + ) + result = filter_unique_structures(structures) + assert len(result.keys()) == 1 + assert "s_sample_0" not in result.keys() + + def test_all_nan_returns_empty(self, nan_si): + structures = BatchedStructures( + {"s_sample_0": nan_si, "s_sample_1": nan_si} + ) + result = filter_unique_structures(structures) + assert len(result.keys()) == 0 + + def test_multiple_parents_filtered_independently(self, bcc_si, fcc_al): + """Each parent key group is deduplicated independently.""" + structures = BatchedStructures( + { + "a_sample_0": bcc_si, + "a_sample_1": bcc_si, # duplicate of a_sample_0 + "b_sample_0": fcc_al, + "b_sample_1": fcc_al, # duplicate of b_sample_0 + } + ) + result = filter_unique_structures(structures) + keys = result.keys() + # One unique per parent + assert len(keys) == 2 + a_keys = [k for k in keys if k.startswith("a_")] + b_keys = [k for k in keys if k.startswith("b_")] + assert len(a_keys) == 1 + assert len(b_keys) == 1 + + def test_accepts_plain_dict(self, bcc_si): + """filter_unique_structures works with a plain dict, not just BatchedStructures.""" + result = filter_unique_structures({"s": bcc_si}) + assert len(result.keys()) == 1 + + def test_returns_batched_structures(self, bcc_si): + result = filter_unique_structures({"s": bcc_si}) + assert isinstance(result, BatchedStructures) + + +# --------------------------------------------------------------------------- +# Fixtures for execution tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def diamond_si(): + """Diamond-cubic silicon (Fd-3m, SG 227) — 8-atom conventional cell.""" + a = 5.43 + return Structure( + Lattice.cubic(a), + ["Si"] * 8, + [ + [0.000, 0.000, 0.000], + [0.250, 0.250, 0.250], + [0.500, 0.500, 0.000], + [0.750, 0.750, 0.250], + [0.500, 0.000, 0.500], + [0.750, 0.250, 0.750], + [0.000, 0.500, 0.500], + [0.250, 0.750, 0.750], + ], + ) + + +@pytest.fixture(scope="module") +def strained_si(diamond_si): + """Diamond silicon compressed to 98 % of equilibrium volume. + + MatterSim will relax this back to the same diamond-cubic minimum as + ``diamond_si``, which lets us verify that the uniqueness filter correctly + identifies the two relaxed structures as equivalent. + """ + s = diamond_si.copy() + s.apply_strain(-0.02) + return s + + +@pytest.fixture(scope="module") +def fcc_al_conventional(): + """FCC aluminium (Fm-3m, SG 225) — 4-atom conventional cell.""" + a = 4.05 + return Structure( + Lattice.cubic(a), + ["Al"] * 4, + [[0.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.5, 0.0, 0.5], [0.0, 0.5, 0.5]], + ) + + +# MatterSim relax settings kept minimal so the tests run quickly on CPU. +_MATTERSIM_RELAX_INPUTS = { + "device": "cpu", + "mlip": "mattersim", + "fmax": 0.1, + "max_n_steps": 100, +} + + +# --------------------------------------------------------------------------- +# TestRelaxationGraphExecution +# --------------------------------------------------------------------------- + + +class TestRelaxationGraphExecution: + """Integration tests: actually execute the WorkGraph with MatterSim. + + These tests require a temporary AiiDA profile (``aiida_profile``) and a + configured localhost computer (``aiida_localhost``). The WorkGraph tasks + are pythonjob tasks that run in a subprocess using the current Python + executable so that all installed packages (including MatterSim) are + available. + + Each test method requests both ``aiida_profile`` and ``aiida_localhost`` + as fixtures to ensure the AiiDA backend and transport layer are ready + before the WorkGraph is submitted. + """ + + def _build_and_run(self, structures: dict, **graph_kwargs): + wg = relaxation_graph.build( + structures=BatchedStructures(structures), + relax_inputs=_MATTERSIM_RELAX_INPUTS, + command_info={"filepath_executable": sys.executable}, + **graph_kwargs, + ) + wg.run() + + assert wg.process.is_finished_ok, "WorkGraph did not finish successfully" + + return wg + + def test_basic_relaxation_returns_structures_and_energies( + self, aiida_profile, aiida_localhost, diamond_si, fcc_al_conventional + ): + """WorkGraph runs to completion and returns one relaxed structure and + one energy row per input.""" + wg = self._build_and_run( + {"si": diamond_si, "al": fcc_al_conventional}, + ) + + assert wg.state == "FINISHED" + + # --- structures output --- + out_node = wg.tasks.relaxation_task.outputs.structures.value + assert isinstance(out_node, BatchedStructuresData) + out_keys = set(out_node.value.keys()) + assert out_keys == {"si", "al"} + + # --- final_energies output --- + energies_node = wg.tasks.relaxation_task.outputs.final_energies.value + energies_df: pd.DataFrame = energies_node.value + assert isinstance(energies_df, pd.DataFrame) + assert set(energies_df.index) == {"si", "al"} + assert (energies_df["final_energy"] < 0).all(), ( + "Energies from MatterSim should be negative (eV)" + ) + + def test_filter_unique_deduplicates_identical_relaxed_copies( + self, + aiida_profile, + aiida_localhost, + diamond_si, + strained_si, + fcc_al_conventional, + ): + """Three Si samples (two identical, one 2 %-strained) and two Al + samples are passed through the full relaxation + uniqueness-filter + pipeline. After relaxation all three Si structures converge to the + same diamond-cubic minimum, so the filter should retain exactly one Si + representative. The two identical Al copies should likewise collapse + to one, leaving two unique structures in total.""" + structures = { + # Three Si samples with the same parent key "si" — all should + # relax to diamond-cubic Si and be deduplicated to one. + "si_sample_0": diamond_si, + "si_sample_1": diamond_si, # exact copy + "si_sample_2": strained_si, # 2 % compressed, same basin + # Two Al samples with parent key "al" — both relax to FCC Al. + "al_sample_0": fcc_al_conventional, + "al_sample_1": fcc_al_conventional, # exact copy + } + + wg = self._build_and_run(structures, filter_unique=True) + + assert wg.state == "FINISHED" + + unique_node = ( + wg.tasks.uniqueness_filter_task.outputs.unique_structures.value + ) + assert isinstance(unique_node, BatchedStructuresData) + unique_keys = list(unique_node.value.keys()) + + si_keys = [k for k in unique_keys if k.startswith("si_")] + al_keys = [k for k in unique_keys if k.startswith("al_")] + + assert len(si_keys) == 1, ( + f"Expected 1 unique Si structure after deduplication, " + f"got {len(si_keys)}: {si_keys}" + ) + assert len(al_keys) == 1, ( + f"Expected 1 unique Al structure after deduplication, " + f"got {len(al_keys)}: {al_keys}" + )