Skip to content
38 changes: 23 additions & 15 deletions plugboard/cli/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import importlib
import inspect
import os
from pathlib import Path
import typing as _t

Expand Down Expand Up @@ -43,20 +42,29 @@ async def _post_to_api(url: str, data: dict) -> None:
def _import_recursive(path: Path, base_package: _t.Optional[str] = None) -> None:
"""Import all modules recursively from the given path."""
logger = DI.logger.resolve_sync()
for root, _dirs, files in os.walk(path):
for file in files:
if file.endswith(".py") and not file.startswith("__"):
# Construct module name
rel_path = os.path.relpath(os.path.join(root, file), path)
module_name = rel_path.replace(os.sep, ".")[:-3]

if base_package:
module_name = f"{base_package}.{module_name}"

try:
importlib.import_module(module_name)
except (ModuleNotFoundError, ImportError, SyntaxError) as e:
logger.warning(f"Failed to import module {module_name}: {e}")

def _walk(current_path: Path) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually keep the inplace update to dirs as it is less code @copilot

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in 58f2032 to restore the os.walk + dirs[:] pruning approach, while keeping the module-name handling a bit cleaner with Path.relative_to(...).with_suffix("").

for child in current_path.iterdir():
if child.name.startswith("."):
continue

if child.is_dir():
_walk(child)
continue

if child.suffix != ".py" or child.name.startswith("__"):
continue

module_name = ".".join(child.relative_to(path).with_suffix("").parts)
if base_package:
module_name = f"{base_package}.{module_name}"

try:
importlib.import_module(module_name)
except (ModuleNotFoundError, ImportError, SyntaxError) as e:
logger.warning(f"Failed to import module {module_name}: {e}")

_walk(path)


def _get_all_subclasses(cls: type) -> set:
Expand Down
73 changes: 66 additions & 7 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
marked async so that they do not interfere with pytest-asyncio's event loop.
"""

import json
from pathlib import Path
import tempfile
import textwrap
import typing as _t
from unittest.mock import AsyncMock, MagicMock, patch

Expand All @@ -20,6 +22,41 @@
runner = CliRunner()


def _create_test_project(
base_path: Path,
*,
as_package: bool = True,
include_hidden_dir: bool = False,
) -> Path:
"""Create a minimal Python project for CLI discovery tests."""
project_dir = base_path / "test_project"
project_dir.mkdir()

if as_package:
(project_dir / "__init__.py").write_text("")
(project_dir / "test_file.py").write_text("")
else:
(project_dir / "test_file.py").write_text(
textwrap.dedent("""
from plugboard.component import Component, IOController as IO


class VisibleComponent(Component):
io = IO(outputs=["out"])

async def step(self) -> None:
self.out = 1
""").strip()
)

if include_hidden_dir:
hidden_dir = project_dir / ".venv"
hidden_dir.mkdir()
(hidden_dir / "bad_module.py").write_text('raise RuntimeError("should not import")')

return project_dir


def test_cli_version() -> None:
"""Tests the version command."""
result = runner.invoke(app, ["version"])
Expand All @@ -35,11 +72,7 @@ def test_cli_version() -> None:
def test_project_dir() -> _t.Iterator[Path]:
"""Create a minimal Python package for testing."""
with tempfile.TemporaryDirectory() as tmpdir:
project_dir = Path(tmpdir) / "test_project"
project_dir.mkdir()
(project_dir / "__init__.py").write_text("")
(project_dir / "test_file.py").write_text("")
yield project_dir
yield _create_test_project(Path(tmpdir))


@pytest.mark.asyncio
Expand Down Expand Up @@ -191,8 +224,28 @@ def test_cli_ai_agents_template_is_packaged_file() -> None:
assert not _AGENTS_MD.is_symlink()


def test_cli_server_discover(test_project_dir: Path) -> None:
@pytest.mark.parametrize(
("as_package", "include_hidden_dir", "expected_component_name"),
[
(True, False, None),
(True, True, None),
(False, False, "VisibleComponent"),
(False, True, "VisibleComponent"),
],
)
def test_cli_server_discover(
tmp_path: Path,
as_package: bool,
include_hidden_dir: bool,
expected_component_name: str | None,
) -> None:
"""Tests the server discover command."""
project_dir = _create_test_project(
tmp_path,
as_package=as_package,
include_hidden_dir=include_hidden_dir,
)

with respx.mock:
# Mock all the API endpoints
component_route = respx.post("http://test:8000/types/component").respond(
Expand All @@ -209,14 +262,15 @@ def test_cli_server_discover(test_project_dir: Path) -> None:
[
"server",
"discover",
str(test_project_dir),
str(project_dir),
"--api-url",
"http://test:8000",
],
)

# CLI must run without error
assert result.exit_code == 0
assert result.exception is None
assert "Discovery complete" in result.stdout

# At minimum, should have discovered plugboard's built-in types
Expand All @@ -225,6 +279,11 @@ def test_cli_server_discover(test_project_dir: Path) -> None:
assert connector_route.called
assert event_route.called
assert process_route.called
if expected_component_name is not None:
assert any(
json.loads(call.request.content)["name"] == expected_component_name
for call in component_route.calls
)


def test_cli_server_discover_with_env_var(test_project_dir: Path) -> None:
Expand Down