Skip to content
8 changes: 8 additions & 0 deletions dpdata/deepmd/mixed.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check other files in this directory to see how we support ANY new keys with the plugin system.

if labels:
dtypes = dpdata.system.LabeledSystem.DTYPES
else:
dtypes = dpdata.system.System.DTYPES
for dtype in dtypes:
if dtype.name in (
"atom_numbs",

Copy link
Copy Markdown
Contributor Author

@anyangml anyangml Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check other files in this directory to see how we support ANY new keys with the plugin system.

@njzjz I am not quite following. Are you suggesting that there are missing changes need to be added, or you want the feature to be implemented using a different approach. This PR only fix the fparam bug in mixed systems.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fparam is implemented with plugins, so I don't suggest handling it specially. Ideally, we should handle any registered data type.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fparam is implemented with plugins, so I don't suggest handling it specially. Ideally, we should handle any registered data type.

If I understand correctly, you are suggesting refactoring the temp_idx logic into comp.py, not only for fparam but also include all the other dtypes?

temp_idx = np.arange(all_real_atom_types_concat.shape[0])[
(all_real_atom_types_concat == all_real_atom_types_concat[0]).all(-1)
]
rest_idx = np.arange(all_real_atom_types_concat.shape[0])[
(all_real_atom_types_concat != all_real_atom_types_concat[0]).any(-1)
]
temp_data = data.copy()
temp_data["atom_names"] = data["atom_names"].copy()
temp_data["atom_numbs"] = temp_atom_numbs
temp_data["atom_types"] = all_real_atom_types_concat[0]
all_real_atom_types_concat = all_real_atom_types_concat[rest_idx]
temp_data["cells"] = all_cells_concat[temp_idx]
all_cells_concat = all_cells_concat[rest_idx]
temp_data["coords"] = all_coords_concat[temp_idx]
all_coords_concat = all_coords_concat[rest_idx]
if labels:
if all_eners_concat is not None and all_eners_concat.size > 0:
temp_data["energies"] = all_eners_concat[temp_idx]
all_eners_concat = all_eners_concat[rest_idx]
if all_forces_concat is not None and all_forces_concat.size > 0:
temp_data["forces"] = all_forces_concat[temp_idx]
all_forces_concat = all_forces_concat[rest_idx]
if all_virs_concat is not None and all_virs_concat.size > 0:
temp_data["virials"] = all_virs_concat[temp_idx]
all_virs_concat = all_virs_concat[rest_idx]
data_list.append(temp_data)

That probably be done in a separate PR as a refactor. This PR only aims to fix the bug.

Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
all_real_atom_types_concat = index_map[all_real_atom_types_concat]
all_cells_concat = data["cells"]
all_coords_concat = data["coords"]
all_fparam_concat = data.get("fparam", None)
all_aparam_concat = data.get("aparam", None)
if labels:
all_eners_concat = data.get("energies")
all_forces_concat = data.get("forces")
Expand Down Expand Up @@ -56,6 +58,12 @@
all_cells_concat = all_cells_concat[rest_idx]
temp_data["coords"] = all_coords_concat[temp_idx]
all_coords_concat = all_coords_concat[rest_idx]
if all_fparam_concat:
temp_data["fparam"] = all_fparam_concat[temp_idx]
all_fparam_concat = all_fparam_concat[rest_idx]

Check warning on line 63 in dpdata/deepmd/mixed.py

View check run for this annotation

Codecov / codecov/patch

dpdata/deepmd/mixed.py#L62-L63

Added lines #L62 - L63 were not covered by tests
if all_aparam_concat:
temp_data["aparam"] = all_aparam_concat[temp_idx]
all_aparam_concat = all_aparam_concat[rest_idx]

Check warning on line 66 in dpdata/deepmd/mixed.py

View check run for this annotation

Codecov / codecov/patch

dpdata/deepmd/mixed.py#L65-L66

Added lines #L65 - L66 were not covered by tests
Comment thread
anyangml marked this conversation as resolved.
Outdated
if labels:
if all_eners_concat is not None and all_eners_concat.size > 0:
temp_data["energies"] = all_eners_concat[temp_idx]
Expand Down