Skip to content

Commit 622b88a

Browse files
Fix RDFs for ase 3.28 (#691)
Co-authored-by: Elliott Kasoar <45317199+ElliottKasoar@users.noreply.github.com>
1 parent 0ee04be commit 622b88a

4 files changed

Lines changed: 99 additions & 41 deletions

File tree

janus_core/calculations/md.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from warnings import warn
1515

1616
from ase import Atoms
17-
from ase.geometry.analysis import Analysis
1817
from ase.io import read
1918
from ase.md.bussi import Bussi
2019
from ase.md.langevin import Langevin
@@ -1293,8 +1292,6 @@ def _post_process(self) -> None:
12931292

12941293
data = read(self.traj_file, index=":")
12951294

1296-
ana = Analysis(data)
1297-
12981295
if self.post_process_kwargs.get("rdf_compute", False):
12991296
rdf_args = {
13001297
name: self.post_process_kwargs.get(key, default)
@@ -1312,7 +1309,7 @@ def _post_process(self) -> None:
13121309
)
13131310
rdf_args["index"] = slice_
13141311

1315-
compute_rdf(data, ana, filenames=self._rdf_files, **rdf_args)
1312+
compute_rdf(data, filenames=self._rdf_files, **rdf_args)
13161313

13171314
if self.post_process_kwargs.get("vaf_compute", False):
13181315
use_vel = self.post_process_kwargs.get("vaf_velocities", False)

janus_core/processing/post_process.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
from collections.abc import Sequence
66
from itertools import combinations_with_replacement
7+
from warnings import warn
78

89
from ase import Atoms
910
from ase.geometry.analysis import Analysis
11+
from ase.geometry.rdf import get_rdf
1012
import numpy as np
1113
from numpy import float64
1214
from numpy.typing import NDArray
@@ -40,7 +42,7 @@ def compute_rdf(
4042
data
4143
Dataset to compute RDF of.
4244
ana
43-
ASE Analysis object for data reuse.
45+
Deprecated. Please do not use. ASE Analysis object for data reuse.
4446
filenames
4547
Filenames to output data to. Must match number of RDFs computed.
4648
by_elements
@@ -69,6 +71,18 @@ def compute_rdf(
6971
If `by_elements` is true returns a `dict` of RDF by element pairs.
7072
Otherwise returns RDF of total system filtered by elements.
7173
"""
74+
if ana is not None:
75+
warn(
76+
"ana has been deprecated.",
77+
FutureWarning,
78+
stacklevel=2,
79+
)
80+
if by_elements:
81+
raise ValueError(
82+
"Analysis.get_rdf has known bugs with by_elements."
83+
"Call without ana to use ase.geometry.rdf.get_rdf directly."
84+
)
85+
7286
index = slicelike_to_startstopstep(index)
7387

7488
if not isinstance(data, Sequence):
@@ -82,32 +96,34 @@ def compute_rdf(
8296
):
8397
volume = (2 * rmax) ** 3
8498

85-
if ana is None:
86-
ana = Analysis(data)
87-
8899
if by_elements:
89100
elements = (
90101
tuple(sorted(set(data[0].get_chemical_symbols())))
91102
if elements is None
92103
else elements
93104
)
94105

95-
rdf = {
96-
element: ana.get_rdf(
97-
rmax=rmax,
98-
nbins=nbins,
99-
elements=element,
100-
imageIdx=slice(*index),
101-
return_dists=True,
102-
volume=volume,
103-
)
106+
rdfs = {
107+
element: [
108+
get_rdf(
109+
atoms,
110+
rmax,
111+
nbins,
112+
elements=element,
113+
volume=volume,
114+
)
115+
for atoms in data
116+
]
104117
for element in combinations_with_replacement(elements, 2)
105118
}
106119

107120
# Compute RDF average
108121
rdf = {
109-
element: (rdf[0][1], np.average([rdf_i[0] for rdf_i in rdf], axis=0))
110-
for element, rdf in rdf.items()
122+
element: (
123+
element_rdfs[0][1],
124+
np.average([rdf_i[0] for rdf_i in element_rdfs], axis=0),
125+
)
126+
for element, element_rdfs in rdfs.items()
111127
}
112128

113129
if filenames is not None:
@@ -129,14 +145,24 @@ def compute_rdf(
129145
print(dist, rdf_i, file=out_file)
130146

131147
else:
132-
rdf = ana.get_rdf(
133-
rmax=rmax,
134-
nbins=nbins,
135-
elements=elements,
136-
imageIdx=slice(*index),
137-
return_dists=True,
138-
volume=volume,
139-
)
148+
if ana is not None:
149+
rdf = ana.get_rdf(
150+
rmax=rmax,
151+
nbins=nbins,
152+
imageIdx=slice(*index),
153+
return_dists=True,
154+
volume=volume,
155+
)
156+
else:
157+
rdf = [
158+
get_rdf(
159+
atoms,
160+
rmax,
161+
nbins,
162+
volume=volume,
163+
)
164+
for atoms in data
165+
]
140166

141167
assert isinstance(rdf, list)
142168

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
readme = "README.md"
2727

2828
dependencies = [
29-
"ase<4.0,==3.27",
29+
"ase<4.0,>=3.28",
3030
"click<9,>=8.2.1",
3131
"codecarbon<4.0.0,>=3.0.7",
3232
"numpy<3.0.0,>=1.26.4",
@@ -73,7 +73,7 @@ visualise = [
7373

7474
# MLIPs with updated e3nn
7575
mattersim = [
76-
"mattersim == 1.2.0; sys_platform != 'win32'",
76+
"mattersim == 1.2.2; sys_platform != 'win32'",
7777
]
7878
nequip = [
7979
"nequip == 0.14.0",

tests/test_post_process.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pathlib import Path
66

7+
from ase.geometry.analysis import Analysis
78
from ase.io import read
89
import numpy as np
910
import pytest
@@ -132,6 +133,51 @@ def test_rdf():
132133
assert (np.isclose(expected_peaks, rdf[0][peaks])).all()
133134

134135

136+
def test_rdf_with_analysis_deprecation():
137+
"""Test computation of RDF and ana deprecation."""
138+
data = read(DATA_PATH / "benzene.xyz")
139+
ana = Analysis(data)
140+
141+
with pytest.warns(FutureWarning, match="ana has been deprecated."):
142+
rdf = post_process.compute_rdf(data, ana, index=0, rmax=5.0, nbins=100)
143+
144+
assert isinstance(rdf, tuple)
145+
assert isinstance(rdf[0], np.ndarray)
146+
147+
expected_peaks = np.asarray(
148+
(
149+
1.075,
150+
1.375,
151+
2.175,
152+
2.425,
153+
2.475,
154+
2.775,
155+
3.425,
156+
3.875,
157+
4.275,
158+
4.975,
159+
)
160+
)
161+
peaks = np.where(rdf[1] > 0.0)
162+
assert (np.isclose(expected_peaks, rdf[0][peaks])).all()
163+
164+
165+
def test_rdf_by_elements_analysis_error():
166+
"""Test the by_elements method with Analysis raises ValueError."""
167+
data = read(DATA_PATH / "benzene.xyz")
168+
ana = Analysis(data)
169+
170+
with pytest.raises(ValueError, match="Analysis.get_rdf has known bugs"):
171+
post_process.compute_rdf(
172+
data,
173+
ana,
174+
index=0,
175+
rmax=5.0,
176+
nbins=100,
177+
by_elements=True,
178+
)
179+
180+
135181
def test_rdf_by_elements():
136182
"""Test the by_elements method of compute rdf."""
137183
data = read(DATA_PATH / "benzene.xyz")
@@ -150,18 +196,7 @@ def test_rdf_by_elements():
150196

151197
expected_peaks = {
152198
("C", "C"): (1.375, 2.425, 2.775),
153-
("C", "H"): (
154-
1.075,
155-
1.375,
156-
2.175,
157-
2.425,
158-
2.475,
159-
2.775,
160-
3.425,
161-
3.875,
162-
4.275,
163-
4.975,
164-
),
199+
("C", "H"): (1.075, 2.175, 3.425, 3.875),
165200
("H", "H"): (2.475, 4.275, 4.975),
166201
}
167202

0 commit comments

Comments
 (0)