11# SPDX-License-Identifier: LGPL-3.0-or-later
2- import os
32
43import numpy as np
54import torch
1615from deepmd .pt .modifier .base_modifier import (
1716 BaseModifier ,
1817)
18+ from deepmd .pt .modifier .dp_modifier import (
19+ DPModifier ,
20+ )
1921from deepmd .pt .utils import (
2022 env ,
2123)
22- from deepmd .pt .utils .serialization import (
23- serialize_from_file ,
24- )
2524from deepmd .pt .utils .utils import (
2625 to_numpy_array ,
2726 to_torch_tensor ,
2827)
2928
3029
3130@BaseModifier .register ("dipole_charge" )
32- class DipoleChargeModifier (BaseModifier ):
31+ class DipoleChargeModifier (DPModifier ):
3332 """Parameters
3433 ----------
3534 model_name
@@ -46,36 +45,22 @@ class DipoleChargeModifier(BaseModifier):
4645
4746 def __init__ (
4847 self ,
49- model_name : str | None ,
48+ dp_model : DipoleModel | None ,
5049 model_charge_map : list [float ],
5150 sys_charge_map : list [float ],
5251 ewald_h : float = 1.0 ,
5352 ewald_beta : float = 1.0 ,
54- ewald_batch_size : int = 5 ,
55- dp_batch_size : int | None = None ,
56- model : DipoleModel | None = None ,
53+ dp_model_file_name : str | None = None ,
5754 use_cache : bool = True ,
5855 ) -> None :
5956 """Constructor."""
60- super ().__init__ (use_cache = use_cache )
6157 self .modifier_type = "dipole_charge"
58+ super ().__init__ (
59+ dp_model = dp_model ,
60+ dp_model_file_name = dp_model_file_name ,
61+ use_cache = use_cache ,
62+ )
6263
63- if model_name is None and model is None :
64- raise AttributeError ("`model_name` or `model` should be specified." )
65- if model_name is not None and model is not None :
66- raise AttributeError (
67- "`model_name` and `model` cannot be used simultaneously."
68- )
69-
70- if model is not None :
71- self ._model = model .to (env .DEVICE )
72- if model_name is not None :
73- data = serialize_from_file (model_name )
74- self ._model = DipoleModel .deserialize (data ["model" ]).to (env .DEVICE )
75- self ._model .eval ()
76-
77- # use jit model for inference
78- self .model = torch .jit .script (self ._model )
7964 self .rcut = self .model .get_rcut ()
8065 self .type_map = self .model .get_type_map ()
8166 sel_type = self .model .get_sel_type ()
@@ -95,23 +80,22 @@ def __init__(
9580 # init ewald recp
9681 self .ewald_h = ewald_h
9782 self .ewald_beta = ewald_beta
98- self . er = CoulombForceModule (
83+ er = CoulombForceModule (
9984 rcut = self .rcut ,
10085 rspace = False ,
10186 kappa = ewald_beta ,
10287 spacing = ewald_h ,
103- )
88+ ).to (env .GLOBAL_PT_FLOAT_PRECISION )
89+ self .er = torch .jit .script (er )
90+ self .er .eval ()
10491 self .placeholder_pairs = torch .ones ((1 , 2 ), device = env .DEVICE , dtype = torch .long )
105- self .placeholder_ds = torch .ones ((1 ), device = env .DEVICE , dtype = torch .float64 )
92+ self .placeholder_ds = torch .ones (
93+ (1 ), device = env .DEVICE , dtype = env .GLOBAL_PT_FLOAT_PRECISION
94+ )
10695 self .placeholder_buffer_scales = torch .zeros (
107- (1 ), device = env .DEVICE , dtype = torch . float64
96+ (1 ), device = env .DEVICE , dtype = env . GLOBAL_PT_FLOAT_PRECISION
10897 )
10998
110- self .ewald_batch_size = ewald_batch_size
111- if dp_batch_size is None :
112- dp_batch_size = int (os .environ .get ("DP_INFER_BATCH_SIZE" , 1 ))
113- self .dp_batch_size = dp_batch_size
114-
11599 def serialize (self ) -> dict :
116100 """Serialize the modifier.
117101
@@ -120,16 +104,13 @@ def serialize(self) -> dict:
120104 dict
121105 The serialized data
122106 """
123- dd = BaseModifier .serialize (self )
107+ dd = super () .serialize ()
124108 dd .update (
125109 {
126- "model" : self ._model .serialize (),
127110 "model_charge_map" : self ._model_charge_map ,
128111 "sys_charge_map" : self ._sys_charge_map ,
129112 "ewald_h" : self .ewald_h ,
130113 "ewald_beta" : self .ewald_beta ,
131- "ewald_batch_size" : self .ewald_batch_size ,
132- "dp_batch_size" : self .dp_batch_size ,
133114 }
134115 )
135116 return dd
@@ -140,9 +121,9 @@ def deserialize(cls, data: dict) -> "DipoleChargeModifier":
140121 data .pop ("@class" , None )
141122 data .pop ("type" , None )
142123 data .pop ("@version" , None )
143- model_obj = DipoleModel .deserialize (data .pop ("model " ))
144- data ["model " ] = model_obj
145- data ["model_name " ] = None
124+ model_obj = DipoleModel .deserialize (data .pop ("dp_model " ))
125+ data ["dp_model " ] = model_obj
126+ data ["dp_model_file_name " ] = None
146127 return cls (** data )
147128
148129 def forward (
@@ -213,30 +194,23 @@ def forward(
213194 )
214195
215196 # add Ewald reciprocal correction
216- tot_e : list [torch .Tensor ] = []
217- chunk_coord = torch .split (
218- extended_coord .reshape (nframes , - 1 , 3 ), self .dp_batch_size , dim = 0
197+ placeholder_pairs = torch .tile (
198+ self .placeholder_pairs .unsqueeze (0 ), (nframes , 1 , 1 )
219199 )
220- chunk_box = torch .split (
221- input_box .reshape (nframes , 3 , 3 ), self .dp_batch_size , dim = 0
200+ placeholder_ds = torch .tile (self .placeholder_ds .unsqueeze (0 ), (nframes , 1 ))
201+ placeholder_buffer_scales = torch .tile (
202+ self .placeholder_buffer_scales .unsqueeze (0 ), (nframes , 1 )
222203 )
223- chunk_charge = torch .split (
224- extended_charge .reshape (nframes , - 1 ), self .dp_batch_size , dim = 0
204+ self .er (
205+ extended_coord .reshape (nframes , 2 * natoms , 3 ),
206+ input_box .reshape (nframes , 3 , 3 ),
207+ placeholder_pairs ,
208+ placeholder_ds ,
209+ placeholder_buffer_scales ,
210+ {"charge" : extended_charge .reshape (nframes , 2 * natoms )},
225211 )
226- for _coord , _box , _charge in zip (
227- chunk_coord , chunk_box , chunk_charge , strict = True
228- ):
229- self .er (
230- _coord ,
231- _box ,
232- self .placeholder_pairs ,
233- self .placeholder_ds ,
234- self .placeholder_buffer_scales ,
235- {"charge" : _charge },
236- )
237- tot_e .append (self .er .reciprocal_energy .unsqueeze (0 ))
238212 # nframe,
239- tot_e = torch . concat ( tot_e , dim = 0 )
213+ tot_e = self . er . reciprocal_energy
240214 # nframe, nat * 3
241215 tot_f = - calc_grads (tot_e , input_coord )
242216 # nframe, nat, 3
@@ -346,37 +320,17 @@ def extend_system_coord(
346320 nframes = coord .shape [0 ]
347321 natoms = coord .shape [1 ] // 3
348322
349- all_dipole : list [torch .Tensor ] = []
350- chunk_coord = torch .split (coord , self .dp_batch_size , dim = 0 )
351- chunk_atype = torch .split (atype , self .dp_batch_size , dim = 0 )
352- chunk_box = torch .split (box , self .dp_batch_size , dim = 0 )
353- # use placeholder to make the jit happy for fparam/aparam is None
354- chunk_fparam = (
355- torch .split (fparam , self .dp_batch_size , dim = 0 )
356- if fparam is not None
357- else chunk_atype
358- )
359- chunk_aparam = (
360- torch .split (aparam , self .dp_batch_size , dim = 0 )
361- if aparam is not None
362- else chunk_atype
323+ model_pred = self .model (
324+ coord = coord ,
325+ atype = atype ,
326+ box = box ,
327+ do_atomic_virial = False ,
328+ fparam = fparam if fparam is not None else None ,
329+ aparam = aparam if aparam is not None else None ,
363330 )
364- for _coord , _atype , _box , _fparam , _aparam in zip (
365- chunk_coord , chunk_atype , chunk_box , chunk_fparam , chunk_aparam , strict = True
366- ):
367- dipole_batch = self .model (
368- coord = _coord ,
369- atype = _atype ,
370- box = _box ,
371- do_atomic_virial = False ,
372- fparam = _fparam if fparam is not None else None ,
373- aparam = _aparam if aparam is not None else None ,
374- )
375- # Extract dipole from the output dictionary
376- all_dipole .append (dipole_batch ["dipole" ])
377331
378- # nframe x natoms x 3
379- dipole = torch . cat ( all_dipole , dim = 0 )
332+ # nframe, natoms, 3
333+ dipole = model_pred [ "dipole" ]
380334 if dipole .shape [0 ] != nframes :
381335 raise RuntimeError (
382336 f"Dipole shape mismatch: expected { nframes } frames, got { dipole .shape [0 ]} "
0 commit comments