|
13 | 13 | OutputVariableDef, |
14 | 14 | fitting_check_output, |
15 | 15 | ) |
| 16 | +from deepmd.dpmodel.utils.seed import ( |
| 17 | + child_seed, |
| 18 | +) |
| 19 | +from deepmd.pt.model.network.mlp import ( |
| 20 | + FittingNet, |
| 21 | + NetworkCollection, |
| 22 | +) |
16 | 23 | from deepmd.pt.model.network.network import ( |
17 | 24 | ResidualDeep, |
18 | 25 | ) |
| 26 | +from deepmd.pt.model.network.utils import ( |
| 27 | + aggregate, |
| 28 | +) |
19 | 29 | from deepmd.pt.model.task.fitting import ( |
20 | 30 | Fitting, |
21 | 31 | GeneralFitting, |
@@ -257,3 +267,155 @@ def forward( |
257 | 267 | "energy": outs.to(env.GLOBAL_PT_FLOAT_PRECISION), |
258 | 268 | "dforce": vec_out, |
259 | 269 | } |
| 270 | + |
| 271 | + |
| 272 | +@Fitting.register("ener_readout") |
| 273 | +@fitting_check_output |
| 274 | +class EnergyFittingNetReadout(InvarFitting): |
| 275 | + def __init__( |
| 276 | + self, |
| 277 | + ntypes: int, |
| 278 | + dim_descrpt: int, |
| 279 | + neuron: list[int] = [128, 128, 128], |
| 280 | + bias_atom_e: Optional[torch.Tensor] = None, |
| 281 | + resnet_dt: bool = True, |
| 282 | + numb_fparam: int = 0, |
| 283 | + numb_aparam: int = 0, |
| 284 | + dim_case_embd: int = 0, |
| 285 | + embedding_width: int = 128, |
| 286 | + activation_function: str = "tanh", |
| 287 | + precision: str = DEFAULT_PRECISION, |
| 288 | + mixed_types: bool = True, |
| 289 | + seed: Optional[Union[int, list[int]]] = None, |
| 290 | + type_map: Optional[list[str]] = None, |
| 291 | + norm_fact: list[float] = [120.0], |
| 292 | + add_edge_readout: bool = True, |
| 293 | + slim_edge_readout: bool = False, |
| 294 | + **kwargs, |
| 295 | + ) -> None: |
| 296 | + """Construct a fitting net for energy. |
| 297 | +
|
| 298 | + Args: |
| 299 | + - ntypes: Element count. |
| 300 | + - embedding_width: Embedding width per atom. |
| 301 | + - neuron: Number of neurons in each hidden layers of the fitting net. |
| 302 | + - bias_atom_e: Average energy per atom for each element. |
| 303 | + - resnet_dt: Using time-step in the ResNet construction. |
| 304 | + """ |
| 305 | + self.add_edge_readout = add_edge_readout |
| 306 | + super().__init__( |
| 307 | + "energy", |
| 308 | + ntypes, |
| 309 | + dim_descrpt, |
| 310 | + 1, |
| 311 | + neuron=neuron, |
| 312 | + bias_atom_e=bias_atom_e, |
| 313 | + resnet_dt=resnet_dt, |
| 314 | + numb_fparam=numb_fparam, |
| 315 | + numb_aparam=numb_aparam, |
| 316 | + dim_case_embd=dim_case_embd, |
| 317 | + activation_function=activation_function, |
| 318 | + precision=precision, |
| 319 | + mixed_types=mixed_types, |
| 320 | + seed=seed, |
| 321 | + type_map=type_map, |
| 322 | + **kwargs, |
| 323 | + ) |
| 324 | + |
| 325 | + # embedding for edge readout |
| 326 | + self.embedding_width = embedding_width |
| 327 | + self.slim_edge_readout = slim_edge_readout |
| 328 | + self.norm_e_fact = norm_fact[0] |
| 329 | + |
| 330 | + if self.add_edge_readout: |
| 331 | + self.edge_embed = NetworkCollection( |
| 332 | + 1 if not self.mixed_types else 0, |
| 333 | + self.ntypes, |
| 334 | + network_type="fitting_network", |
| 335 | + networks=[ |
| 336 | + FittingNet( |
| 337 | + self.embedding_width, |
| 338 | + 1, |
| 339 | + self.neuron if not self.slim_edge_readout else self.neuron[:1], |
| 340 | + self.activation_function, |
| 341 | + self.resnet_dt, |
| 342 | + self.precision, |
| 343 | + bias_out=True, |
| 344 | + seed=child_seed(self.seed + 100, ii), |
| 345 | + ) |
| 346 | + for ii in range(self.ntypes if not self.mixed_types else 1) |
| 347 | + ], |
| 348 | + ) |
| 349 | + else: |
| 350 | + self.edge_embed = None |
| 351 | + |
| 352 | + # set trainable |
| 353 | + for param in self.parameters(): |
| 354 | + param.requires_grad = self.trainable |
| 355 | + |
| 356 | + # make jit happy with torch 2.0.0 |
| 357 | + exclude_types: list[int] |
| 358 | + |
| 359 | + def need_additional_input(self) -> bool: |
| 360 | + return True |
| 361 | + |
| 362 | + def serialize(self) -> dict: |
| 363 | + raise NotImplementedError |
| 364 | + |
| 365 | + @classmethod |
| 366 | + def deserialize(cls, data: dict) -> "EnergyFittingNetReadout": |
| 367 | + raise NotImplementedError |
| 368 | + |
| 369 | + def forward( |
| 370 | + self, |
| 371 | + descriptor: torch.Tensor, |
| 372 | + atype: torch.Tensor, |
| 373 | + gr: Optional[torch.Tensor] = None, |
| 374 | + g2: Optional[torch.Tensor] = None, |
| 375 | + h2: Optional[torch.Tensor] = None, |
| 376 | + fparam: Optional[torch.Tensor] = None, |
| 377 | + aparam: Optional[torch.Tensor] = None, |
| 378 | + sw: Optional[torch.Tensor] = None, |
| 379 | + edge_index: Optional[torch.Tensor] = None, |
| 380 | + ): |
| 381 | + """Based on embedding net output, alculate total energy. |
| 382 | +
|
| 383 | + Args: |
| 384 | + - inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.dim_descrpt]. |
| 385 | + - natoms: Tell atom count and element count. Its shape is [2+self.ntypes]. |
| 386 | +
|
| 387 | + Returns |
| 388 | + ------- |
| 389 | + - `torch.Tensor`: Total energy with shape [nframes, natoms[0]]. |
| 390 | + """ |
| 391 | + out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ |
| 392 | + self.var_name |
| 393 | + ] |
| 394 | + nf, nloc, _ = descriptor.shape |
| 395 | + |
| 396 | + if self.add_edge_readout: |
| 397 | + assert g2 is not None |
| 398 | + assert sw is not None |
| 399 | + assert self.edge_embed is not None |
| 400 | + # nf x nloc x nnei x d [OR] nedge x d |
| 401 | + edge_feature = g2 |
| 402 | + # nf x nloc x nnei x 1 [OR] nedge x 1 |
| 403 | + edge_atomic_contrib = self.edge_embed.networks[0](edge_feature) |
| 404 | + # nf x nloc x nnei x 1 [OR] nedge x 1 |
| 405 | + edge_atomic_contrib = edge_atomic_contrib * sw.unsqueeze(-1) |
| 406 | + if edge_index is not None: |
| 407 | + # use dynamic sel |
| 408 | + n2e_index, n_ext2e_index = edge_index[0], edge_index[1] |
| 409 | + # nf x nloc x 1 |
| 410 | + edge_energy = aggregate( |
| 411 | + edge_atomic_contrib, |
| 412 | + n2e_index, |
| 413 | + average=False, |
| 414 | + num_owner=nf * nloc, |
| 415 | + ).reshape(nf, nloc, 1) |
| 416 | + else: |
| 417 | + # nf x nloc x 1 |
| 418 | + edge_energy = torch.sum(edge_atomic_contrib, dim=-2) |
| 419 | + # energy |
| 420 | + out = out + edge_energy / self.norm_e_fact |
| 421 | + return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} |
0 commit comments