Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions pathwaysutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ def initialize() -> None:
# debug log is not triggered for customers who are following our instructions
# and using the new initialize() function only once but have already had the
# legacy initialization triggered.
if _initialization_count > 2:
_logger.debug("Already initialized. Ignoring duplicate call.")

if _initialization_count > 1:
_logger.debug("Already initialized. Ignoring duplicate call.")
return

_logger.debug("Starting initialize.")

if is_pathways_backend_used():
_logger.debug("Detected Pathways-on-Cloud backend. Applying changes.")
proxy_backend.register_backend_factory()
Expand All @@ -112,6 +112,3 @@ def initialize() -> None:
_logger.debug(
"Did not detect Pathways-on-Cloud backend. No changes applied."
)


initialize()
17 changes: 7 additions & 10 deletions pathwaysutils/test/pathwaysutils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,21 @@ def test_first_initialize(self):
with self.assertLogs(pathwaysutils._logger, level="DEBUG") as logs:
pathwaysutils.initialize()

self.assertLen(logs.output, 1)
self.assertLen(logs.output, 2)
self.assertIn(
"Detected Pathways-on-Cloud backend. Applying changes.", logs.output[0]
"Starting initialize.", logs.output[0]
)
self.assertIn(
"Detected Pathways-on-Cloud backend. Applying changes.", logs.output[1]
)

def test_second_initialize(self):
jax.config.update("jax_platforms", "proxy")
pathwaysutils._initialization_count = 1

with self.assertNoLogs(pathwaysutils._logger, level="DEBUG"):
pathwaysutils.initialize()

@parameterized.named_parameters(
("initialization_count 1", 1),
("initialization_count 2", 2),
("initialization_count 5", 5),
("initialization_count 1000", 1000),
)
def test_initialize_more_than_twice(self, initialization_count):
def test_initialize_more_than_once(self, initialization_count):
pathwaysutils._initialization_count = initialization_count

with self.assertLogs(pathwaysutils._logger, level="DEBUG") as logs:
Expand Down