|
4 | 4 | Union, |
5 | 5 | ) |
6 | 6 |
|
| 7 | +import numpy as np |
| 8 | + |
| 9 | +from deepmd.env import ( |
| 10 | + GLOBAL_NP_FLOAT_PRECISION, |
| 11 | +) |
7 | 12 | from deepmd.tf.env import ( |
| 13 | + GLOBAL_TF_FLOAT_PRECISION, |
8 | 14 | MODEL_VERSION, |
9 | 15 | global_cvt_2_ener_float, |
10 | 16 | tf, |
@@ -169,10 +175,35 @@ def build( |
169 | 175 | ) |
170 | 176 |
|
171 | 177 | # Apply out_bias and out_std directly to tensor output |
172 | | - atype_selected = self._get_selected_atype(atype, natoms) |
173 | | - output = self._apply_out_bias_std( |
174 | | - output, atype, natoms, coord, selected_atype=atype_selected |
175 | | - ) |
| 178 | + # dipole not applying bias but polar does, per dpmodel |
| 179 | + if self.model_type in {"polar"} and self.fitting.shift_diag: |
| 180 | + v_constant_matrix = np.zeros( |
| 181 | + self.ntypes, |
| 182 | + dtype=GLOBAL_NP_FLOAT_PRECISION, |
| 183 | + ) |
| 184 | + for itype in range(len(self.get_sel_type())): |
| 185 | + v_constant_matrix[self.get_sel_type()[itype]] = np.mean( |
| 186 | + np.diagonal(self.out_bias[0, itype].reshape((3, 3))) |
| 187 | + ) |
| 188 | + nframes = input_dict["nframes"] |
| 189 | + nloc_mask = tf.reshape( |
| 190 | + tf.tile(tf.repeat(self.fitting.sel_mask, natoms[2:]), [nframes]), |
| 191 | + [nframes, -1], |
| 192 | + ) |
| 193 | + constant_matrix = tf.reshape( |
| 194 | + tf.reshape( |
| 195 | + tf.tile(tf.repeat(v_constant_matrix, natoms[2:]), [nframes]), |
| 196 | + [nframes, -1], |
| 197 | + )[nloc_mask], |
| 198 | + [nframes, -1], |
| 199 | + ) |
| 200 | + |
| 201 | + # nf x nloc x odims, out_bias: ntypes x odims |
| 202 | + output = output + tf.reshape( |
| 203 | + tf.expand_dims(tf.expand_dims(constant_matrix, -1), -1) |
| 204 | + * tf.eye(3, batch_shape=[1, 1], dtype=GLOBAL_TF_FLOAT_PRECISION), |
| 205 | + tf.shape(output), |
| 206 | + ) |
176 | 207 | framesize = nout if "global" in self.model_type else natomsel * nout |
177 | 208 | output = tf.reshape( |
178 | 209 | output, [-1, framesize], name="o_" + self.model_type + suffix |
|
0 commit comments