Skip to content

Commit 9646d71

Browse files
author
Han Wang
committed
revert extend_coord_with_ghosts
1 parent 2384835 commit 9646d71

1 file changed

Lines changed: 2 additions & 6 deletions

File tree

source/tests/consistent/descriptor/common.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,17 +153,13 @@ def eval_pt_expt_descriptor(
153153
box: np.ndarray,
154154
mixed_types: bool = False,
155155
) -> Any:
156-
# Use the torch-native neighbor list utilities to avoid array_api_compat
157-
# allocations on CUDA. The array_api path can hit torch empty/ones/eye/etc
158-
# on CUDA, which all rely on aten::empty_strided and fail in CI builds
159-
# where that CUDA kernel is not available.
160-
ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt(
156+
ext_coords, ext_atype, mapping = extend_coord_with_ghosts(
161157
torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3),
162158
torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1),
163159
torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3),
164160
pt_expt_obj.get_rcut(),
165161
)
166-
nlist = build_neighbor_list_pt(
162+
nlist = build_neighbor_list(
167163
ext_coords,
168164
ext_atype,
169165
natoms[0],

0 commit comments

Comments
 (0)