Skip to content

Commit 5ab8d27

Browse files
daiyippyglove authors
authored andcommitted
Release pyglove/dev for supporting module reloading.
PiperOrigin-RevId: 918567274
1 parent 4387242 commit 5ab8d27

6 files changed

Lines changed: 398 additions & 4 deletions

File tree

.github/workflows/ci.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ jobs:
1414
runs-on: "${{ matrix.os }}"
1515
strategy:
1616
matrix:
17-
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]
17+
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
1818
os: [ubuntu-latest]
1919
steps:
20-
- uses: actions/checkout@v2
20+
- uses: actions/checkout@v4
2121
- name: Set up Python ${{ matrix.python-version }}
22-
uses: actions/setup-python@v1
22+
uses: actions/setup-python@v5
2323
with:
2424
python-version: ${{ matrix.python-version }}
2525
- name: Install dependencies
@@ -30,7 +30,8 @@ jobs:
3030
pip install -r requirements.txt
3131
- name: Test with pytest and generate coverage report
3232
run: |
33-
pytest -n auto --cov=pyglove --cov-report=xml
33+
pytest -n auto --ignore=pyglove/dev/reloader_test.py --cov=pyglove --cov-report=xml
34+
pytest pyglove/dev/reloader_test.py --cov=pyglove --cov-append --cov-report=xml
3435
- name: Upload coverage to Codecov
3536
uses: codecov/codecov-action@v1
3637
with:

pyglove/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
# Placeholder for Google-internal imports.
3434

35+
import pyglove.dev
36+
3537
# pylint: enable=g-import-not-at-top
3638
# pylint: enable=reimported
3739
# pylint: enable=unused-import

pyglove/dev/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright 2023 The PyGlove Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""PyGlove dev tools."""
15+
16+
# pylint: disable=g-importing-member
17+
18+
from pyglove.dev.reloader import adhoc_import
19+
from pyglove.dev.reloader import reload
20+
21+
from pyglove.dev.unittest import enable_test
22+
from pyglove.dev.unittest import run_tests
23+
24+
# pylint: enable=g-importing-member

pyglove/dev/reloader.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Copyright 2023 The PyGlove Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Utilities for reloading modules."""
15+
16+
import contextlib
17+
import getpass
18+
import importlib
19+
import inspect
20+
import re
21+
import sys
22+
import time
23+
import types
24+
from typing import Callable, List, Optional, Sequence, Union
25+
26+
27+
def reload(
28+
module: Union[
29+
types.ModuleType, # Module
30+
str, # Module name.
31+
Sequence[Union[types.ModuleType, str]], # List of module/module names.
32+
None
33+
] = None, # pylint: disable=bad-whitespace
34+
*,
35+
workspace: Optional[str] = None,
36+
user: Optional[str] = None,
37+
cl: Optional[int] = None,
38+
reset_flags: bool = True,
39+
reload_pattern: str = 'pyglove.*',
40+
behavior: Optional[str] = 'preferred',
41+
verbose: bool = False,
42+
) -> Union[types.ModuleType, List[types.ModuleType]]:
43+
"""Reloads a module with refreshing its sub-modules based on filter.
44+
45+
Args:
46+
module: The root module(s) to reload. If None, module `pyglove` will be
47+
reloaded.
48+
workspace: Cider-V workspace to sync code from. If None, use a specific
49+
CL when `cl` is specified, or sync code from HEAD.
50+
user: The user LDAP. If None, the current user will be used.
51+
cl: A Change Number to sync code from. If None, refer to `workspace`.
52+
reset_flags: If True, removes all the flags in the module that is being
53+
reloaded. This is to avoid flags being defined twice when reloading.
54+
reload_pattern: An optional regular expression to whitelist the dependent
55+
module names that need to be reloaded. If None, it will reload all the
56+
dependent modules of `module`.
57+
behavior: The adhoc_import behavior string. Among 'preferred' or None (
58+
'fallback').
59+
verbose: If True, print the reloaded sub-modules.
60+
61+
Returns:
62+
The reloaded module(s).
63+
"""
64+
reload_multiple = isinstance(module, (list, tuple))
65+
66+
if module is None:
67+
module = sys.modules['pyglove']
68+
69+
modules = list(module) if isinstance(module, (list, tuple)) else [module]
70+
71+
regex = re.compile(reload_pattern)
72+
filter_fn = lambda m: regex.match(m.__name__)
73+
74+
import_lib = adhoc_import_lib()
75+
76+
def _reload(m: types.ModuleType):
77+
try:
78+
setattr(m, '__reloading__', True)
79+
if import_lib is None:
80+
return importlib.reload(m)
81+
else:
82+
return import_lib.Reload(m, reset_flags=reset_flags)
83+
finally:
84+
delattr(m, '__reloading__')
85+
86+
start_time = time.time()
87+
with adhoc_import(workspace, user, cl=cl, behavior=behavior):
88+
# Step 1: Load module from names.
89+
for i, m in enumerate(modules):
90+
if isinstance(m, str):
91+
if verbose:
92+
print(f'Loading [{m}]...')
93+
modules[i] = importlib.import_module(m)
94+
95+
# Step 2: Compute and reload dependencies.
96+
for m in module_dependencies(modules, transitive=True, filter=filter_fn):
97+
if verbose:
98+
print(f'Reloading [{m.__name__}]...')
99+
_ = _reload(m)
100+
101+
# Reload the root modules.
102+
reloaded_modules = []
103+
for m in modules:
104+
if verbose:
105+
print(f'Reloading [{m.__name__}]...')
106+
reloaded_modules.append(_reload(m))
107+
108+
elapse = time.time() - start_time
109+
print(f'Sync completed in {elapse:.2f} seconds.')
110+
return reloaded_modules if reload_multiple else reloaded_modules[0]
111+
112+
113+
_BUILTIN_MODULE_NAMES = frozenset(sys.builtin_module_names)
114+
115+
116+
def module_dependencies(
117+
module: Union[types.ModuleType, Sequence[types.ModuleType]],
118+
transitive: bool = False,
119+
filter: Optional[Callable[[types.ModuleType], bool]] = None # pylint: disable=redefined-builtin
120+
) -> List[types.ModuleType]:
121+
"""Returns a list of module dependencies for a given module."""
122+
if transitive and not filter:
123+
raise ValueError(
124+
'`filter` must be provided when `transitive` is set to True.')
125+
126+
filter = filter or (lambda m: True)
127+
128+
dependencies = []
129+
seen = set()
130+
max_depth = None if transitive else 1
131+
132+
def _visit(m: types.ModuleType, depth: int) -> None:
133+
if max_depth is not None and depth >= max_depth:
134+
return
135+
136+
if not hasattr(m, '__file__'):
137+
return
138+
139+
try:
140+
lines = inspect.getsource(m).split('\n')
141+
except OSError:
142+
return
143+
144+
for line in lines:
145+
symbols = _imported_symbols(line)
146+
147+
for symbol in symbols:
148+
dependency = _dependent_module(symbol)
149+
if not dependency or not filter(dependency):
150+
continue
151+
152+
if dependency not in seen:
153+
seen.add(dependency)
154+
_visit(dependency, depth + 1)
155+
dependencies.append(dependency)
156+
157+
if not isinstance(module, (list, tuple)):
158+
module = [module]
159+
160+
for m in module:
161+
_visit(m, 0)
162+
return dependencies
163+
164+
165+
_IMPORT_REGEX = re.compile('^import (.*)')
166+
_FROM_IMPORT_REGEX = re.compile('^from (.*) import (.*)')
167+
168+
169+
def _imported_symbols(import_statement: str) -> List[str]:
170+
"""Gets the fully qualified names of the imported symbols."""
171+
m = _FROM_IMPORT_REGEX.match(import_statement)
172+
if m:
173+
parent_module = m.group(1).strip()
174+
symbol_names = (
175+
m.group(2).split(' as ')[0] # Remove 'as' sub-statements.
176+
.split('#')[0] # Remove comments.
177+
.split(','))
178+
return [
179+
f'{parent_module}.{symbol_name.strip()}'
180+
for symbol_name in symbol_names
181+
]
182+
183+
m = _IMPORT_REGEX.match(import_statement)
184+
if m:
185+
symbol_name = (
186+
m.group(1).split(' as ')[0] # Remove 'as' sub-statements.
187+
.split('#')[0] # Remove comments.
188+
.split(','))
189+
return [n.strip() for n in symbol_name]
190+
return []
191+
192+
193+
def _dependent_module(symbol_name: str):
194+
"""Gets the immediate module for a fully qualified symbol name."""
195+
if symbol_name.startswith(('_', '.')):
196+
return None
197+
198+
module = sys.modules.get(symbol_name)
199+
if module is None:
200+
module_name = symbol_name[:symbol_name.rindex('.')]
201+
if module_name.endswith('_pb2'):
202+
return None
203+
module = sys.modules.get(module_name)
204+
if (module is not None
205+
and (module.__name__ in _BUILTIN_MODULE_NAMES
206+
or not hasattr(module, '__file__')
207+
or module.__name__.endswith('_pb2'))):
208+
return None
209+
return module
210+
211+
212+
def adhoc_import(
213+
workspace: Optional[str],
214+
user: Optional[str] = None,
215+
cl: Optional[int] = None,
216+
behavior: Optional[str] = 'preferred'):
217+
"""Returns a context manager for importing libraries."""
218+
import_lib = adhoc_import_lib()
219+
if import_lib is None:
220+
return contextlib.nullcontext()
221+
# Placeholder for Google-internal adhoc import logic.
222+
223+
224+
def adhoc_import_lib():
225+
try:
226+
_ = get_ipython() # pytype: disable=name-error
227+
# pytype: disable=import-error
228+
from colabtools import adhoc_import as import_lib # pylint: disable=g-import-not-at-top
229+
# pytype: enable=import-error
230+
return import_lib
231+
except NameError:
232+
return None

