Skip to content

Commit ffc5824

Browse files
lukebaumanncopybara-github
authored andcommitted
Updated some docstrings in pathwaysutils/__init__.py, removed a warning, updated tests, and changed warnings to debug logs.
PiperOrigin-RevId: 746481798
1 parent 7bcd91d commit ffc5824

2 files changed

Lines changed: 82 additions & 40 deletions

File tree

pathwaysutils/__init__.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import datetime
1717
import logging
1818
import os
19-
import warnings
2019

2120
import jax
2221
from pathwaysutils import cloud_logging
@@ -33,11 +32,29 @@
3332

3433
# This is a brittle implementation since the platforms value is not necessarily
3534
# which backend is ultimately selected
36-
def _is_pathways_used():
35+
def is_pathways_backend_used() -> bool:
36+
"""Returns whether Pathways backend is used.
37+
38+
This function checks the JAX platforms configuration to determine whether
39+
Pathways is used. If the platforms configuration contains the string "proxy",
40+
Pathways is used. This is a brittle implementation since the platforms value
41+
is not necessarily which backend is ultimately selected or there may be more
42+
than one platform specified and another may have higher priority.
43+
"""
3744
return jax.config.jax_platforms and "proxy" in jax.config.jax_platforms
3845

3946

40-
def _is_persistence_enabled():
47+
def _is_persistence_enabled() -> bool:
48+
"""Returns whether persistence is enabled.
49+
50+
This function checks the environment variable ENABLE_PATHWAYS_PERSISTENCE to
51+
determine whether persistence is enabled. If the variable is set to "1",
52+
persistence is enabled. If the variable is set to "0" or unset, persistence is
53+
disabled.
54+
55+
Returns:
56+
True if persistence is enabled, False otherwise.
57+
"""
4158
if "ENABLE_PATHWAYS_PERSISTENCE" in os.environ:
4259
if os.environ["ENABLE_PATHWAYS_PERSISTENCE"] == "1":
4360
return True
@@ -51,31 +68,29 @@ def _is_persistence_enabled():
5168
return False
5269

5370

54-
def initialize():
55-
"""Initializes pathwaysutils."""
71+
def initialize() -> None:
72+
"""Initializes pathwaysutils.
73+
74+
This function is called by the user to initialize pathwaysutils. It is
75+
responsible for setting up the logging, profiling, and persistence handlers
76+
through various monkey patching functions. It is also responsible for
77+
registering the proxy backend factory.
78+
"""
5679
global _initialization_count
5780
_initialization_count += 1
5881

59-
if _initialization_count == 1:
60-
warnings.warn(
61-
"pathwaysutils: Legacy initialization. Ensure you also call"
62-
" pathwaysutils.initialize(). This warning will be removed in a future"
63-
" release."
64-
)
65-
66-
# Ignoring the second call to initialize() is a temporary measure so that this warning is not triggered for customers who are following our instructions and using the new initialize() function only once but have already had the legacy initialization triggered.
82+
# Ignoring the second call to initialize() is a temporary measure so that this
83+
# debug log is not triggered for customers who are following our instructions
84+
# and using the new initialize() function only once but have already had the
85+
# legacy initialization triggered.
6786
if _initialization_count > 2:
68-
warnings.warn(
69-
"pathwaysutils: Already initialized. Ignoring duplicate call."
70-
)
87+
_logger.debug("Already initialized. Ignoring duplicate call.")
7188

7289
if _initialization_count > 1:
7390
return
7491

75-
if _is_pathways_used():
76-
_logger.debug(
77-
"pathwaysutils: Detected Pathways-on-Cloud backend. Applying changes."
78-
)
92+
if is_pathways_backend_used():
93+
_logger.debug("Detected Pathways-on-Cloud backend. Applying changes.")
7994
proxy_backend.register_backend_factory()
8095
profiling.monkey_patch_jax()
8196
# TODO: b/365549911 - Remove when OCDBT-compatible
@@ -90,14 +105,13 @@ def initialize():
90105
cloud_logging.setup()
91106
except Exception as error: # pylint: disable=broad-except
92107
_logger.debug(
93-
"pathwaysutils: Failed to set up cloud logging due to the following"
94-
" error: %s",
108+
"Failed to set up cloud logging due to the following error: %s",
95109
error,
96110
)
97111
else:
98112
_logger.debug(
99-
"pathwaysutils: Did not detect Pathways-on-Cloud backend. No changes"
100-
" applied."
113+
"Did not detect Pathways-on-Cloud backend. No changes applied."
101114
)
102115

116+
103117
initialize()

pathwaysutils/test/pathwaysutils_test.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
import os
1616
from unittest import mock
17-
import warnings
1817

18+
import google.cloud.logging
1919
import jax
2020
import pathwaysutils
2121
from pathwaysutils import cloud_logging
@@ -32,37 +32,65 @@ def setUp(self):
3232
cloud_logging, "setup", autospec=True
3333
)
3434

35-
def test_legacy_initialize(self):
35+
def test_first_initialize(self):
36+
jax.config.update("jax_platforms", "proxy")
3637
pathwaysutils._initialization_count = 0
3738

38-
with self.assertWarns(UserWarning, msg="Legacy initialization"):
39+
self.enter_context(
40+
mock.patch.object(google.cloud.logging, "Client", autospec=True)
41+
)
42+
43+
with self.assertLogs(pathwaysutils._logger, level="DEBUG") as logs:
3944
pathwaysutils.initialize()
4045

41-
def test_legacy_and_new_initialize(self):
46+
self.assertLen(logs.output, 1)
47+
self.assertIn(
48+
"Detected Pathways-on-Cloud backend. Applying changes.", logs.output[0]
49+
)
50+
51+
def test_second_initialize(self):
52+
jax.config.update("jax_platforms", "proxy")
4253
pathwaysutils._initialization_count = 1
4354

44-
with warnings.catch_warnings(record=True) as caught_warnings:
55+
with self.assertNoLogs(pathwaysutils._logger, level="DEBUG"):
4556
pathwaysutils.initialize()
4657

47-
self.assertEmpty(caught_warnings)
48-
4958
@parameterized.named_parameters(
5059
("initialization_count 2", 2),
5160
("initialization_count 5", 5),
61+
("initialization_count 1000", 1000),
5262
)
53-
def test_initialize_more_than_once(self, initialization_count):
63+
def test_initialize_more_than_twice(self, initialization_count):
5464
pathwaysutils._initialization_count = initialization_count
5565

56-
with self.assertWarns(UserWarning, msg="Already initialized"):
66+
with self.assertLogs(pathwaysutils._logger, level="DEBUG") as logs:
5767
pathwaysutils.initialize()
5868

59-
def test_is_pathways_used(self):
60-
for platform in ["", "cpu", "tpu", "gpu", "cpu,tpu,gpu"]:
61-
jax.config.update("jax_platforms", platform)
62-
self.assertFalse(pathwaysutils._is_pathways_used())
63-
for platform in ["proxy", "proxy,cpu", "cpu,proxy", "tpu,cpu,proxy,gpu"]:
64-
jax.config.update("jax_platforms", platform)
65-
self.assertTrue(pathwaysutils._is_pathways_used())
69+
self.assertLen(logs.output, 1)
70+
self.assertIn(
71+
"Already initialized. Ignoring duplicate call.", logs.output[0]
72+
)
73+
74+
@parameterized.named_parameters(
75+
("empty", ""),
76+
("cpu", "cpu"),
77+
("tpu", "tpu"),
78+
("gpu", "gpu"),
79+
("cpu,tpu,gpu", "cpu,tpu,gpu"),
80+
)
81+
def test_not_is_pathways_backend_used(self, platform: str):
82+
jax.config.update("jax_platforms", platform)
83+
self.assertFalse(pathwaysutils.is_pathways_backend_used())
84+
85+
@parameterized.named_parameters(
86+
("proxy", "proxy"),
87+
("proxy,cpu", "proxy,cpu"),
88+
("cpu,proxy", "cpu,proxy"),
89+
("tpu,cpu,proxy,gpu", "tpu,cpu,proxy,gpu"),
90+
)
91+
def test_is_pathways_backend_used(self, platform: str):
92+
jax.config.update("jax_platforms", platform)
93+
self.assertTrue(pathwaysutils.is_pathways_backend_used())
6694

6795
def test_persistence_enabled(self):
6896
os.environ["ENABLE_PATHWAYS_PERSISTENCE"] = "1"

0 commit comments

Comments
 (0)