Skip to content

Commit 8aa6a9e

Browse files
authored
Added explicit pathwaysutils.initialize() call to JetStream (#233)
1 parent 838f614 commit 8aa6a9e

2 files changed

Lines changed: 81 additions & 4 deletions

File tree

jetstream/engine/__init__.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,14 @@
1515
"""Initialization for any Engine implementation."""
1616

1717
import jax
18+
from importlib import util
1819

19-
try:
20+
if util.find_spec("pathwaysutils"):
2021
import pathwaysutils
21-
except ImportError as e:
22-
print("Proxy backend support is not added")
23-
pass
22+
23+
pathwaysutils.initialize()
24+
else:
25+
print(
26+
"Running JetStream without Pathways. "
27+
"Module pathwaysutils is not imported."
28+
)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for initializing jetstream.engine module."""
16+
17+
import contextlib
18+
import io
19+
import unittest
20+
from unittest import mock
21+
22+
import importlib
23+
import jetstream.engine
24+
25+
26+
class InitTest(unittest.TestCase):
27+
28+
def test_init_with_error(self):
29+
def mock_find_spec(name):
30+
if name == "pathwaysutils":
31+
return None
32+
return "some_spec"
33+
34+
with mock.patch(
35+
"importlib.util.find_spec", side_effect=mock_find_spec
36+
), contextlib.redirect_stdout(io.StringIO()) as captured_output:
37+
38+
importlib.reload(jetstream.engine)
39+
40+
self.assertIn(
41+
"Running JetStream without Pathways.", captured_output.getvalue()
42+
)
43+
44+
def test_init(self):
45+
orig_import = __import__
46+
p_mock = mock.Mock()
47+
48+
def mock_import(name, *args):
49+
50+
if name == "pathwaysutils":
51+
return p_mock
52+
return orig_import(name, *args)
53+
54+
def mock_find_spec(name):
55+
if name == "pathwaysutils":
56+
return "pathwaysutils_spec"
57+
return "some_spec"
58+
59+
with mock.patch(
60+
"importlib.util.find_spec", side_effect=mock_find_spec
61+
), mock.patch(
62+
"builtins.__import__", side_effect=mock_import
63+
), contextlib.redirect_stdout(
64+
io.StringIO()
65+
) as captured_output:
66+
importlib.reload(jetstream.engine)
67+
68+
p_mock.initialize.assert_called_once()
69+
70+
self.assertNotIn(
71+
"Running JetStream without Pathways.", captured_output.getvalue()
72+
)

0 commit comments

Comments
 (0)