@@ -343,7 +343,7 @@ def forward(
343343 # use default fparam
344344 assert self .default_fparam_tensor is not None
345345 fparam = torch .tile (self .default_fparam_tensor .unsqueeze (0 ), [nf , 1 ])
346- atomic_number = self .atype_to_idx [atype .view (- 1 )].view (nf , nloc )
346+ atomic_number = self .atype_to_idx [atype .view (- 1 )].view (nf , nloc ) + 1
347347 tags = torch .zeros_like (atomic_number )
348348 pbc = torch .ones ([3 ], dtype = torch .bool , device = coord .device )
349349 fixed_idx = torch .zeros_like (atomic_number )
@@ -368,12 +368,12 @@ def forward(
368368 model_ret = self .model (batch )
369369
370370 # apply energy bias
371- model_ret ["energy" ] = model_ret ["energy" ] + self .out_bias [0 , :, 0 ][atype ].sum (- 1 )
371+ model_ret ["energy" ] = model_ret ["energy" ] * 0.8125 + self .out_bias [0 , :, 0 ][atype ].sum (- 1 )
372372
373373 model_predict = {}
374374 model_predict ["energy" ] = model_ret ["energy" ]
375- model_predict ["force" ] = model_ret ["forces" ].view (nf , nloc , 3 )
376- model_predict ["virial" ] = model_ret ["virial" ].view (nf , 9 )
375+ model_predict ["force" ] = model_ret ["forces" ].view (nf , nloc , 3 )* 0.8125
376+ model_predict ["virial" ] = model_ret ["virial" ].view (nf , 9 )* 0.8125
377377 return model_predict
378378
379379 @torch .jit .export
0 commit comments