Skip to content

Commit 913b302

Browse files
committed
Incorporate internal-review comment
Signed-off-by: refai06 <refai.ahamed06@gmail.com>
1 parent ac18d33 commit 913b302

2 files changed

Lines changed: 39 additions & 27 deletions

File tree

openfl/experimental/workflow/notebooktools/code_analyzer.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import sys
99
from importlib import import_module
1010
from pathlib import Path
11-
from typing import Any, Dict, List, Optional
11+
from typing import Any, Dict, List, Optional, Tuple
1212

1313
import nbformat
1414
from nbdev.export import nb_export
@@ -117,41 +117,23 @@ def __extract_user_defined_imports(self, notebook_path) -> List[str]:
117117

118118
return list(user_imports)
119119

120-
def _is_user_defined_module(self, module_name: str, notebook_path: Path) -> bool:
121-
"""
122-
Check if a module is user-defined
123-
124-
Args:
125-
notebook_path: Path to Jupyter notebook.
126-
"""
127-
notebook_dir = notebook_path.parent
128-
module_path = notebook_dir / f"{module_name}.py"
129-
130-
module_dir = notebook_dir / module_name
131-
132-
if (module_path.exists() and module_path.is_file()) or module_dir.exists():
133-
return True
134-
135-
return False
136-
137120
def __copy_user_defined_modules(self, module_names: List[str], notebook_path: Path) -> None:
138121
"""
139122
Copies user-defined modules/packages to the workspace's src directory
140123
141124
Args:
142-
module_name: List of module name to copy.
125+
module_names: List of module names.
143126
notebook_path: Path to Jupyter notebook.
144127
"""
145128
src_dir = self.script_path.parent
146129
for module_name in module_names:
147-
module_file = notebook_path.parent / f"{module_name}.py"
148-
module_dir = notebook_path.parent / module_name
149-
if module_file.exists() and module_file.is_file():
150-
shutil.copy(module_file, src_dir)
151-
print(f"Copied used-defined module: {module_name}.py")
130+
module_path, module_dir = self._get_module_paths(module_name, notebook_path)
131+
if module_path.exists() and module_path.is_file():
132+
shutil.copy(module_path, src_dir)
133+
print(f"Copied user-defined module: {module_name}.py")
152134
elif module_dir.exists() and module_dir.is_dir():
153135
shutil.copytree(module_dir, src_dir / module_name, dirs_exist_ok=True)
154-
print(f"Copied used-defined directory: {module_name}/")
136+
print(f"Copied user-defined directory: {module_name}/")
155137

156138
def __modify_experiment_script(self) -> None:
157139
"""Modifies the given python script by commenting out following code:
@@ -361,6 +343,34 @@ def _clean_value(self, value: str) -> str:
361343
value = value.lstrip("[").rstrip("]")
362344
return value
363345

346+
def _is_user_defined_module(self, module_name: str, notebook_path: Path) -> bool:
347+
"""
348+
Check if a module is user-defined
349+
350+
Args:
351+
module_name: Name of the module.
352+
notebook_path: Path to Jupyter notebook.
353+
"""
354+
if not isinstance(module_name, str) or not module_name.strip():
355+
return False
356+
357+
module_path, module_dir = self._get_module_paths(module_name, notebook_path)
358+
359+
return (module_path.exists() and module_path.is_file()) or module_dir.exists()
360+
361+
def _get_module_paths(self, module_name: str, notebook_path: Path) -> Tuple:
362+
"""
363+
Get the file and directory paths for a user-defined module
364+
365+
Args:
366+
module_name: Name of the module.
367+
notebook_path: Path to the Jupyter notebook.
368+
"""
369+
notebook_dir = notebook_path.parent
370+
module_path = notebook_dir / f"{module_name}.py"
371+
module_dir = notebook_dir / module_name
372+
return module_path, module_dir
373+
364374
def _get_requirements(self) -> List[str]:
365375
"""Extract pip libraries from the script
366376

openfl/utilities/workspace.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,10 @@ def __enter__(self):
105105
os.chdir(self.experiment_work_dir)
106106

107107
# This is needed for python module finder
108-
sys.path.append(str(self.experiment_work_dir))
109-
sys.path.append(str(self.experiment_work_dir / "src"))
108+
for path in [self.experiment_work_dir, self.experiment_work_dir / "src"]:
109+
path_str = str(path)
110+
if path_str not in sys.path:
111+
sys.path.append(path_str)
110112

111113
def __exit__(self, exc_type, exc_value, traceback):
112114
"""Remove the workspace."""

0 commit comments

Comments
 (0)