Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,12 @@ node_modules/
test_dp_test/
test_dp_test_*.out
*_detail.out

# Training and model output files
*.pth
*.ckpt*
checkpoint
lcurve.out
out.json
input_v2_compat.json
frozen_model.*
10 changes: 7 additions & 3 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def __init__(
self._networks = [None for ii in range(ntypes**ndim)]
for ii, network in enumerate(networks):
self[ii] = network
if len(networks):
if len(networks) and all(net is not None for net in networks):
self.check_completeness()

def check_completeness(self) -> None:
Expand Down Expand Up @@ -969,7 +969,9 @@ def __getitem__(self, key):
return self._networks[self._convert_key(key)]

def __setitem__(self, key, value) -> None:
if isinstance(value, self.network_type):
if value is None:
pass
elif isinstance(value, self.network_type):
pass
elif isinstance(value, dict):
value = self.network_type.deserialize(value)
Expand All @@ -993,7 +995,9 @@ def serialize(self) -> dict:
"ndim": self.ndim,
"ntypes": self.ntypes,
"network_type": network_type_name,
"networks": [nn.serialize() for nn in self._networks],
"networks": [
nn.serialize() if nn is not None else None for nn in self._networks
],
}

@classmethod
Expand Down
8 changes: 8 additions & 0 deletions deepmd/tf/fit/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def serialize(self, suffix: str) -> dict:
"activation_function": self.activation_function_name,
"precision": self.fitting_precision.name,
"exclude_types": [],
"sel_type": self.sel_type,
"nets": self.serialize_network(
ntypes=self.ntypes,
ndim=0 if self.mixed_types else 1,
Expand All @@ -434,6 +435,13 @@ def serialize(self, suffix: str) -> dict:
trainable=self.trainable,
suffix=suffix,
),
"@variables": {
"fparam_avg": self.fparam_avg,
"fparam_inv_std": self.fparam_inv_std,
"aparam_avg": self.aparam_avg,
"aparam_inv_std": self.aparam_inv_std,
"case_embd": None,
},
"type_map": self.type_map,
}
return data
Expand Down
4 changes: 3 additions & 1 deletion deepmd/tf/fit/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,9 @@ def deserialize_network(cls, data: dict, suffix: str = "") -> dict:
else:
raise ValueError(f"Invalid ndim: {fittings.ndim}")
network = fittings[net_idx]
assert network is not None
if network is None:
# Skip types that are not selected (when sel_type is used)
continue
for layer_idx, layer in enumerate(network.layers):
if layer_idx == len(network.layers) - 1:
layer_name = "final_layer"
Expand Down
Loading