Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
PR_CHANGES_OPTION = "--changed-samples-only-from"


def is_plugin_active(config: pytest.Config) -> bool:
"""Return whether any of the plugin provided options were provided on commandline."""
return get_diff_paths_function(config) is not None


def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption(
WORKING_TREE_CHANGES_OPTION,
Expand Down Expand Up @@ -42,10 +47,18 @@ def pytest_configure(config: pytest.Config) -> None:

@pytest.hookimpl(hookwrapper=True)
def pytest_collection(session: pytest.Session) -> None:
"""Set up path filtering based on git diff."""
config = session.config
diff_path_trie = Trie()

for p in get_diff_paths_function(config)():
paths_filter = get_diff_paths_function(config)

if paths_filter is None:
# Exit early if there's no path filter
yield
return

for p in paths_filter():
diff_path_trie.insert(p.parts)

config.stash[DIFF_PATH_TRIE_KEY] = diff_path_trie
Expand All @@ -56,8 +69,9 @@ def pytest_collection(session: pytest.Session) -> None:


def pytest_ignore_collect(collection_path: Path, config: pytest.Config) -> Optional[bool]:
"""Ignore paths that were not touched by the current git diff."""
if DIFF_PATH_TRIE_KEY not in config.stash:
# Occures when calling `pytest --fixtures`
# Occurs when calling `pytest --fixtures`
return None

diff_path_trie = config.stash[DIFF_PATH_TRIE_KEY]
Expand All @@ -72,7 +86,16 @@ def pytest_ignore_collect(collection_path: Path, config: pytest.Config) -> Optio
return (not diff_path_trie.is_prefix(ignore_dir.resolve().parts)) or None


def get_diff_paths_function(config: pytest.Config) -> Callable[[], Iterable[Path]]:
@pytest.hookimpl(trylast=True)
def pytest_sessionfinish(session: pytest.Session, exitstatus: int) -> None:
if not is_plugin_active(session.config):
return

if exitstatus == pytest.ExitCode.NO_TESTS_COLLECTED:
session.exitstatus = pytest.ExitCode.OK


def get_diff_paths_function(config: pytest.Config) -> Optional[Callable[[], Iterable[Path]]]:
"""Get the function that returns paths present in a diff specfied by cmdline arguments

:param pytest.Config config: The pytest config
Expand All @@ -87,7 +110,7 @@ def get_diff_paths_function(config: pytest.Config) -> Callable[[], Iterable[Path
if ref := config.getoption(opt_var(PR_CHANGES_OPTION)):
return lambda: get_branch_diff_paths(ref)

return lambda: ()
return None


def opt_var(s: str) -> str:
Expand Down