|
| 1 | +# SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +"""Shared test mixins and utilities used by both pt and pt_expt tests. |
| 3 | +
|
| 4 | +These are kept in ``tests.common`` so that importing them does NOT trigger |
| 5 | +``tests.pt.__init__`` (which sets ``torch.set_default_device("cuda:9999999")`` |
| 6 | +and pushes a ``DeviceContext`` onto the mode stack, breaking pt_expt tests |
| 7 | +on CPU-only machines). |
| 8 | +""" |
| 9 | + |
| 10 | +import numpy as np |
| 11 | + |
| 12 | + |
| 13 | +class TestCaseSingleFrameWithNlist: |
| 14 | + """Mixin providing a small 2-frame, 2-type test system with neighbor list.""" |
| 15 | + |
| 16 | + def setUp(self) -> None: |
| 17 | + # nloc == 3, nall == 4 |
| 18 | + self.nloc = 3 |
| 19 | + self.nall = 4 |
| 20 | + self.nf, self.nt = 2, 2 |
| 21 | + self.coord_ext = np.array( |
| 22 | + [ |
| 23 | + [0, 0, 0], |
| 24 | + [0, 1, 0], |
| 25 | + [0, 0, 1], |
| 26 | + [0, -2, 0], |
| 27 | + ], |
| 28 | + dtype=np.float64, |
| 29 | + ).reshape([1, self.nall, 3]) |
| 30 | + self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall]) |
| 31 | + self.mapping = np.array([0, 1, 2, 0], dtype=int).reshape([1, self.nall]) |
| 32 | + # sel = [5, 2] |
| 33 | + self.sel = [5, 2] |
| 34 | + self.sel_mix = [7] |
| 35 | + self.natoms = [3, 3, 2, 1] |
| 36 | + self.nlist = np.array( |
| 37 | + [ |
| 38 | + [1, 3, -1, -1, -1, 2, -1], |
| 39 | + [0, -1, -1, -1, -1, 2, -1], |
| 40 | + [0, 1, -1, -1, -1, -1, -1], |
| 41 | + ], |
| 42 | + dtype=int, |
| 43 | + ).reshape([1, self.nloc, sum(self.sel)]) |
| 44 | + self.rcut = 2.2 |
| 45 | + self.rcut_smth = 0.4 |
| 46 | + # permutations |
| 47 | + self.perm = np.array([2, 0, 1, 3], dtype=np.int32) |
| 48 | + inv_perm = np.array([1, 2, 0, 3], dtype=np.int32) |
| 49 | + # permute the coord and atype |
| 50 | + self.coord_ext = np.concatenate( |
| 51 | + [self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0 |
| 52 | + ).reshape(self.nf, self.nall * 3) |
| 53 | + self.atype_ext = np.concatenate( |
| 54 | + [self.atype_ext, self.atype_ext[:, self.perm]], axis=0 |
| 55 | + ) |
| 56 | + self.mapping = np.concatenate( |
| 57 | + [self.mapping, self.mapping[:, self.perm]], axis=0 |
| 58 | + ) |
| 59 | + |
| 60 | + # permute the nlist |
| 61 | + nlist1 = self.nlist[:, self.perm[: self.nloc], :] |
| 62 | + mask = nlist1 == -1 |
| 63 | + nlist1 = inv_perm[nlist1] |
| 64 | + nlist1 = np.where(mask, -1, nlist1) |
| 65 | + self.nlist = np.concatenate([self.nlist, nlist1], axis=0) |
| 66 | + self.atol = 1e-12 |
| 67 | + |
| 68 | + |
| 69 | +def get_tols(prec): |
| 70 | + """Return (rtol, atol) for a given precision string.""" |
| 71 | + if prec in ["single", "float32"]: |
| 72 | + rtol, atol = 0.0, 1e-4 |
| 73 | + elif prec in ["double", "float64"]: |
| 74 | + rtol, atol = 0.0, 1e-12 |
| 75 | + else: |
| 76 | + raise ValueError(f"unknown prec {prec}") |
| 77 | + return rtol, atol |
0 commit comments