Skip to content

Commit 6d67e25

Browse files
committed
enable consistent tests
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent 519b7bb commit 6d67e25

9 files changed

Lines changed: 68 additions & 12 deletions

File tree

deepmd/tf/fit/dipole.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,6 @@ def serialize(self, suffix: str) -> dict:
423423
"dim_descrpt": self.dim_descrpt,
424424
"embedding_width": self.dim_rot_mat_1,
425425
"mixed_types": self.mixed_types,
426-
"dim_out": 3,
427426
"neuron": self.n_neuron,
428427
"resnet_dt": self.resnet_dt,
429428
"numb_fparam": self.numb_fparam,
@@ -458,6 +457,15 @@ def serialize(self, suffix: str) -> dict:
458457
),
459458
},
460459
"type_map": self.type_map,
460+
"var_name": "dipole",
461+
"rcond": None,
462+
"tot_ener_zero": False,
463+
"trainable": self.trainable,
464+
"layer_name": None,
465+
"use_aparam_as_mask": False,
466+
"spin": None,
467+
"r_differentiable": True,
468+
"c_differentiable": True,
461469
}
462470
return data
463471

deepmd/tf/fit/dos.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,11 @@ def serialize(self, suffix: str = "") -> dict:
750750
"case_embd": None,
751751
},
752752
"type_map": self.type_map,
753+
"tot_ener_zero": False,
754+
"layer_name": None,
755+
"use_aparam_as_mask": False,
756+
"spin": None,
757+
"atom_ener": None,
753758
}
754759
return data
755760

deepmd/tf/fit/polar.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
import numpy as np
88

9+
from deepmd.env import (
10+
GLOBAL_NP_FLOAT_PRECISION,
11+
)
912
from deepmd.tf.common import (
1013
cast_precision,
1114
get_activation_func,
@@ -641,7 +644,6 @@ def serialize(self, suffix: str) -> dict:
641644
"dim_descrpt": self.dim_descrpt,
642645
"embedding_width": self.dim_rot_mat_1,
643646
"mixed_types": self.mixed_types,
644-
"dim_out": 3,
645647
"neuron": self.n_neuron,
646648
"resnet_dt": self.resnet_dt,
647649
"numb_fparam": self.numb_fparam,
@@ -652,7 +654,6 @@ def serialize(self, suffix: str) -> dict:
652654
"precision": self.fitting_precision.name,
653655
"exclude_types": [],
654656
"fit_diag": self.fit_diag,
655-
"scale": list(self.scale),
656657
"shift_diag": self.shift_diag,
657658
"nets": self.serialize_network(
658659
ntypes=self.ntypes,
@@ -674,8 +675,18 @@ def serialize(self, suffix: str) -> dict:
674675
"case_embd": None,
675676
"scale": self.scale.reshape(-1, 1),
676677
"constant_matrix": self.constant_matrix.reshape(-1),
678+
"bias_atom_e": np.zeros(
679+
(self.ntypes, self.dim_rot_mat_1), dtype=GLOBAL_NP_FLOAT_PRECISION
680+
),
677681
},
678682
"type_map": self.type_map,
683+
"var_name": "polar",
684+
"rcond": None,
685+
"tot_ener_zero": False,
686+
"trainable": self.trainable,
687+
"layer_name": None,
688+
"use_aparam_as_mask": False,
689+
"spin": None,
679690
}
680691
return data
681692

deepmd/tf/model/tensor.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
Union,
55
)
66

7+
import numpy as np
8+
9+
from deepmd.env import (
10+
GLOBAL_NP_FLOAT_PRECISION,
11+
)
712
from deepmd.tf.env import (
13+
GLOBAL_TF_FLOAT_PRECISION,
814
MODEL_VERSION,
915
global_cvt_2_ener_float,
1016
tf,
@@ -169,10 +175,35 @@ def build(
169175
)
170176

171177
# Apply out_bias and out_std directly to tensor output
172-
atype_selected = self._get_selected_atype(atype, natoms)
173-
output = self._apply_out_bias_std(
174-
output, atype, natoms, coord, selected_atype=atype_selected
175-
)
178+
# dipole not applying bias but polar does, per dpmodel
179+
if self.model_type in {"polar"} and self.fitting.shift_diag:
180+
v_constant_matrix = np.zeros(
181+
self.ntypes,
182+
dtype=GLOBAL_NP_FLOAT_PRECISION,
183+
)
184+
for itype in range(len(self.get_sel_type())):
185+
v_constant_matrix[self.get_sel_type()[itype]] = np.mean(
186+
np.diagonal(self.out_bias[0, itype].reshape((3, 3)))
187+
)
188+
nframes = input_dict["nframes"]
189+
nloc_mask = tf.reshape(
190+
tf.tile(tf.repeat(self.fitting.sel_mask, natoms[2:]), [nframes]),
191+
[nframes, -1],
192+
)
193+
constant_matrix = tf.reshape(
194+
tf.reshape(
195+
tf.tile(tf.repeat(v_constant_matrix, natoms[2:]), [nframes]),
196+
[nframes, -1],
197+
)[nloc_mask],
198+
[nframes, -1],
199+
)
200+
201+
# nf x nloc x odims, out_bias: ntypes x odims
202+
output = output + tf.reshape(
203+
tf.expand_dims(tf.expand_dims(constant_matrix, -1), -1)
204+
* tf.eye(3, batch_shape=[1, 1], dtype=GLOBAL_TF_FLOAT_PRECISION),
205+
tf.shape(output),
206+
)
176207
framesize = nout if "global" in self.model_type else natomsel * nout
177208
output = tf.reshape(
178209
output, [-1, framesize], name="o_" + self.model_type + suffix

source/tests/consistent/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def test_tf_consistent_with_ref(self) -> None:
354354
data1.pop("@version")
355355
data2.pop("@version")
356356

357-
if tf_obj.__class__.__name__.startswith("Polar"):
357+
if tf_obj.__class__.__name__.startswith("PolarFitting"):
358358
data1["@variables"].pop("bias_atom_e")
359359

360360
np.testing.assert_equal(data1, data2)

source/tests/consistent/model/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def build_tf_model(
7777
]
7878
elif ret_key == "polar":
7979
ret_list = [
80-
ret["polar"],
8180
ret["global_polar"],
81+
ret["polar"],
8282
]
8383
else:
8484
raise NotImplementedError

source/tests/consistent/model/test_dipole.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def data(self) -> dict:
7373
pt_class = DipoleModelPT
7474
jax_class = DipoleModelJAX
7575
args = model_args()
76+
atol = 1e-8
7677

7778
def get_reference_backend(self):
7879
"""Get the reference backend.
@@ -89,7 +90,7 @@ def get_reference_backend(self):
8990

9091
@property
9192
def skip_tf(self):
92-
return True # need to fix tf consistency
93+
return not INSTALLED_TF
9394

9495
@property
9596
def skip_jax(self) -> bool:

source/tests/consistent/model/test_dos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def get_reference_backend(self):
9090

9191
@property
9292
def skip_tf(self):
93-
return True # need to fix tf consistency
93+
return not INSTALLED_TF
9494

9595
@property
9696
def skip_jax(self) -> bool:

source/tests/consistent/model/test_polar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def get_reference_backend(self):
8989

9090
@property
9191
def skip_tf(self):
92-
return True # need to fix tf consistency
92+
return not INSTALLED_TF
9393

9494
@property
9595
def skip_jax(self) -> bool:

0 commit comments

Comments
 (0)