-
Notifications
You must be signed in to change notification settings - Fork 611
Expand file tree
/
Copy pathconvert_backend.py
More file actions
46 lines (41 loc) · 1.29 KB
/
convert_backend.py
File metadata and controls
46 lines (41 loc) · 1.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)
from deepmd.backend.backend import (
Backend,
)
def convert_backend(
*, # Enforce keyword-only arguments
INPUT: str,
OUTPUT: str,
atomic_virial: bool = False,
**kwargs: Any,
) -> None:
"""Convert a model file from one backend to another.
Parameters
----------
INPUT : str
The input model file.
OUTPUT : str
The output model file.
atomic_virial : bool
If True, export .pt2/.pte models with per-atom virial correction.
This adds ~2.5x inference cost. Default False.
"""
inp_backend: Backend = Backend.detect_backend_by_model(INPUT)()
out_backend: Backend = Backend.detect_backend_by_model(OUTPUT)()
inp_hook = inp_backend.serialize_hook
out_hook = out_backend.deserialize_hook
data = inp_hook(INPUT)
# Forward atomic_virial to pt_expt deserialize_to_file if applicable
import inspect
sig = inspect.signature(out_hook)
if "do_atomic_virial" in sig.parameters:
out_hook(OUTPUT, data, do_atomic_virial=atomic_virial)
else:
if atomic_virial:
raise ValueError(
"--atomic-virial is only supported for pt_expt .pt2/.pte outputs"
)
out_hook(OUTPUT, data)