diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index cfdcfdca96..e745b28f94 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -78,7 +78,7 @@ def xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) -> Array: xp = array_api_compat.array_namespace(input) # Create flat index array matching input shape - idx = xp.arange(input.size, dtype=xp.int64) + idx = xp.arange(input.size, dtype=xp.int64, device=array_api_compat.device(input)) idx = xp.reshape(idx, input.shape) # Get flat indices where we want to add values @@ -190,6 +190,10 @@ def xp_bincount(x: Array, weights: Array | None = None, minlength: int = 0) -> A else: if weights is None: weights = xp.ones_like(x) - result = xp.zeros((max(minlength, int(xp.max(x)) + 1),), dtype=weights.dtype) + result = xp.zeros( + (max(minlength, int(xp.max(x)) + 1),), + dtype=weights.dtype, + device=array_api_compat.device(weights), + ) result = xp_add_at(result, x, weights) return result diff --git a/deepmd/dpmodel/atomic_model/polar_atomic_model.py b/deepmd/dpmodel/atomic_model/polar_atomic_model.py index 2180e48265..bdad31dcb1 100644 --- a/deepmd/dpmodel/atomic_model/polar_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/polar_atomic_model.py @@ -48,6 +48,7 @@ def apply_out_stat( if self.fitting.shift_diag: nframes, nloc = atype.shape dtype = out_bias[self.bias_keys[0]].dtype + device = array_api_compat.device(out_bias[self.bias_keys[0]]) for kk in self.bias_keys: ntypes = out_bias[kk].shape[0] temp = xp.mean( @@ -61,7 +62,7 @@ def apply_out_stat( modified_bias[..., xp.newaxis] * (self.fitting.scale[atype]) ) - eye = xp.eye(3, dtype=dtype) + eye = xp.eye(3, dtype=dtype, device=device) eye = xp.tile(eye, (nframes, nloc, 1, 1)) # (nframes, nloc, 3, 3) modified_bias = modified_bias[..., xp.newaxis] * eye diff --git a/source/checker/deepmd_checker.py b/source/checker/deepmd_checker.py index c066fe0c10..dc1f31dc69 100644 --- a/source/checker/deepmd_checker.py +++ b/source/checker/deepmd_checker.py @@ -60,7 +60,7 @@ def visit_call(self, node) -> None: no_device = False if kw.arg == "dtype": no_dtype = False - if no_device and node.func.expr.name == "torch": + if no_device and node.func.expr.name in {"torch", "xp"}: # only PT needs device self.add_message("no-explicit-device", node=node) if no_dtype: