Skip to content

Commit d91c214

Browse files
committed
Added ND scan functionalities to plotter
1 parent 6d7fe8d commit d91c214

2 files changed

Lines changed: 339 additions & 3 deletions

File tree

arc/plotter.py

Lines changed: 210 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,6 +1293,12 @@ def plot_2d_rotor_scan(results: dict,
12931293
if len(results['scans']) != 2:
12941294
raise InputError(f'results must represent a 2D rotor, got {len(results["scans"])}D')
12951295

1296+
# Dispatch to sparse plotting for adaptive scans
1297+
if is_sparse_2d_scan(results):
1298+
_plot_2d_rotor_scan_sparse(results, path=path, label=label, cmap=cmap,
1299+
resolution=resolution, original_dihedrals=original_dihedrals)
1300+
return
1301+
12961302
results['directed_scan'] = clean_scan_results(results['directed_scan'])
12971303

12981304
# phis0 and phis1 correspond to columns and rows in energies, respectively
@@ -1357,9 +1363,9 @@ def plot_2d_rotor_scan(results: dict,
13571363
label = ' for ' + label if label else ''
13581364
plt.title(f'2D scan energies (kJ/mol){label}')
13591365
min_x = min_y = -180
1360-
plt.xlim = (min_x, min_x + 360)
1366+
plt.gca().set_xlim(min_x, min_x + 360)
13611367
plt.xticks(np.arange(min_x, min_x + 361, step=60))
1362-
plt.ylim = (min_y, min_y + 360)
1368+
plt.gca().set_ylim(min_y, min_y + 360)
13631369
plt.yticks(np.arange(min_y, min_y + 361, step=60))
13641370

13651371
if mark_lowest_conformations:
@@ -1379,6 +1385,207 @@ def plot_2d_rotor_scan(results: dict,
13791385
plt.close(fig=fig)
13801386

13811387

1388+
def is_sparse_2d_scan(results: dict) -> bool:
1389+
"""
1390+
Detect whether a 2D scan results dict represents a sparse/adaptive scan.
1391+
1392+
A scan is considered sparse if the results contain
1393+
``sampling_policy == 'adaptive'``.
1394+
1395+
Args:
1396+
results (dict): The results dictionary from a 2D directed scan.
1397+
1398+
Returns:
1399+
bool: ``True`` if the scan is sparse/adaptive.
1400+
"""
1401+
return results.get('sampling_policy') == 'adaptive'
1402+
1403+
1404+
def extract_sparse_2d_points(results: dict) -> dict:
1405+
"""
1406+
Extract sampled point coordinates and energies from a sparse 2D scan result.
1407+
1408+
Args:
1409+
results (dict): The results dictionary from a 2D directed scan.
1410+
1411+
Returns:
1412+
dict: A dictionary with keys ``'x'``, ``'y'``, ``'energy'`` (lists of floats for
1413+
completed points with non-None energy), plus ``'failed_points'`` and
1414+
``'invalid_points'`` (lists of ``[x, y]`` pairs).
1415+
"""
1416+
xs, ys, energies = [], [], []
1417+
for key, entry in results.get('directed_scan', {}).items():
1418+
e = entry.get('energy')
1419+
if e is not None:
1420+
xs.append(float(key[0]))
1421+
ys.append(float(key[1]))
1422+
energies.append(float(e))
1423+
summary = results.get('adaptive_scan_summary', {})
1424+
return {
1425+
'x': xs,
1426+
'y': ys,
1427+
'energy': energies,
1428+
'failed_points': summary.get('failed_points', []),
1429+
'invalid_points': summary.get('invalid_points', []),
1430+
}
1431+
1432+
1433+
def interpolate_sparse_2d_scan(points_x: list,
1434+
points_y: list,
1435+
energies: list,
1436+
grid_resolution: float = 2.0,
1437+
) -> tuple:
1438+
"""
1439+
Interpolate sparse 2D scan data onto a dense grid for contour plotting.
1440+
1441+
Uses ``scipy.interpolate.griddata`` with periodic boundary augmentation
1442+
to reduce artifacts at the -180/+180 wrap boundary.
1443+
1444+
Args:
1445+
points_x (list): Sampled dihedral angles for dimension 0 (degrees).
1446+
points_y (list): Sampled dihedral angles for dimension 1 (degrees).
1447+
energies (list): Energy values at sampled points (kJ/mol).
1448+
grid_resolution (float): Spacing of the dense output grid in degrees.
1449+
1450+
Returns:
1451+
tuple: ``(grid_x, grid_y, grid_energies)`` where each is a 2D numpy array
1452+
suitable for ``plt.contourf``.
1453+
"""
1454+
from scipy.interpolate import griddata
1455+
1456+
px = np.array(points_x, dtype=np.float64)
1457+
py = np.array(points_y, dtype=np.float64)
1458+
pe = np.array(energies, dtype=np.float64)
1459+
1460+
# Augment with periodic image points for wrap-around
1461+
aug_x, aug_y, aug_e = list(px), list(py), list(pe)
1462+
for dx in (-360.0, 0.0, 360.0):
1463+
for dy in (-360.0, 0.0, 360.0):
1464+
if dx == 0.0 and dy == 0.0:
1465+
continue
1466+
aug_x.extend(px + dx)
1467+
aug_y.extend(py + dy)
1468+
aug_e.extend(pe)
1469+
aug_x = np.array(aug_x)
1470+
aug_y = np.array(aug_y)
1471+
aug_e = np.array(aug_e)
1472+
1473+
# Dense grid from -180 to 180
1474+
n_pts = int(360.0 / grid_resolution) + 1
1475+
gx = np.linspace(-180.0, 180.0, n_pts)
1476+
gy = np.linspace(-180.0, 180.0, n_pts)
1477+
grid_x, grid_y = np.meshgrid(gx, gy, indexing='ij')
1478+
1479+
# Interpolate: try cubic, fall back to linear, then nearest
1480+
pts = np.column_stack([aug_x, aug_y])
1481+
grid_e = None
1482+
for method in ('cubic', 'linear'):
1483+
try:
1484+
grid_e = griddata(pts, aug_e, (grid_x, grid_y), method=method)
1485+
if not np.all(np.isnan(grid_e)):
1486+
break
1487+
except (ValueError, Exception):
1488+
grid_e = None
1489+
if grid_e is None or np.all(np.isnan(grid_e)):
1490+
grid_e = griddata(pts, aug_e, (grid_x, grid_y), method='nearest')
1491+
# Fill any remaining NaN with nearest-neighbor
1492+
mask = np.isnan(grid_e)
1493+
if mask.any():
1494+
grid_nearest = griddata(pts, aug_e, (grid_x, grid_y), method='nearest')
1495+
grid_e[mask] = grid_nearest[mask]
1496+
1497+
return grid_x, grid_y, grid_e
1498+
1499+
1500+
def _plot_2d_rotor_scan_sparse(results: dict,
1501+
path: Optional[str] = None,
1502+
label: str = '',
1503+
cmap: str = 'Blues',
1504+
resolution: int = 90,
1505+
original_dihedrals: Optional[List[float]] = None,
1506+
):
1507+
"""
1508+
Plot a sparse/adaptive 2D rotor scan using interpolation for contours
1509+
and overlaying sampled, failed, and invalid points.
1510+
1511+
This is called internally by :func:`plot_2d_rotor_scan` when the results
1512+
are detected as sparse.
1513+
1514+
Args:
1515+
results (dict): The results dictionary from a 2D directed scan.
1516+
path (str, optional): Folder path to save the plot image.
1517+
label (str, optional): Species label.
1518+
cmap (str, optional): Matplotlib colormap name.
1519+
resolution (int, optional): Image DPI.
1520+
original_dihedrals (list, optional): Original dihedral angles for marker.
1521+
"""
1522+
data = extract_sparse_2d_points(results)
1523+
xs, ys, energies = data['x'], data['y'], data['energy']
1524+
1525+
if len(xs) < 3:
1526+
logger.warning(f'Not enough sparse points to plot 2D scan ({len(xs)} points)')
1527+
return
1528+
1529+
# Normalize energies to min = 0
1530+
e_min = min(energies)
1531+
energies_norm = [e - e_min for e in energies]
1532+
1533+
# Interpolate to dense grid
1534+
grid_x, grid_y, grid_e = interpolate_sparse_2d_scan(xs, ys, energies_norm, grid_resolution=2.0)
1535+
1536+
fig = plt.figure(num=None, figsize=(12, 8), dpi=resolution, facecolor='w', edgecolor='k')
1537+
1538+
plt.contourf(grid_x, grid_y, grid_e, 20, cmap=cmap)
1539+
plt.colorbar()
1540+
contours = plt.contour(grid_x, grid_y, grid_e, 4, colors='black')
1541+
plt.clabel(contours, inline=True, fontsize=8)
1542+
1543+
# Overlay sampled points
1544+
plt.scatter(xs, ys, c='black', s=12, zorder=5, label='sampled')
1545+
1546+
# Overlay failed points
1547+
failed = data.get('failed_points', [])
1548+
if failed:
1549+
fx = [p[0] for p in failed]
1550+
fy = [p[1] for p in failed]
1551+
plt.scatter(fx, fy, c='red', marker='x', s=40, zorder=6, label='failed')
1552+
1553+
# Overlay invalid points
1554+
invalid = data.get('invalid_points', [])
1555+
if invalid:
1556+
ix = [p[0] for p in invalid]
1557+
iy = [p[1] for p in invalid]
1558+
plt.scatter(ix, iy, edgecolors='orange', marker='s', facecolors='none',
1559+
s=40, zorder=6, label='invalid')
1560+
1561+
# Mark original dihedral
1562+
if original_dihedrals is not None and len(original_dihedrals) >= 2:
1563+
plt.plot(original_dihedrals[0], original_dihedrals[1], color='r',
1564+
marker='.', markersize=15, linewidth=0, label='original')
1565+
1566+
plt.xlabel(f'Dihedral 1 for {results["scans"][0]} (degrees)')
1567+
plt.ylabel(f'Dihedral 2 for {results["scans"][1]} (degrees)')
1568+
label_str = ' for ' + label if label else ''
1569+
summary = results.get('adaptive_scan_summary', {})
1570+
n_pts = summary.get('completed_count', len(xs))
1571+
plt.title(f'2D scan energies (kJ/mol){label_str} [adaptive, {n_pts} pts]')
1572+
plt.gca().set_xlim(-180, 180)
1573+
plt.xticks(np.arange(-180, 181, step=60))
1574+
plt.gca().set_ylim(-180, 180)
1575+
plt.yticks(np.arange(-180, 181, step=60))
1576+
1577+
plt.legend(loc='upper right', fontsize=8)
1578+
1579+
if path is not None:
1580+
fig_name = f'{results["directed_scan_type"]}_{results["scans"]}_adaptive.png'
1581+
fig_path = os.path.join(path, fig_name)
1582+
plt.savefig(fig_path, dpi=resolution, facecolor='w', edgecolor='w', orientation='portrait',
1583+
format='png', transparent=False, bbox_inches=None, pad_inches=0.1, metadata=None)
1584+
1585+
plt.show()
1586+
plt.close(fig=fig)
1587+
1588+
13821589
def plot_2d_scan_bond_dihedral(results: dict,
13831590
path: Optional[str] = None,
13841591
label: str = '',
@@ -1486,7 +1693,7 @@ def plot_2d_scan_bond_dihedral(results: dict,
14861693
label = ' for ' + label if label else ''
14871694
plt.title(f'2D scan energies (kJ/mol){label}')
14881695
min_x = -180
1489-
plt.xlim = (min_x, min_x + 360)
1696+
plt.gca().set_xlim(min_x, min_x + 360)
14901697
plt.xticks(np.arange(min_x, min_x + 361, step=60))
14911698

14921699
if original_dihedrals is not None:

arc/plotter_test.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import shutil
1010
import unittest
1111

12+
import numpy as np
13+
1214
import arc.plotter as plotter
1315
from arc.common import ARC_PATH, ARC_TESTING_PATH, read_yaml_file, safe_copy_file
1416
from arc.species.converter import str_to_xyz
@@ -236,5 +238,132 @@ def tearDownClass(cls):
236238
os.remove(file_path)
237239

238240

241+
class TestSparse2DPlotting(unittest.TestCase):
242+
"""
243+
Contains unit tests for sparse 2D rotor scan plotting helpers.
244+
"""
245+
246+
def _make_dense_results(self):
247+
"""Helper: build a small dense 2D result dict (4x4 grid, increment=120)."""
248+
directed_scan = {}
249+
for a0 in [0.0, 120.0, -120.0, 0.0]:
250+
for a1 in [0.0, 120.0, -120.0, 0.0]:
251+
key = (f'{a0:.2f}', f'{a1:.2f}')
252+
if key not in directed_scan:
253+
directed_scan[key] = {
254+
'energy': float(abs(a0) + abs(a1)) / 10.0,
255+
'xyz': {},
256+
'is_isomorphic': True,
257+
'trsh': [],
258+
}
259+
return {
260+
'directed_scan_type': 'brute_force_opt',
261+
'scans': [[1, 2, 3, 4], [5, 6, 7, 8]],
262+
'directed_scan': directed_scan,
263+
}
264+
265+
def _make_sparse_results(self, n_points=20):
266+
"""Helper: build a sparse adaptive 2D result dict."""
267+
import random
268+
random.seed(42)
269+
directed_scan = {}
270+
for _ in range(n_points):
271+
a0 = round(random.uniform(-180, 180), 2)
272+
a1 = round(random.uniform(-180, 180), 2)
273+
key = (f'{a0:.2f}', f'{a1:.2f}')
274+
directed_scan[key] = {
275+
'energy': float(abs(a0) + abs(a1)) / 10.0,
276+
'xyz': {},
277+
'is_isomorphic': True,
278+
'trsh': [],
279+
}
280+
return {
281+
'directed_scan_type': 'brute_force_opt',
282+
'scans': [[1, 2, 3, 4], [5, 6, 7, 8]],
283+
'directed_scan': directed_scan,
284+
'sampling_policy': 'adaptive',
285+
'adaptive_scan_summary': {
286+
'completed_count': n_points,
287+
'failed_count': 2,
288+
'invalid_count': 1,
289+
'stopping_reason': 'max_points_reached',
290+
'failed_points': [[45.0, -90.0], [120.0, 60.0]],
291+
'invalid_points': [[-30.0, 150.0]],
292+
},
293+
}
294+
295+
def test_is_sparse_2d_scan_dense(self):
296+
"""Test that dense results are not detected as sparse."""
297+
results = self._make_dense_results()
298+
self.assertFalse(plotter.is_sparse_2d_scan(results))
299+
300+
def test_is_sparse_2d_scan_adaptive(self):
301+
"""Test that adaptive results are detected as sparse."""
302+
results = self._make_sparse_results()
303+
self.assertTrue(plotter.is_sparse_2d_scan(results))
304+
305+
def test_extract_sparse_2d_points(self):
306+
"""Test extraction of sparse point data."""
307+
results = self._make_sparse_results(15)
308+
data = plotter.extract_sparse_2d_points(results)
309+
self.assertEqual(len(data['x']), len(data['y']))
310+
self.assertEqual(len(data['x']), len(data['energy']))
311+
self.assertGreater(len(data['x']), 0)
312+
self.assertEqual(len(data['failed_points']), 2)
313+
self.assertEqual(len(data['invalid_points']), 1)
314+
315+
def test_extract_sparse_2d_points_dense(self):
316+
"""Test extraction from dense results (no adaptive summary)."""
317+
results = self._make_dense_results()
318+
data = plotter.extract_sparse_2d_points(results)
319+
self.assertGreater(len(data['x']), 0)
320+
self.assertEqual(data['failed_points'], [])
321+
self.assertEqual(data['invalid_points'], [])
322+
323+
def test_interpolate_sparse_2d_scan(self):
324+
"""Test interpolation produces a dense grid."""
325+
xs = [0.0, 90.0, -90.0, 180.0, -180.0, 45.0, -45.0]
326+
ys = [0.0, 90.0, -90.0, 180.0, -180.0, 45.0, -45.0]
327+
es = [0.0, 5.0, 5.0, 10.0, 10.0, 3.0, 3.0]
328+
gx, gy, ge = plotter.interpolate_sparse_2d_scan(xs, ys, es, grid_resolution=10.0)
329+
# Check shapes match
330+
self.assertEqual(gx.shape, gy.shape)
331+
self.assertEqual(gx.shape, ge.shape)
332+
n = int(360.0 / 10.0) + 1
333+
self.assertEqual(gx.shape, (n, n))
334+
# No NaN values
335+
self.assertFalse(np.any(np.isnan(ge)))
336+
337+
def test_plot_sparse_2d_no_crash(self):
338+
"""Test that plotting a sparse scan doesn't crash."""
339+
import tempfile
340+
results = self._make_sparse_results(30)
341+
with tempfile.TemporaryDirectory() as tmpdir:
342+
# Should not raise
343+
plotter.plot_2d_rotor_scan(results, path=tmpdir)
344+
# Check that a file was created
345+
files = os.listdir(tmpdir)
346+
self.assertTrue(any('adaptive' in f for f in files),
347+
f'Expected adaptive plot file, got: {files}')
348+
349+
def test_plot_dense_2d_unchanged(self):
350+
"""Test that plotting a dense scan still works through the legacy path."""
351+
# This exercises the existing code path; if it crashes, the dense path is broken
352+
results = self._make_dense_results()
353+
# Don't save to disk, just ensure no crash
354+
try:
355+
plotter.plot_2d_rotor_scan(results, path=None)
356+
except (ValueError, KeyError):
357+
# Dense path might fail on this small test grid due to missing points,
358+
# but it should NOT dispatch to sparse path
359+
self.assertFalse(plotter.is_sparse_2d_scan(results))
360+
361+
def test_plot_sparse_too_few_points_no_crash(self):
362+
"""Test that sparse plotting with < 3 points doesn't crash."""
363+
results = self._make_sparse_results(2)
364+
# Should not raise, just warn
365+
plotter.plot_2d_rotor_scan(results, path=None)
366+
367+
239368
if __name__ == '__main__':
240369
unittest.main(testRunner=unittest.TextTestRunner(verbosity=2))

0 commit comments

Comments
 (0)