|
26 | 26 | from abc import ABC, abstractmethod |
27 | 27 | from collections import Counter |
28 | 28 | from dataclasses import dataclass, field |
| 29 | +from typing import Any |
29 | 30 |
|
30 | 31 | import numpy as np |
31 | 32 | from scipy.spatial.distance import cdist, pdist |
@@ -611,7 +612,7 @@ def _kabsch_rmsd(pa: np.ndarray, pb: np.ndarray) -> float: |
611 | 612 | RMSD for enantiomeric pairs. |
612 | 613 | """ |
613 | 614 | U, S, Vt = np.linalg.svd(pb.T @ pa) |
614 | | - d = np.sign(np.linalg.det(U) * np.linalg.det(Vt)) |
| 615 | + d = 1.0 if np.linalg.det(U) * np.linalg.det(Vt) >= 0 else -1.0 |
615 | 616 | E0 = np.sum(pa ** 2) + np.sum(pb ** 2) |
616 | 617 | rmsd_sq = max(0.0, E0 - 2.0 * (S[0] + S[1] + d * S[2])) / len(pa) |
617 | 618 | return float(np.sqrt(rmsd_sq)) |
@@ -724,7 +725,7 @@ class ExplorationTask: |
724 | 725 | xyz_file: str |
725 | 726 | afir_params: list[str] |
726 | 727 | priority: float = 0.0 |
727 | | - metadata: dict = field(default_factory=dict) |
| 728 | + metadata: dict[str, Any] = field(default_factory=dict) |
728 | 729 |
|
729 | 730 |
|
730 | 731 | @dataclass |
@@ -1054,31 +1055,36 @@ def __init__(self, filepath: str, flush_interval: int = 100) -> None: |
1054 | 1055 | # ------------------------------------------------------------------ # |
1055 | 1056 |
|
1056 | 1057 | def _load(self) -> None: |
1057 | | - """Load existing records from the text file (if present).""" |
1058 | 1058 | if not os.path.isfile(self._filepath): |
1059 | 1059 | return |
1060 | | - |
| 1060 | + |
1061 | 1061 | with open(self._filepath, "r", encoding="utf-8") as fh: |
1062 | | - for line in fh: |
| 1062 | + for lineno, line in enumerate(fh, 1): |
1063 | 1063 | line = line.strip() |
1064 | 1064 | if not line or line.startswith("#"): |
1065 | 1065 | continue |
1066 | 1066 | parts = line.split() |
1067 | 1067 | if len(parts) < 4: |
1068 | 1068 | continue |
1069 | 1069 | try: |
1070 | | - # Strip the leading "EQ" prefix before converting to int |
1071 | 1070 | node_id = int(parts[0][2:]) |
1072 | 1071 | atom_i = int(parts[1]) |
1073 | 1072 | atom_j = int(parts[2]) |
1074 | 1073 | gamma_sign = parts[3] |
1075 | | - |
| 1074 | + |
1076 | 1075 | if gamma_sign not in ("+", "-"): |
1077 | 1076 | continue |
1078 | | - |
| 1077 | + |
1079 | 1078 | self._explored.add((node_id, atom_i, atom_j, gamma_sign)) |
1080 | 1079 | self._explored_node_ids.add(node_id) |
1081 | 1080 | except (ValueError, IndexError): |
| 1081 | + logger.warning( |
| 1082 | + "ExploredPairsLog._load: malformed record at %s:%d " |
| 1083 | + "(content=%r) — skipped.", |
| 1084 | + self._filepath, |
| 1085 | + lineno, |
| 1086 | + line, |
| 1087 | + ) |
1082 | 1088 | continue |
1083 | 1089 |
|
1084 | 1090 | logger.info( |
@@ -2551,7 +2557,7 @@ def _collect_task_batch(self, n: int) -> list[tuple[ExplorationTask, str, str, i |
2551 | 2557 | An empty list is returned when no runnable tasks are available, |
2552 | 2558 | which signals :meth:`_run_batch_parallel` to stop the loop. |
2553 | 2559 | """ |
2554 | | - batch: list[tuple[ExplorationTask, str, str, str, str, int]] = [] |
| 2560 | + batch: list[tuple[ExplorationTask, str, str, int, int, int]] = [] |
2555 | 2561 | # Safety limit: if every task remaining in the queue is already |
2556 | 2562 | # explored, stop after at most (current queue size + n) skips rather |
2557 | 2563 | # than draining the entire queue one pop at a time. |
@@ -3005,6 +3011,7 @@ def _save_run_metadata( |
3005 | 3011 | task: ExplorationTask, |
3006 | 3012 | status: str, |
3007 | 3013 | profile_dirs: list[str], |
| 3014 | + iteration=None, |
3008 | 3015 | ) -> None: |
3009 | 3016 | info = { |
3010 | 3017 | "iteration": self._iteration, |
|
0 commit comments