1616)
1717from deepmd .pt .model .network .network import (
1818 TypeEmbedNet ,
19- TypeEmbedNetConsistent ,
2019)
2120from deepmd .pt .model .task .fitting import (
2221 Fitting ,
@@ -130,7 +129,6 @@ def __init__(
130129 type_map : list [str ] | None = None ,
131130 use_aparam_as_mask : bool = False ,
132131 default_fparam : list [float ] | None = None ,
133- use_type_embed_for_bias_q : bool | None = None ,
134132 ** kwargs : Any ,
135133 ) -> None :
136134 super ().__init__ ()
@@ -176,17 +174,9 @@ def __init__(
176174 assert self .ntypes == bias_atom_e .shape [0 ], "Element count mismatches!"
177175 self .register_buffer ("bias_atom_e" , bias_atom_e )
178176
179- if bias_atom_q is not None :
180- self .use_type_embed_for_bias_q = False
181- elif use_type_embed_for_bias_q is None :
182- self .use_type_embed_for_bias_q = True
183- else :
184- self .use_type_embed_for_bias_q = use_type_embed_for_bias_q
185- self .bias_atom_q_type_embed : TypeEmbedNet | None = None
186- self .bias_atom_q : torch .nn .Parameter | None = None
187- if self .use_type_embed_for_bias_q :
188- # Build per-type lr bias from a type-only embedding network.
189- self .bias_atom_q_type_embed = TypeEmbedNet (
177+ if bias_atom_q is None :
178+ # No external bias provided: learn per-type bias via TypeEmbedNet.
179+ self .bias_atom_q : torch .nn .Parameter | TypeEmbedNet = TypeEmbedNet (
190180 type_nums = self .ntypes ,
191181 embed_dim = self .lr_net_dim_out ,
192182 precision = self .precision ,
@@ -195,9 +185,6 @@ def __init__(
195185 trainable = self .trainable ,
196186 )
197187 else :
198- if bias_atom_q is None :
199- # small random initialization to break saddle point
200- bias_atom_q = np .random .randn (self .ntypes , self .lr_net_dim_out ) * 0.01
201188 bias_atom_q = torch .tensor (
202189 bias_atom_q , dtype = env .GLOBAL_PT_FLOAT_PRECISION , device = device
203190 )
@@ -340,8 +327,7 @@ def change_type_map(
340327 )
341328 self .bias_atom_e = torch .cat ([self .bias_atom_e , extend_bias_atom_e ], dim = 0 )
342329
343- if not self .use_type_embed_for_bias_q :
344- assert self .bias_atom_q is not None
330+ if isinstance (self .bias_atom_q , torch .nn .Parameter ):
345331 extend_shape_q = [len (type_map ), * list (self .bias_atom_q .shape [1 :])]
346332 extend_bias_atom_q = torch .zeros (
347333 extend_shape_q ,
@@ -354,15 +340,13 @@ def change_type_map(
354340 )
355341
356342 self .bias_atom_e = self .bias_atom_e [remap_index ]
357- if self .use_type_embed_for_bias_q :
358- assert self .bias_atom_q_type_embed is not None
359- self .bias_atom_q_type_embed .change_type_map (type_map = type_map )
360- else :
361- assert self .bias_atom_q is not None
343+ if isinstance (self .bias_atom_q , torch .nn .Parameter ):
362344 self .bias_atom_q = torch .nn .Parameter (
363345 self .bias_atom_q .data [remap_index ],
364346 requires_grad = bool (self .trainable ),
365347 )
348+ else :
349+ self .bias_atom_q .change_type_map (type_map = type_map )
366350
367351 def serialize (self ) -> dict :
368352 """Serialize the fitting to dict."""
@@ -388,12 +372,6 @@ def serialize(self) -> dict:
388372 "nets_lr" : self .filter_layers_lr .serialize (),
389373 "rcond" : self .rcond ,
390374 "exclude_types" : self .exclude_types ,
391- "use_type_embed_for_bias_q" : self .use_type_embed_for_bias_q ,
392- "bias_atom_q_type_embed" : (
393- self .bias_atom_q_type_embed .embedding .serialize ()
394- if self .bias_atom_q_type_embed is not None
395- else None
396- ),
397375 "@variables" : {
398376 "bias_atom_e" : to_numpy_array (self .bias_atom_e ),
399377 "bias_atom_q" : to_numpy_array (self ._get_bias_atom_q_table ()),
@@ -421,26 +399,15 @@ def serialize(self) -> dict:
421399 @classmethod
422400 def deserialize (cls , data : dict ) -> "LRFittingNet" :
423401 data = data .copy ()
424- use_type_embed_for_bias_q = data . get ( "use_type_embed_for_bias_q" , False )
425- data [ "use_type_embed_for_bias_q" ] = use_type_embed_for_bias_q
426- bias_atom_q_type_embed = data .pop ("bias_atom_q_type_embed" , None )
402+ # Compatibility with old checkpoints.
403+ data . pop ( "use_type_embed_for_bias_q" , None )
404+ data .pop ("bias_atom_q_type_embed" , None )
427405 variables = data .pop ("@variables" )
428406 nets_sr = data .pop ("nets_sr" )
429407 nets_lr = data .pop ("nets_lr" )
430408 obj = cls (** data )
431- if obj .use_type_embed_for_bias_q and bias_atom_q_type_embed is not None :
432- assert obj .bias_atom_q_type_embed is not None
433- obj .bias_atom_q_type_embed .embedding = TypeEmbedNetConsistent .deserialize (
434- bias_atom_q_type_embed
435- )
436409 for kk in variables .keys ():
437410 if variables [kk ] is not None :
438- if (
439- kk == "bias_atom_q"
440- and obj .use_type_embed_for_bias_q
441- and bias_atom_q_type_embed is not None
442- ):
443- continue
444411 obj [kk ] = to_torch_tensor (variables [kk ])
445412 obj .filter_layers_sr = NetworkCollection .deserialize (nets_sr )
446413 obj .filter_layers_lr = NetworkCollection .deserialize (nets_lr )
@@ -500,15 +467,10 @@ def __setitem__(self, key: str, value: torch.Tensor) -> None:
500467 self .bias_atom_e = value
501468 elif key in ["bias_atom_q" ]:
502469 value = value .view ([self .ntypes , self ._lr_net_out_dim ()])
503- if self .bias_atom_q is None :
504- self .use_type_embed_for_bias_q = False
505- self .bias_atom_q_type_embed = None
506- self .bias_atom_q = torch .nn .Parameter (
507- value ,
508- requires_grad = bool (self .trainable ),
509- )
510- else :
511- self .bias_atom_q .data .copy_ (value )
470+ self .bias_atom_q = torch .nn .Parameter (
471+ value ,
472+ requires_grad = bool (self .trainable ),
473+ )
512474 elif key in ["fparam_avg" ]:
513475 self .fparam_avg = value
514476 elif key in ["fparam_inv_std" ]:
@@ -565,22 +527,20 @@ def _compress_bias_atom_q(self, bias: torch.Tensor) -> torch.Tensor:
565527 return self .bias_atom_q_bound * torch .tanh (bias / self .bias_atom_q_bound )
566528
567529 def _get_bias_atom_q_table (self ) -> torch .Tensor :
568- if self .bias_atom_q is not None :
530+ if isinstance ( self .bias_atom_q , torch . nn . Parameter ) :
569531 return self ._compress_bias_atom_q (self .bias_atom_q )
570- assert self .bias_atom_q_type_embed is not None
571532 # `TypeEmbedNet` appends one zero-padding row; keep only real atom types.
572- bias_table = self .bias_atom_q_type_embed .get_full_embedding (self .bias_atom_e .device )[
533+ bias_table = self .bias_atom_q .get_full_embedding (self .bias_atom_e .device )[
573534 : self .ntypes
574535 ]
575536 return self ._compress_bias_atom_q (bias_table )
576537
577538 def _get_lr_bias (self , atype : torch .Tensor ) -> torch .Tensor :
578539 atype_long = atype .to (torch .long )
579- if self .bias_atom_q is not None :
540+ if isinstance ( self .bias_atom_q , torch . nn . Parameter ) :
580541 return self ._compress_bias_atom_q (self .bias_atom_q [atype_long ].to (self .prec ))
581- assert self .bias_atom_q_type_embed is not None
582542 return self ._compress_bias_atom_q (
583- self .bias_atom_q_type_embed (atype_long ).to (self .prec )
543+ self .bias_atom_q (atype_long ).to (self .prec )
584544 )
585545
586546 def _extend_f_avg_std (self , xx : torch .Tensor , nb : int ) -> torch .Tensor :
0 commit comments