Skip to content

Commit f2fbe88

Browse files
author
Han Wang
committed
mv to_torch_tensor to common
1 parent 1cc001f commit f2fbe88

2 files changed

Lines changed: 39 additions & 12 deletions

File tree

deepmd/pt_expt/common.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import importlib
3+
from typing import (
4+
Any,
5+
overload,
6+
)
7+
8+
import numpy as np
9+
10+
from deepmd.pt_expt.utils import (
11+
env,
12+
)
13+
14+
torch = importlib.import_module("torch")
15+
16+
17+
@overload
18+
def to_torch_array(array: np.ndarray) -> torch.Tensor: ...
19+
20+
21+
@overload
22+
def to_torch_array(array: None) -> None: ...
23+
24+
25+
@overload
26+
def to_torch_array(array: torch.Tensor) -> torch.Tensor: ...
27+
28+
29+
def to_torch_array(array: Any) -> torch.Tensor | None:
30+
"""Convert input to a torch tensor on the pt-expt device."""
31+
if array is None:
32+
return None
33+
if torch.is_tensor(array):
34+
return array
35+
return torch.as_tensor(array, device=env.DEVICE)

deepmd/pt_expt/utils/network.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,13 @@
1919
make_fitting_network,
2020
make_multilayer_network,
2121
)
22-
from deepmd.pt_expt.utils import (
23-
env,
22+
from deepmd.pt_expt.common import (
23+
to_torch_array,
2424
)
2525

2626
torch = importlib.import_module("torch")
2727

2828

29-
def _to_torch_array(value: Any) -> torch.Tensor | None:
30-
if value is None:
31-
return None
32-
if torch.is_tensor(value):
33-
return value
34-
return torch.as_tensor(value, device=env.DEVICE)
35-
36-
3729
class TorchArrayParam(torch.nn.Parameter):
3830
def __new__(cls, data: Any = None, requires_grad: bool = True) -> Self:
3931
return torch.nn.Parameter.__new__(cls, data, requires_grad)
@@ -52,7 +44,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
5244
for name in ("w", "b", "idt"):
5345
if name in self._parameters or name in self._buffers:
5446
continue
55-
val = _to_torch_array(getattr(self, name))
47+
val = to_torch_array(getattr(self, name))
5648
if val is None:
5749
continue
5850
if self.trainable:
@@ -66,7 +58,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
6658

6759
def __setattr__(self, name: str, value: Any) -> None:
6860
if name in {"w", "b", "idt"} and "_parameters" in self.__dict__:
69-
val = _to_torch_array(value)
61+
val = to_torch_array(value)
7062
if val is None:
7163
return super().__setattr__(name, None)
7264
if getattr(self, "trainable", False):

0 commit comments

Comments
 (0)