Skip to content

Commit b79e71d

Browse files
committed
Clarify CPU-specific thread warning
1 parent 3f0087b commit b79e71d

4 files changed

Lines changed: 15 additions & 9 deletions

File tree

deepmd/env.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,21 @@ def set_env_if_empty(key: str, value: str, verbose: bool = True) -> None:
6868
)
6969

7070

71-
def set_default_nthreads() -> None:
71+
def set_default_nthreads(use_cpu: bool = False) -> None:
7272
"""Set internal number of threads to default=automatic selection.
7373
74+
Parameters
75+
----------
76+
use_cpu : bool, optional
77+
If ``True``, suppress warnings about thread configuration,
78+
by default ``False``.
79+
7480
Notes
7581
-----
7682
`DP_INTRA_OP_PARALLELISM_THREADS` and `DP_INTER_OP_PARALLELISM_THREADS`
7783
control configuration of multithreading.
7884
"""
79-
if (
85+
if not use_cpu and (
8086
"OMP_NUM_THREADS" not in os.environ
8187
# for backward compatibility
8288
or (
@@ -89,10 +95,10 @@ def set_default_nthreads() -> None:
8995
)
9096
):
9197
log.warning(
92-
"To get the best performance, it is recommended to adjust "
93-
"the number of threads by setting the environment variables "
94-
"OMP_NUM_THREADS, DP_INTRA_OP_PARALLELISM_THREADS, and "
95-
"DP_INTER_OP_PARALLELISM_THREADS. See "
98+
"To get the best CPU performance, adjust the number of threads by "
99+
"setting the environment variables OMP_NUM_THREADS, "
100+
"DP_INTRA_OP_PARALLELISM_THREADS, and DP_INTER_OP_PARALLELISM_THREADS. "
101+
"These variables are only effective when running on CPU. See "
96102
"https://deepmd.rtfd.io/parallelism/ for more information."
97103
)
98104
if "TF_INTRA_OP_PARALLELISM_THREADS" not in os.environ:

deepmd/pd/utils/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def to_bool(flag: int | bool | str) -> bool:
113113
DEFAULT_PRECISION = "float64"
114114

115115
# throw warnings if threads not set
116-
set_default_nthreads()
116+
set_default_nthreads(use_cpu=DEVICE == "cpu")
117117
inter_nthreads, intra_nthreads = get_default_nthreads()
118118
# if inter_nthreads > 0: # the behavior of 0 is not documented
119119
# os.environ['OMP_NUM_THREADS'] = str(inter_nthreads)

deepmd/pt/utils/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
DEFAULT_PRECISION = "float64"
8080

8181
# throw warnings if threads not set
82-
set_default_nthreads()
82+
set_default_nthreads(use_cpu=DEVICE.type == "cpu")
8383
inter_nthreads, intra_nthreads = get_default_nthreads()
8484
if inter_nthreads > 0: # the behavior of 0 is not documented
8585
torch.set_num_interop_threads(inter_nthreads)

deepmd/tf/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def get_tf_session_config() -> Any:
263263
Any
264264
session configure object
265265
"""
266-
set_tf_default_nthreads()
266+
set_tf_default_nthreads(use_cpu=os.environ.get("DEVICE") == "cpu")
267267
intra, inter = get_tf_default_nthreads()
268268
if int(os.environ.get("DP_JIT", 0)):
269269
set_env_if_empty("TF_XLA_FLAGS", "--tf_xla_auto_jit=2")

0 commit comments

Comments
 (0)