pyglove/dev/reloader_test.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2025 The PyGlove Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import sys
15+
import unittest
16+
import pyglove.core as pg
17+
from pyglove.dev import reloader
18+
19+
20+
class ReloaderTest(unittest.TestCase):
21+
22+
def test_module_dependencies(self):
23+
dependencies = reloader.module_dependencies(pg)
24+
self.assertEqual(
25+
dependencies,
26+
[
27+
sys.modules['pyglove.core.symbolic'],
28+
sys.modules['pyglove.core.typing'],
29+
sys.modules['pyglove.core.geno'],
30+
sys.modules['pyglove.core.hyper'],
31+
sys.modules['pyglove.core.tuning'],
32+
sys.modules['pyglove.core.detouring'],
33+
sys.modules['pyglove.core.patching'],
34+
sys.modules['pyglove.core.utils'],
35+
sys.modules['pyglove.core.views'],
36+
sys.modules['pyglove.core.views.html.controls'],
37+
sys.modules['pyglove.core.io'],
38+
sys.modules['pyglove.core.coding'],
39+
sys.modules['pyglove.core.logging'],
40+
sys.modules['pyglove.core.monitoring'],
41+
],
42+
)
43+
44+
def test_module_dependencies_transitive(self):
45+
dependencies = reloader.module_dependencies(
46+
pg.symbolic,
47+
transitive=True,
48+
filter=lambda m: m.__name__.startswith('pyglove.core.symbolic'))
49+
50+
def index(module_name):
51+
return dependencies.index(
52+
sys.modules['pyglove.core.symbolic.' + module_name])
53+
54+
self.assertLess(index('base'), index('list'))
55+
self.assertLess(index('origin'), index('base'))
56+
self.assertLess(index('pure_symbolic'), index('base'))
57+
self.assertLess(index('object'), index('class_wrapper'))
58+
59+
def test_module_dependencies_transitive_multiple(self):
60+
dependencies = reloader.module_dependencies(
61+
(pg.symbolic, pg.symbolic.origin),
62+
transitive=True,
63+
filter=lambda m: m.__name__.startswith('pyglove.core.symbolic'))
64+
65+
def index(module_name):
66+
return dependencies.index(
67+
sys.modules['pyglove.core.symbolic.' + module_name])
68+
69+
self.assertLess(index('base'), index('list'))
70+
self.assertLess(index('origin'), index('base'))
71+
self.assertLess(index('pure_symbolic'), index('base'))
72+
self.assertLess(index('object'), index('class_wrapper'))
73+
74+
def test_reload(self):
75+
_ = reloader.reload(pg)
76+
77+
78+
if __name__ == '__main__':
79+
unittest.main()

0 commit comments

Comments
 (0)