Skip to content

Commit eedcbaf

Browse files
author
Han Wang
committed
fix(pt,pt-expt): guard thread setters
1 parent 57433d3 commit eedcbaf

4 files changed

Lines changed: 82 additions & 4 deletions

File tree

deepmd/pt/utils/env.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,20 @@
9393
set_default_nthreads()
9494
intra_nthreads, inter_nthreads = get_default_nthreads()
9595
if inter_nthreads > 0: # the behavior of 0 is not documented
96-
torch.set_num_interop_threads(inter_nthreads)
96+
# torch.set_num_interop_threads can only be called once per process.
97+
# Guard to avoid RuntimeError when multiple backends are imported.
98+
try:
99+
if torch.get_num_interop_threads() != inter_nthreads:
100+
torch.set_num_interop_threads(inter_nthreads)
101+
except RuntimeError as err:
102+
log.warning(f"Could not set torch interop threads: {err}")
97103
if intra_nthreads > 0:
98-
torch.set_num_threads(intra_nthreads)
104+
# torch.set_num_threads can also fail if called after threads are created.
105+
try:
106+
if torch.get_num_threads() != intra_nthreads:
107+
torch.set_num_threads(intra_nthreads)
108+
except RuntimeError as err:
109+
log.warning(f"Could not set torch intra threads: {err}")
99110

100111
__all__ = [
101112
"CACHE_PER_SYS",

deepmd/pt_expt/utils/env.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,20 @@
9393
set_default_nthreads()
9494
intra_nthreads, inter_nthreads = get_default_nthreads()
9595
if inter_nthreads > 0: # the behavior of 0 is not documented
96-
torch.set_num_interop_threads(inter_nthreads)
96+
# torch.set_num_interop_threads can only be called once per process.
97+
# Guard to avoid RuntimeError when both pt and pt_expt env modules are imported.
98+
try:
99+
if torch.get_num_interop_threads() != inter_nthreads:
100+
torch.set_num_interop_threads(inter_nthreads)
101+
except RuntimeError as err:
102+
log.warning(f"Could not set torch interop threads: {err}")
97103
if intra_nthreads > 0:
98-
torch.set_num_threads(intra_nthreads)
104+
# torch.set_num_threads can also fail if called after threads are created.
105+
try:
106+
if torch.get_num_threads() != intra_nthreads:
107+
torch.set_num_threads(intra_nthreads)
108+
except RuntimeError as err:
109+
log.warning(f"Could not set torch intra threads: {err}")
99110

100111
__all__ = [
101112
"CACHE_PER_SYS",
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import importlib
3+
import logging
4+
5+
import torch
6+
7+
import deepmd.env as common_env
8+
9+
10+
def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None:
11+
def raise_err(*_args, **_kwargs) -> None:
12+
raise RuntimeError("boom")
13+
14+
monkeypatch.setattr(common_env, "set_default_nthreads", lambda: None)
15+
monkeypatch.setattr(common_env, "get_default_nthreads", lambda: (1, 1))
16+
monkeypatch.setattr(torch, "get_num_interop_threads", lambda: 2)
17+
monkeypatch.setattr(torch, "set_num_interop_threads", raise_err)
18+
monkeypatch.setattr(torch, "get_num_threads", lambda: 2)
19+
monkeypatch.setattr(torch, "set_num_threads", raise_err)
20+
21+
caplog.set_level(logging.WARNING, logger="deepmd.pt.utils.env")
22+
import deepmd.pt.utils.env as env
23+
24+
importlib.reload(env)
25+
26+
messages = [record.getMessage() for record in caplog.records]
27+
assert any("Could not set torch interop threads" in msg for msg in messages)
28+
assert any("Could not set torch intra threads" in msg for msg in messages)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import importlib
3+
import logging
4+
5+
import torch
6+
7+
import deepmd.env as common_env
8+
9+
10+
def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None:
11+
def raise_err(*_args, **_kwargs) -> None:
12+
raise RuntimeError("boom")
13+
14+
monkeypatch.setattr(common_env, "set_default_nthreads", lambda: None)
15+
monkeypatch.setattr(common_env, "get_default_nthreads", lambda: (1, 1))
16+
monkeypatch.setattr(torch, "get_num_interop_threads", lambda: 2)
17+
monkeypatch.setattr(torch, "set_num_interop_threads", raise_err)
18+
monkeypatch.setattr(torch, "get_num_threads", lambda: 2)
19+
monkeypatch.setattr(torch, "set_num_threads", raise_err)
20+
21+
caplog.set_level(logging.WARNING, logger="deepmd.pt_expt.utils.env")
22+
import deepmd.pt_expt.utils.env as env
23+
24+
importlib.reload(env)
25+
26+
messages = [record.getMessage() for record in caplog.records]
27+
assert any("Could not set torch interop threads" in msg for msg in messages)
28+
assert any("Could not set torch intra threads" in msg for msg in messages)

0 commit comments

Comments
 (0)