Skip to content

Commit 4d16274

Browse files
fix xpu
1 parent 23f7cc2 commit 4d16274

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

deepmd/pd/utils/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,10 @@ def get_generator(
348348
)
349349
elif DEVICE == "xpu":
350350
generator = paddle.framework.core.default_xpu_generator(0)
351+
elif DEVICE.startswith("xpu:"):
352+
generator = paddle.framework.core.default_cuda_generator(
353+
int(DEVICE.split("xpu:")[1])
354+
)
351355
else:
352356
# return none for compability in different devices
353357
warnings.warn(

0 commit comments

Comments
 (0)