|
64 | 64 | to_numpy_array, |
65 | 65 | to_torch_tensor, |
66 | 66 | ) |
| 67 | +from deepmd.utils.econf_embd import ( |
| 68 | + sort_element_type, |
| 69 | +) |
67 | 70 |
|
68 | 71 | if TYPE_CHECKING: |
69 | 72 | import ase.neighborlist |
@@ -98,6 +101,7 @@ def __init__( |
98 | 101 | auto_batch_size: Union[bool, int, AutoBatchSize] = True, |
99 | 102 | neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None, |
100 | 103 | head: Optional[Union[str, int]] = None, |
| 104 | + no_jit: bool = False, |
101 | 105 | **kwargs: Any, |
102 | 106 | ) -> None: |
103 | 107 | self.output_def = output_def |
@@ -130,7 +134,7 @@ def __init__( |
130 | 134 | ] = state_dict[item].clone() |
131 | 135 | state_dict = state_dict_head |
132 | 136 | model = get_model(self.input_param).to(DEVICE) |
133 | | - if not self.input_param.get("hessian_mode"): |
| 137 | + if not self.input_param.get("hessian_mode") and not no_jit: |
134 | 138 | model = torch.jit.script(model) |
135 | 139 | self.dp = ModelWrapper(model) |
136 | 140 | self.dp.load_state_dict(state_dict) |
@@ -648,6 +652,22 @@ def get_model_size(self) -> dict: |
648 | 652 | "total": sum_param_des + sum_param_fit, |
649 | 653 | } |
650 | 654 |
|
| 655 | + def get_observed_types(self) -> dict: |
| 656 | + """Get observed types (elements) of the model during data statistics. |
| 657 | +
|
| 658 | + Returns |
| 659 | + ------- |
| 660 | + dict |
| 661 | + A dictionary containing the information of observed type in the model: |
| 662 | + - 'type_num': the total number of observed types in this model. |
| 663 | + - 'observed_type': a list of the observed types in this model. |
| 664 | + """ |
| 665 | + observed_type_list = self.dp.model["Default"].get_observed_type_list() |
| 666 | + return { |
| 667 | + "type_num": len(observed_type_list), |
| 668 | + "observed_type": sort_element_type(observed_type_list), |
| 669 | + } |
| 670 | + |
651 | 671 | def eval_descriptor( |
652 | 672 | self, |
653 | 673 | coords: np.ndarray, |
|
0 commit comments