Skip to content

Commit 8bdb1f8

Browse files
author
Han Wang
committed
fix bug
1 parent aeef15a commit 8bdb1f8

2 files changed

Lines changed: 18 additions & 6 deletions

File tree

source/tests/pt/test_env_threads.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import deepmd.env as common_env
88

99

10-
def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None:
10+
def test_env_threads_guard_handles_runtimeerror(monkeypatch) -> None:
1111
def raise_err(*_args, **_kwargs) -> None:
1212
raise RuntimeError("boom")
1313

@@ -18,11 +18,17 @@ def raise_err(*_args, **_kwargs) -> None:
1818
monkeypatch.setattr(torch, "get_num_threads", lambda: 2)
1919
monkeypatch.setattr(torch, "set_num_threads", raise_err)
2020

21-
caplog.set_level(logging.WARNING, logger="deepmd.pt.utils.env")
21+
messages: list[str] = []
22+
original_warning = logging.Logger.warning
23+
24+
def capture_warning(self, msg, *args, **kwargs): # type: ignore[no-untyped-def]
25+
messages.append(str(msg))
26+
return original_warning(self, msg, *args, **kwargs)
27+
28+
monkeypatch.setattr(logging.Logger, "warning", capture_warning)
2229
import deepmd.pt.utils.env as env
2330

2431
importlib.reload(env)
2532

26-
messages = [record.getMessage() for record in caplog.records]
2733
assert any("Could not set torch interop threads" in msg for msg in messages)
2834
assert any("Could not set torch intra threads" in msg for msg in messages)

source/tests/pt_expt/utils/test_env.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import deepmd.env as common_env
88

99

10-
def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None:
10+
def test_env_threads_guard_handles_runtimeerror(monkeypatch) -> None:
1111
def raise_err(*_args, **_kwargs) -> None:
1212
raise RuntimeError("boom")
1313

@@ -18,11 +18,17 @@ def raise_err(*_args, **_kwargs) -> None:
1818
monkeypatch.setattr(torch, "get_num_threads", lambda: 2)
1919
monkeypatch.setattr(torch, "set_num_threads", raise_err)
2020

21-
caplog.set_level(logging.WARNING, logger="deepmd.pt_expt.utils.env")
21+
messages: list[str] = []
22+
original_warning = logging.Logger.warning
23+
24+
def capture_warning(self, msg, *args, **kwargs): # type: ignore[no-untyped-def]
25+
messages.append(str(msg))
26+
return original_warning(self, msg, *args, **kwargs)
27+
28+
monkeypatch.setattr(logging.Logger, "warning", capture_warning)
2229
import deepmd.pt_expt.utils.env as env
2330

2431
importlib.reload(env)
2532

26-
messages = [record.getMessage() for record in caplog.records]
2733
assert any("Could not set torch interop threads" in msg for msg in messages)
2834
assert any("Could not set torch intra threads" in msg for msg in messages)

0 commit comments

Comments
 (0)