Skip to content

Commit d5e3fae

Browse files
committed
Handle node/graph features/targets in from_ase_atoms_list
1 parent 4f7220a commit d5e3fae

3 files changed

Lines changed: 163 additions & 7 deletions

File tree

orb_models/forcefield/forcefield_adapter.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def from_ase_atoms_list(
209209
if len(atoms) == 1:
210210
return self.from_ase_atoms(
211211
atoms[0],
212+
max_num_neighbors=max_num_neighbors,
212213
edge_method=edge_method,
213214
wrap=wrap,
214215
device=device,
@@ -311,6 +312,7 @@ def from_ase_atoms_list(
311312
"positions": positions,
312313
"atomic_numbers": atomic_numbers,
313314
"atomic_numbers_embedding": atomic_numbers_embedding,
315+
"atom_identity": torch.cat([torch.arange(n, dtype=torch.long) for n in n_atoms], dim=0),
314316
}
315317
edge_feats = {
316318
"vectors": edge_vectors,
@@ -320,6 +322,17 @@ def from_ase_atoms_list(
320322
"cell": cells,
321323
"pbc": pbcs,
322324
}
325+
326+
# Merge extra features from atoms.info
327+
node_feats.update(_batch_info_tensors(atoms, "node_features"))
328+
edge_feats.update(_batch_info_tensors(atoms, "edge_features"))
329+
graph_feats.update(_batch_info_tensors(atoms, "graph_features", system_level=True))
330+
331+
# Collect targets from atoms.info
332+
node_targets = _batch_info_tensors(atoms, "node_targets")
333+
edge_targets = _batch_info_tensors(atoms, "edge_targets")
334+
system_targets = _batch_info_tensors(atoms, "graph_targets", system_level=True)
335+
323336
# Collect charge and spin: all-or-nothing semantics
324337
charge_spin_list = [_get_charge_and_spin(a) for a in atoms]
325338
has_charge_spin = [bool(cs) for cs in charge_spin_list]
@@ -341,9 +354,9 @@ def from_ase_atoms_list(
341354
node_features=node_feats,
342355
edge_features=edge_feats,
343356
system_features=graph_feats,
344-
node_targets={},
345-
edge_targets={},
346-
system_targets={},
357+
node_targets=node_targets,
358+
edge_targets=edge_targets,
359+
system_targets=system_targets,
347360
system_id=None,
348361
fix_atoms=fix_atoms,
349362
tags=tags,
@@ -482,6 +495,39 @@ def is_compatible_with(self, other: AbstractAtomsAdapter):
482495
return True
483496

484497

498+
def _batch_info_tensors(
499+
atoms_list: list[ase.Atoms],
500+
info_key: str,
501+
system_level: bool = False,
502+
) -> dict:
503+
"""Collect tensor dicts from atoms.info[info_key] across a list and batch them.
504+
505+
For system-level tensors, matches from_ase_atoms behavior: unsqueeze non-scalars
506+
(numel > 1) before concatenating, so cat produces [N, ...] for non-scalars
507+
and [N] for scalars.
508+
"""
509+
dicts = [a.info.get(info_key, {}) for a in atoms_list]
510+
all_keys = [set(d.keys()) for d in dicts]
511+
keys = set().union(*all_keys)
512+
if not keys:
513+
return {}
514+
if any(ks != keys for ks in all_keys):
515+
raise ValueError(
516+
f"All atoms must have the same set of keys in info['{info_key}']. "
517+
f"Got: {[sorted(ks) for ks in all_keys]}"
518+
)
519+
out: dict = {}
520+
for k in keys:
521+
values = [d[k] for d in dicts]
522+
if any(v is None for v in values):
523+
out[k] = None
524+
elif system_level:
525+
out[k] = torch.cat([v.unsqueeze(0) if v.numel() > 1 else v for v in values], dim=0)
526+
else:
527+
out[k] = torch.cat(values, dim=0)
528+
return out
529+
530+
485531
def _get_charge_and_spin(atoms: ase.Atoms | ts.SimState) -> dict[str, torch.Tensor]:
486532
out = {}
487533
if isinstance(atoms, ase.Atoms) and ("charge" in atoms.info or "spin" in atoms.info):

tests/common/atoms/test_graph_batch.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,26 @@ def test_random_batching():
131131
assert batched.equals(AtomGraphs.batch(graphs))
132132

133133

134+
def test_batch_none_values_collapse():
135+
"""Batching graphs where a target value is None in one graph produces None for that key."""
136+
g1 = graph()
137+
g2 = graph()
138+
g2.node_targets["node_target"] = None
139+
140+
batched = AtomGraphs.batch([g1, g2])
141+
assert batched.node_targets["node_target"] is None
142+
143+
144+
def test_batch_mismatched_keys_raises():
145+
"""Batching graphs with different feature/target keys raises."""
146+
g1 = graph()
147+
g2 = graph()
148+
g2.system_targets["extra"] = torch.tensor([1.0])
149+
150+
with pytest.raises(ValueError, match="same nested structure"):
151+
AtomGraphs.batch([g1, g2])
152+
153+
134154
def test_refeaturization_pos_substitution(dataset_and_loader):
135155
dataset = dataset_and_loader[0]
136156
datapoint = dataset[0]

tests/forcefield/test_forcefield_adapter.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,23 @@ def test_forcefield_adapter_requires_both_spin_and_charge():
196196

197197
def test_from_ase_atoms_list_parallel_equivalence():
198198
"""Test that from_ase_atoms_list produces equivalent results to sequential processing."""
199-
atoms_list = [
200-
Atoms(
199+
n_atoms = 3
200+
atoms_list = []
201+
for i in range(4):
202+
a = Atoms(
201203
"H2O",
202204
positions=np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]]) + i * 0.1,
203205
pbc=True,
204206
cell=np.diag([5, 5, 5]),
205207
)
206-
for i in range(4)
207-
]
208+
a.info["graph_features"] = {"bandgap": torch.tensor([float(i)])}
209+
a.info["graph_targets"] = {
210+
"energy": torch.tensor([float(i) * 0.5]),
211+
"stress": torch.randn(6),
212+
}
213+
a.info["node_features"] = {"mulliken_charges": torch.randn(n_atoms, 1)}
214+
a.info["node_targets"] = {"forces": torch.randn(n_atoms, 3)}
215+
atoms_list.append(a)
208216
adapter = ForcefieldAtomsAdapter(radius=6.0, max_num_neighbors=20)
209217

210218
batch_result = adapter.from_ase_atoms_list(atoms_list)
@@ -226,6 +234,88 @@ def test_from_ase_atoms_list_parallel_equivalence():
226234
sd = sequential_batch.edge_features["vectors"][start:end].cpu().norm(dim=1).sort()[0]
227235
assert torch.allclose(bd, sd, atol=1e-4)
228236

237+
# Verify features/targets keys match between batched and sequential
238+
assert batch_result.node_features.keys() == sequential_batch.node_features.keys()
239+
assert batch_result.system_features.keys() == sequential_batch.system_features.keys()
240+
assert batch_result.system_targets.keys() == sequential_batch.system_targets.keys()
241+
assert batch_result.node_targets.keys() == sequential_batch.node_targets.keys()
242+
243+
# Verify feature/target values
244+
torch.testing.assert_close(
245+
batch_result.system_features["bandgap"].cpu(),
246+
sequential_batch.system_features["bandgap"].cpu(),
247+
)
248+
torch.testing.assert_close(
249+
batch_result.system_targets["energy"].cpu(),
250+
sequential_batch.system_targets["energy"].cpu(),
251+
)
252+
torch.testing.assert_close(
253+
batch_result.system_targets["stress"].cpu(),
254+
sequential_batch.system_targets["stress"].cpu(),
255+
)
256+
torch.testing.assert_close(
257+
batch_result.node_features["mulliken_charges"].cpu(),
258+
sequential_batch.node_features["mulliken_charges"].cpu(),
259+
)
260+
torch.testing.assert_close(
261+
batch_result.node_targets["forces"].cpu(),
262+
sequential_batch.node_targets["forces"].cpu(),
263+
)
264+
265+
266+
def test_from_ase_atoms_list_inconsistent_info_raises():
267+
"""Test that from_ase_atoms_list raises when atoms have inconsistent info keys."""
268+
adapter = ForcefieldAtomsAdapter(radius=6.0, max_num_neighbors=20)
269+
270+
a0 = Atoms(
271+
"H2O",
272+
positions=np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]]),
273+
pbc=True,
274+
cell=np.diag([5, 5, 5]),
275+
)
276+
a0.info["node_targets"] = {"forces": torch.randn(3, 3)}
277+
a0.info["graph_targets"] = {"energy": torch.tensor([1.0])}
278+
279+
a1 = Atoms(
280+
"H2O",
281+
positions=np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]]) + 0.1,
282+
pbc=True,
283+
cell=np.diag([5, 5, 5]),
284+
)
285+
286+
with pytest.raises(ValueError, match="same set of keys"):
287+
adapter.from_ase_atoms_list([a0, a1])
288+
289+
290+
def test_from_ase_atoms_list_none_values_collapse_to_none():
291+
"""Test that if any atom has None for a key, the batched result is None for that key."""
292+
adapter = ForcefieldAtomsAdapter(radius=6.0, max_num_neighbors=20)
293+
atoms_list = []
294+
for i in range(3):
295+
a = Atoms(
296+
"H2O",
297+
positions=np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]]) + i * 0.1,
298+
pbc=True,
299+
cell=np.diag([5, 5, 5]),
300+
)
301+
a.info["node_targets"] = {
302+
"forces": None if i == 0 else torch.randn(3, 3),
303+
}
304+
a.info["graph_targets"] = {
305+
"energy": torch.tensor([float(i)]),
306+
"stress": None if i == 1 else torch.randn(6),
307+
}
308+
atoms_list.append(a)
309+
310+
batch_result = adapter.from_ase_atoms_list(atoms_list)
311+
312+
assert batch_result.node_targets["forces"] is None
313+
assert batch_result.system_targets["stress"] is None
314+
torch.testing.assert_close(
315+
batch_result.system_targets["energy"].cpu(),
316+
torch.tensor([0.0, 1.0, 2.0]),
317+
)
318+
229319

230320
def test_from_ase_atoms_list_nonperiodic():
231321
"""Test from_ase_atoms_list with non-periodic systems."""

0 commit comments

Comments
 (0)