@@ -196,15 +196,23 @@ def test_forcefield_adapter_requires_both_spin_and_charge():
196196
197197def test_from_ase_atoms_list_parallel_equivalence ():
198198 """Test that from_ase_atoms_list produces equivalent results to sequential processing."""
199- atoms_list = [
200- Atoms (
199+ n_atoms = 3
200+ atoms_list = []
201+ for i in range (4 ):
202+ a = Atoms (
201203 "H2O" ,
202204 positions = np .array ([[0 , 0 , 0 ], [0 , 1 , 0 ], [1 , 0 , 0 ]]) + i * 0.1 ,
203205 pbc = True ,
204206 cell = np .diag ([5 , 5 , 5 ]),
205207 )
206- for i in range (4 )
207- ]
208+ a .info ["graph_features" ] = {"bandgap" : torch .tensor ([float (i )])}
209+ a .info ["graph_targets" ] = {
210+ "energy" : torch .tensor ([float (i ) * 0.5 ]),
211+ "stress" : torch .randn (6 ),
212+ }
213+ a .info ["node_features" ] = {"mulliken_charges" : torch .randn (n_atoms , 1 )}
214+ a .info ["node_targets" ] = {"forces" : torch .randn (n_atoms , 3 )}
215+ atoms_list .append (a )
208216 adapter = ForcefieldAtomsAdapter (radius = 6.0 , max_num_neighbors = 20 )
209217
210218 batch_result = adapter .from_ase_atoms_list (atoms_list )
@@ -226,6 +234,88 @@ def test_from_ase_atoms_list_parallel_equivalence():
226234 sd = sequential_batch .edge_features ["vectors" ][start :end ].cpu ().norm (dim = 1 ).sort ()[0 ]
227235 assert torch .allclose (bd , sd , atol = 1e-4 )
228236
237+ # Verify features/targets keys match between batched and sequential
238+ assert batch_result .node_features .keys () == sequential_batch .node_features .keys ()
239+ assert batch_result .system_features .keys () == sequential_batch .system_features .keys ()
240+ assert batch_result .system_targets .keys () == sequential_batch .system_targets .keys ()
241+ assert batch_result .node_targets .keys () == sequential_batch .node_targets .keys ()
242+
243+ # Verify feature/target values
244+ torch .testing .assert_close (
245+ batch_result .system_features ["bandgap" ].cpu (),
246+ sequential_batch .system_features ["bandgap" ].cpu (),
247+ )
248+ torch .testing .assert_close (
249+ batch_result .system_targets ["energy" ].cpu (),
250+ sequential_batch .system_targets ["energy" ].cpu (),
251+ )
252+ torch .testing .assert_close (
253+ batch_result .system_targets ["stress" ].cpu (),
254+ sequential_batch .system_targets ["stress" ].cpu (),
255+ )
256+ torch .testing .assert_close (
257+ batch_result .node_features ["mulliken_charges" ].cpu (),
258+ sequential_batch .node_features ["mulliken_charges" ].cpu (),
259+ )
260+ torch .testing .assert_close (
261+ batch_result .node_targets ["forces" ].cpu (),
262+ sequential_batch .node_targets ["forces" ].cpu (),
263+ )
264+
265+
266+ def test_from_ase_atoms_list_inconsistent_info_raises ():
267+ """Test that from_ase_atoms_list raises when atoms have inconsistent info keys."""
268+ adapter = ForcefieldAtomsAdapter (radius = 6.0 , max_num_neighbors = 20 )
269+
270+ a0 = Atoms (
271+ "H2O" ,
272+ positions = np .array ([[0 , 0 , 0 ], [0 , 1 , 0 ], [1 , 0 , 0 ]]),
273+ pbc = True ,
274+ cell = np .diag ([5 , 5 , 5 ]),
275+ )
276+ a0 .info ["node_targets" ] = {"forces" : torch .randn (3 , 3 )}
277+ a0 .info ["graph_targets" ] = {"energy" : torch .tensor ([1.0 ])}
278+
279+ a1 = Atoms (
280+ "H2O" ,
281+ positions = np .array ([[0 , 0 , 0 ], [0 , 1 , 0 ], [1 , 0 , 0 ]]) + 0.1 ,
282+ pbc = True ,
283+ cell = np .diag ([5 , 5 , 5 ]),
284+ )
285+
286+ with pytest .raises (ValueError , match = "same set of keys" ):
287+ adapter .from_ase_atoms_list ([a0 , a1 ])
288+
289+
290+ def test_from_ase_atoms_list_none_values_collapse_to_none ():
291+ """Test that if any atom has None for a key, the batched result is None for that key."""
292+ adapter = ForcefieldAtomsAdapter (radius = 6.0 , max_num_neighbors = 20 )
293+ atoms_list = []
294+ for i in range (3 ):
295+ a = Atoms (
296+ "H2O" ,
297+ positions = np .array ([[0 , 0 , 0 ], [0 , 1 , 0 ], [1 , 0 , 0 ]]) + i * 0.1 ,
298+ pbc = True ,
299+ cell = np .diag ([5 , 5 , 5 ]),
300+ )
301+ a .info ["node_targets" ] = {
302+ "forces" : None if i == 0 else torch .randn (3 , 3 ),
303+ }
304+ a .info ["graph_targets" ] = {
305+ "energy" : torch .tensor ([float (i )]),
306+ "stress" : None if i == 1 else torch .randn (6 ),
307+ }
308+ atoms_list .append (a )
309+
310+ batch_result = adapter .from_ase_atoms_list (atoms_list )
311+
312+ assert batch_result .node_targets ["forces" ] is None
313+ assert batch_result .system_targets ["stress" ] is None
314+ torch .testing .assert_close (
315+ batch_result .system_targets ["energy" ].cpu (),
316+ torch .tensor ([0.0 , 1.0 , 2.0 ]),
317+ )
318+
229319
230320def test_from_ase_atoms_list_nonperiodic ():
231321 """Test from_ase_atoms_list with non-periodic systems."""
0 commit comments