File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 9393set_default_nthreads ()
9494intra_nthreads , inter_nthreads = get_default_nthreads ()
9595if inter_nthreads > 0 : # the behavior of 0 is not documented
96- torch .set_num_interop_threads (inter_nthreads )
96+ # torch.set_num_interop_threads can only be called once per process.
97+ # Guard to avoid RuntimeError when multiple backends are imported.
98+ try :
99+ if torch .get_num_interop_threads () != inter_nthreads :
100+ torch .set_num_interop_threads (inter_nthreads )
101+ except RuntimeError as err :
102+ log .warning (f"Could not set torch interop threads: { err } " )
97103if intra_nthreads > 0 :
98- torch .set_num_threads (intra_nthreads )
104+ # torch.set_num_threads can also fail if called after threads are created.
105+ try :
106+ if torch .get_num_threads () != intra_nthreads :
107+ torch .set_num_threads (intra_nthreads )
108+ except RuntimeError as err :
109+ log .warning (f"Could not set torch intra threads: { err } " )
99110
100111__all__ = [
101112 "CACHE_PER_SYS" ,
Original file line number Diff line number Diff line change 9393set_default_nthreads ()
9494intra_nthreads , inter_nthreads = get_default_nthreads ()
9595if inter_nthreads > 0 : # the behavior of 0 is not documented
96- torch .set_num_interop_threads (inter_nthreads )
96+ # torch.set_num_interop_threads can only be called once per process.
97+ # Guard to avoid RuntimeError when both pt and pt_expt env modules are imported.
98+ try :
99+ if torch .get_num_interop_threads () != inter_nthreads :
100+ torch .set_num_interop_threads (inter_nthreads )
101+ except RuntimeError as err :
102+ log .warning (f"Could not set torch interop threads: { err } " )
97103if intra_nthreads > 0 :
98- torch .set_num_threads (intra_nthreads )
104+ # torch.set_num_threads can also fail if called after threads are created.
105+ try :
106+ if torch .get_num_threads () != intra_nthreads :
107+ torch .set_num_threads (intra_nthreads )
108+ except RuntimeError as err :
109+ log .warning (f"Could not set torch intra threads: { err } " )
99110
100111__all__ = [
101112 "CACHE_PER_SYS" ,
Original file line number Diff line number Diff line change 1+ # SPDX-License-Identifier: LGPL-3.0-or-later
2+ import importlib
3+ import logging
4+
5+ import torch
6+
7+ import deepmd .env as common_env
8+
9+
10+ def test_env_threads_guard_handles_runtimeerror (monkeypatch , caplog ) -> None :
11+ def raise_err (* _args , ** _kwargs ) -> None :
12+ raise RuntimeError ("boom" )
13+
14+ monkeypatch .setattr (common_env , "set_default_nthreads" , lambda : None )
15+ monkeypatch .setattr (common_env , "get_default_nthreads" , lambda : (1 , 1 ))
16+ monkeypatch .setattr (torch , "get_num_interop_threads" , lambda : 2 )
17+ monkeypatch .setattr (torch , "set_num_interop_threads" , raise_err )
18+ monkeypatch .setattr (torch , "get_num_threads" , lambda : 2 )
19+ monkeypatch .setattr (torch , "set_num_threads" , raise_err )
20+
21+ caplog .set_level (logging .WARNING , logger = "deepmd.pt.utils.env" )
22+ import deepmd .pt .utils .env as env
23+
24+ importlib .reload (env )
25+
26+ messages = [record .getMessage () for record in caplog .records ]
27+ assert any ("Could not set torch interop threads" in msg for msg in messages )
28+ assert any ("Could not set torch intra threads" in msg for msg in messages )
Original file line number Diff line number Diff line change 1+ # SPDX-License-Identifier: LGPL-3.0-or-later
2+ import importlib
3+ import logging
4+
5+ import torch
6+
7+ import deepmd .env as common_env
8+
9+
10+ def test_env_threads_guard_handles_runtimeerror (monkeypatch , caplog ) -> None :
11+ def raise_err (* _args , ** _kwargs ) -> None :
12+ raise RuntimeError ("boom" )
13+
14+ monkeypatch .setattr (common_env , "set_default_nthreads" , lambda : None )
15+ monkeypatch .setattr (common_env , "get_default_nthreads" , lambda : (1 , 1 ))
16+ monkeypatch .setattr (torch , "get_num_interop_threads" , lambda : 2 )
17+ monkeypatch .setattr (torch , "set_num_interop_threads" , raise_err )
18+ monkeypatch .setattr (torch , "get_num_threads" , lambda : 2 )
19+ monkeypatch .setattr (torch , "set_num_threads" , raise_err )
20+
21+ caplog .set_level (logging .WARNING , logger = "deepmd.pt_expt.utils.env" )
22+ import deepmd .pt_expt .utils .env as env
23+
24+ importlib .reload (env )
25+
26+ messages = [record .getMessage () for record in caplog .records ]
27+ assert any ("Could not set torch interop threads" in msg for msg in messages )
28+ assert any ("Could not set torch intra threads" in msg for msg in messages )
You can’t perform that action at this time.
0 commit comments