@@ -422,3 +422,239 @@ def forward(
422422 # energy
423423 out = out + edge_energy / self .norm_e_fact
424424 return {self .var_name : out .to (env .GLOBAL_PT_FLOAT_PRECISION )}
425+
426+
427+ @Fitting .register ("ener_direct" )
428+ @fitting_check_output
429+ class EnergyFittingNetDirectHead (InvarFitting ):
430+ def __init__ (
431+ self ,
432+ ntypes : int ,
433+ dim_descrpt : int ,
434+ neuron : list [int ] = [128 , 128 , 128 ],
435+ bias_atom_e : Optional [torch .Tensor ] = None ,
436+ resnet_dt : bool = True ,
437+ numb_fparam : int = 0 ,
438+ numb_aparam : int = 0 ,
439+ dim_case_embd : int = 0 ,
440+ embedding_width : int = 128 ,
441+ activation_function : str = "tanh" ,
442+ precision : str = DEFAULT_PRECISION ,
443+ mixed_types : bool = True ,
444+ seed : Optional [Union [int , list [int ]]] = None ,
445+ type_map : Optional [list [str ]] = None ,
446+ additional_gradient : bool = False ,
447+ additional_noise_head : bool = False ,
448+ ** kwargs : Any ,
449+ ) -> None :
450+ """Construct a fitting net for energy.
451+
452+ Args:
453+ - ntypes: Element count.
454+ - embedding_width: Embedding width per atom.
455+ - neuron: Number of neurons in each hidden layers of the fitting net.
456+ - bias_atom_e: Average energy per atom for each element.
457+ - resnet_dt: Using time-step in the ResNet construction.
458+ """
459+ self .additional_gradient = additional_gradient
460+ self .additional_noise_head = additional_noise_head
461+ super ().__init__ (
462+ "energy" ,
463+ ntypes ,
464+ dim_descrpt ,
465+ 1 ,
466+ neuron = neuron ,
467+ bias_atom_e = bias_atom_e ,
468+ resnet_dt = resnet_dt ,
469+ numb_fparam = numb_fparam ,
470+ numb_aparam = numb_aparam ,
471+ dim_case_embd = dim_case_embd ,
472+ activation_function = activation_function ,
473+ precision = precision ,
474+ mixed_types = mixed_types ,
475+ seed = seed ,
476+ type_map = type_map ,
477+ ** kwargs ,
478+ )
479+
480+ # embedding for direct force
481+ self .force_input_dim = embedding_width # can add force embedding if needed
482+ self .force_embed = NetworkCollection (
483+ 1 if not self .mixed_types else 0 ,
484+ self .ntypes ,
485+ network_type = "fitting_network" ,
486+ networks = [
487+ FittingNet (
488+ self .force_input_dim ,
489+ 1 ,
490+ self .neuron ,
491+ self .activation_function ,
492+ self .resnet_dt ,
493+ self .precision ,
494+ bias_out = True ,
495+ seed = child_seed (self .seed + 100 , ii ),
496+ )
497+ for ii in range (self .ntypes if not self .mixed_types else 1 )
498+ ],
499+ )
500+ # additional noise head
501+ self .noise_input_dim = embedding_width # can add noise embedding if needed
502+ if self .additional_noise_head :
503+ # dforce for force; dnosie for noise
504+ self .noise_embed = NetworkCollection (
505+ 1 if not self .mixed_types else 0 ,
506+ self .ntypes ,
507+ network_type = "fitting_network" ,
508+ networks = [
509+ FittingNet (
510+ self .noise_input_dim ,
511+ 1 ,
512+ self .neuron ,
513+ self .activation_function ,
514+ self .resnet_dt ,
515+ self .precision ,
516+ bias_out = True ,
517+ seed = child_seed (self .seed + 200 , ii ),
518+ )
519+ for ii in range (self .ntypes if not self .mixed_types else 1 )
520+ ],
521+ )
522+ else :
523+ # dforce for noise
524+ self .noise_embed = None
525+
526+ # set trainable
527+ for param in self .parameters ():
528+ param .requires_grad = self .trainable
529+
530+ def output_def (self ) -> FittingOutputDef :
531+ out_list = [
532+ OutputVariableDef (
533+ self .var_name ,
534+ [self .dim_out ],
535+ reducible = True ,
536+ r_differentiable = self .additional_gradient ,
537+ c_differentiable = self .additional_gradient ,
538+ ),
539+ OutputVariableDef (
540+ "dforce" ,
541+ [3 ],
542+ reducible = False ,
543+ r_differentiable = False ,
544+ c_differentiable = False ,
545+ ),
546+ ]
547+ if self .additional_noise_head :
548+ out_list .append (
549+ OutputVariableDef (
550+ "dnoise" ,
551+ [3 ],
552+ reducible = False ,
553+ r_differentiable = False ,
554+ c_differentiable = False ,
555+ )
556+ )
557+
558+ return FittingOutputDef (out_list )
559+
560+ # make jit happy with torch 2.0.0
561+ exclude_types : list [int ]
562+
563+ def need_additional_input (self ) -> bool :
564+ return True
565+
566+ def serialize (self ) -> dict :
567+ raise NotImplementedError
568+
569+ @classmethod
570+ def deserialize (cls , data : dict ) -> "EnergyFittingNetDirectHead" :
571+ raise NotImplementedError
572+
573+ def change_type_map (
574+ self , type_map : list [str ], model_with_new_type_stat : Optional [Any ] = None
575+ ) -> None :
576+ raise NotImplementedError
577+
578+ def get_type_map (self ) -> list [str ]:
579+ raise NotImplementedError
580+
581+ def forward (
582+ self ,
583+ descriptor : torch .Tensor ,
584+ atype : torch .Tensor ,
585+ gr : Optional [torch .Tensor ] = None ,
586+ g2 : Optional [torch .Tensor ] = None ,
587+ h2 : Optional [torch .Tensor ] = None ,
588+ fparam : Optional [torch .Tensor ] = None ,
589+ aparam : Optional [torch .Tensor ] = None ,
590+ diff : Optional [torch .Tensor ] = None ,
591+ edge_index : Optional [torch .Tensor ] = None ,
592+ sw : Optional [torch .Tensor ] = None ,
593+ ) -> dict [str , torch .Tensor ]:
594+ """Based on embedding net output, alculate total energy.
595+
596+ Args:
597+ - inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.dim_descrpt].
598+ - natoms: Tell atom count and element count. Its shape is [2+self.ntypes].
599+
600+ Returns
601+ -------
602+ - `torch.Tensor`: Total energy with shape [nframes, natoms[0]].
603+ """
604+ out = self ._forward_common (descriptor , atype , gr , g2 , h2 , fparam , aparam )[
605+ self .var_name
606+ ]
607+ # energy
608+ result = {self .var_name : out .to (env .GLOBAL_PT_FLOAT_PRECISION )}
609+
610+ # direct force
611+ assert diff is not None
612+ assert g2 is not None
613+
614+ nf , nloc , _ = descriptor .shape
615+
616+ # nf x nloc x nnei x 3 [OR] nedge x 3
617+ edge_vec = diff
618+ # nf x nloc x nnei x d [OR] nedge x d
619+ edge_feature = g2
620+ # nf x nloc x nnei x 1 [OR] nedge x 1
621+ edge_weight = self .force_embed .networks [0 ](edge_feature )
622+ # nf x nloc x nnei x 3 [OR] nedge x 3
623+ fij = edge_weight * edge_vec
624+ if edge_index is not None :
625+ # use dynamic sel
626+ n2e_index , n_ext2e_index = edge_index [:, 0 ], edge_index [:, 1 ]
627+ # nf x nloc x 3
628+ fi = aggregate (
629+ fij ,
630+ n2e_index ,
631+ average = False ,
632+ num_owner = nf * nloc ,
633+ ).reshape (nf , nloc , 3 )
634+ else :
635+ # nf x nloc x 3
636+ fi = torch .sum (fij , dim = - 2 )
637+
638+ result ["dforce" ] = fi
639+
640+ if self .additional_noise_head :
641+ assert self .noise_embed is not None
642+ edge_weight = self .noise_embed .networks [0 ](edge_feature )
643+ # nf x nloc x nnei x 3 [OR] nedge x 3
644+ nij = edge_weight * edge_vec
645+ if edge_index is not None :
646+ # use dynamic sel
647+ n2e_index , n_ext2e_index = edge_index [:, 0 ], edge_index [:, 1 ]
648+ # nf x nloc x 3
649+ ni = aggregate (
650+ nij ,
651+ n2e_index ,
652+ average = False ,
653+ num_owner = nf * nloc ,
654+ ).reshape (nf , nloc , 3 )
655+ else :
656+ # nf x nloc x 3
657+ ni = torch .sum (nij , dim = - 2 )
658+ result ["dnoise" ] = ni
659+
660+ return result
0 commit comments