diff --git a/.github/workflows/code_test_and_deploy.yml b/.github/workflows/code_test_and_deploy.yml index 52975585f..2fee3d260 100644 --- a/.github/workflows/code_test_and_deploy.yml +++ b/.github/workflows/code_test_and_deploy.yml @@ -32,7 +32,7 @@ jobs: # macos-14 is M1, macos-13 is intel. Run on earliest and # latest python versions. All python versions are tested in # the weekly cron job. - os: [windows-latest, ubuntu-latest, macos-14, macos-13] + os: [ ubuntu-latest, windows-latest, macos-14, macos-13] # Test all Python versions for cron job, and only first/last for other triggers python-version: ${{ fromJson(github.event_name == 'schedule' && '["3.9", "3.10", "3.11", "3.12"]' || '["3.9", "3.12"]') }} @@ -57,8 +57,17 @@ jobs: run: | python -m pip install --upgrade pip pip install .[dev] - - name: Test - run: pytest + # run SSH tests only on Linux because Windows and macOS + # are already run within a virtual container and so cannot + # run Linux containers because nested containerisation is disabled. + - name: Test SSH (Linux only) + if: runner.os == 'Linux' + run: | + sudo service mysql stop # free up port 3306 for ssh tests + pytest tests/tests_transfers/ssh + - name: All Other Tests + run: | + pytest --ignore tests/tests_transfers/ssh build_sdist_wheels: name: Build source distribution diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index 381d04f73..542cf7f78 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -9,6 +9,7 @@ from __future__ import annotations +import os from typing import ( TYPE_CHECKING, Dict, @@ -52,6 +53,14 @@ def keys_str_on_file_but_path_in_class() -> list[str]: ] +def get_default_ssh_port() -> int: + """Get the default port used for SSH connections.""" + if "DS_SSH_PORT" in os.environ: + return int(os.environ["DS_SSH_PORT"]) + else: + return 22 + + # ----------------------------------------------------------------------------- # Check Configs # ----------------------------------------------------------------------------- diff --git a/datashuttle/utils/data_transfer.py b/datashuttle/utils/data_transfer.py index 1db712859..4129bb712 100644 --- a/datashuttle/utils/data_transfer.py +++ b/datashuttle/utils/data_transfer.py @@ -164,7 +164,6 @@ def build_a_list_of_all_files_and_folders_to_transfer(self) -> List[str]: self.update_list_with_non_ses_sub_level_folders( extra_folder_names, extra_filenames, sub ) - continue # Datatype (sub and ses level) -------------------------------- diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index 2d66a8cc1..c6b57b24e 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -6,6 +6,7 @@ from subprocess import CompletedProcess from typing import Dict, List, Literal +from datashuttle.configs import canonical_configs from datashuttle.configs.config_class import Configs from datashuttle.utils import utils from datashuttle.utils.custom_types import TopLevelFolder @@ -161,7 +162,7 @@ def setup_rclone_config_for_ssh( f"sftp " f"host {cfg['central_host_id']} " f"user {cfg['central_host_username']} " - f"port 22 " + f"port {canonical_configs.get_default_ssh_port()} " f"key_file {ssh_key_path.as_posix()}", pipe_std=True, ) diff --git a/datashuttle/utils/ssh.py b/datashuttle/utils/ssh.py index 587e9416a..3cbf1c492 100644 --- a/datashuttle/utils/ssh.py +++ b/datashuttle/utils/ssh.py @@ -14,6 +14,7 @@ import paramiko +from datashuttle.configs import canonical_configs from datashuttle.utils import utils # ----------------------------------------------------------------------------- @@ -58,6 +59,7 @@ def connect_client_core( else None ), look_for_keys=True, + port=canonical_configs.get_default_ssh_port(), ) @@ -122,7 +124,9 @@ def get_remote_server_key(central_host_id: str): """ transport: paramiko.Transport - with paramiko.Transport(central_host_id) as transport: + with paramiko.Transport( + (central_host_id, canonical_configs.get_default_ssh_port()) + ) as transport: transport.connect() key = transport.get_remote_server_key() return key @@ -148,7 +152,15 @@ def save_hostkey_locally(key, central_host_id, hostkeys_path) -> None: """ client = paramiko.SSHClient() - client.get_host_keys().add(central_host_id, key.get_name(), key) + + port = canonical_configs.get_default_ssh_port() + host_key = f"[{central_host_id}]:{port}" if port != 22 else central_host_id + + client.get_host_keys().add( + host_key, + key.get_name(), + key, + ) client.get_host_keys().save(hostkeys_path.as_posix()) @@ -242,7 +254,7 @@ def connect_client_with_logging( f"Connection to {cfg['central_host_id']} made successfully." ) - except Exception: + except Exception as e: utils.log_and_raise_error( f"Could not connect to server. Ensure that \n" f"1) You have run setup_ssh_connection() \n" @@ -250,7 +262,8 @@ def connect_client_with_logging( f"3) The central_host_id: {cfg['central_host_id']} is" f" correct.\n" f"4) The central username:" - f" {cfg['central_host_username']}, and password are correct.", + f" {cfg['central_host_username']}, and password are correct." + f"Original error: {e}", ConnectionError, ) diff --git a/pyproject.toml b/pyproject.toml index 87af8ef45..054747f59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,7 +126,8 @@ select = [ "D", # pydocstyle ] per-file-ignores = { "tests/*" = [ - "D" # ignore docstring formatting in tests for now + "D", # ignore docstring formatting in tests for now + "TID252" ], "examples/*" = [ "D400", # first line should end with a period. "D415", # first line should end with a period, question mark... diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/base.py b/tests/base.py similarity index 98% rename from tests/tests_integration/base.py rename to tests/base.py index 5ce5359c1..f0e45588b 100644 --- a/tests/tests_integration/base.py +++ b/tests/base.py @@ -1,7 +1,8 @@ import warnings import pytest -import test_utils + +from . import test_utils TEST_PROJECT_NAME = "test_project" diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 50c27d849..000000000 --- a/tests/conftest.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Test configs, used for setting up SSH tests. - -Before running these tests, it is necessary to setup -an SSH key. This can be done through datashuttle -ssh.setup_ssh_key(project.cfg, log=False). - -Store this path somewhere outside of the test environment, -and it will be copied to the project test folder before testing. - -FILESYSTEM_PATH and SERVER_PATH these must point -to the same folder on the HPC, filesystem, -as a mounted drive and server as the linux path to -connect through SSH -""" - -import platform -from types import SimpleNamespace - -import pytest -import test_utils - -test_ssh = False -username = "jziminski" -central_host_id = "ssh.swc.ucl.ac.uk" -server_path = r"/ceph/neuroinformatics/neuroinformatics/scratch/datashuttle_tests/fake_data" - - -if platform.system() == "Windows": - ssh_key_path = r"C:\Users\Joe\.datashuttle\test_file_conflicts_ssh_key" - filesystem_path = "X:/neuroinformatics/scratch/datashuttle_tests/fake_data" - -else: - ssh_key_path = "/home/joe/test_file_conflicts_ssh_key" - filesystem_path = "/home/joe/ceph_mount/neuroinformatics/scratch/datashuttle_tests/fake_data" - - -def pytest_configure(config): - pytest.ssh_config = SimpleNamespace( - TEST_SSH=test_ssh, - SSH_KEY_PATH=ssh_key_path, - USERNAME=username, - CENTRAL_HOST_ID=central_host_id, - FILESYSTEM_PATH=filesystem_path, # FILESYSTEM_PATH and SERVER_PATH these must point to the same folder on the HPC, filesystem - SERVER_PATH=server_path, # as a mounted drive and server as the linux path to connect through SSH - ) - test_utils.set_datashuttle_loggers(disable=True) diff --git a/tests/ssh_test_utils.py b/tests/ssh_test_utils.py deleted file mode 100644 index a7af1a65c..000000000 --- a/tests/ssh_test_utils.py +++ /dev/null @@ -1,49 +0,0 @@ -import builtins -import copy - -from datashuttle.utils import rclone, ssh - - -def setup_project_for_ssh( - project, central_path, central_host_id, central_host_username -): - """Set up the project configs to use SSH connection - to central. - """ - project.update_config_file( - central_path=central_path, - ) - project.update_config_file(central_host_id=central_host_id) - project.update_config_file(central_host_username=central_host_username) - project.update_config_file(connection_method="ssh") - - rclone.setup_rclone_config_for_ssh( - project.cfg, - project.cfg.get_rclone_config_name("ssh"), - project.cfg.ssh_key_path, - ) - - -def setup_mock_input(input_): - """Very similar to pytest monkeypatch but - using that was giving me very strange output, - monkeypatch.setattr('builtins.input', lambda _: "n") - i.e. pdb went deep into some unrelated code stack. - """ - orig_builtin = copy.deepcopy(builtins.input) - builtins.input = lambda _: input_ # type: ignore - return orig_builtin - - -def restore_mock_input(orig_builtin): - """orig_builtin: the copied, original builtins.input.""" - builtins.input = orig_builtin - - -def setup_hostkeys(project): - """Convenience function to verify the server hostkey.""" - orig_builtin = setup_mock_input(input_="y") - ssh.verify_ssh_central_host( - project.cfg["central_host_id"], project.cfg.hostkeys_path, log=True - ) - restore_mock_input(orig_builtin) diff --git a/tests/test_utils.py b/tests/test_utils.py index c70928fe5..48a7617d0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,7 +10,6 @@ from pathlib import Path import yaml -from file_conflicts_pathtable import get_pathtable from datashuttle import DataShuttle from datashuttle.configs import canonical_configs, canonical_folders @@ -146,18 +145,6 @@ def make_test_path(base_path, local_or_central, test_project_name): return Path(base_path) / local_or_central / test_project_name -def create_all_pathtable_files(pathtable): - for i in range(pathtable.shape[0]): - filepath = pathtable["base_folder"][i] / pathtable["path"][i] - filepath.parents[0].mkdir(parents=True, exist_ok=True) - write_file(filepath, contents="test_entry") - - -def quick_create_project(base_path): - pathtable = get_pathtable(base_path) - create_all_pathtable_files(pathtable) - - # ----------------------------------------------------------------------------- # Test Configs # ----------------------------------------------------------------------------- diff --git a/tests/tests_integration/__init__.py b/tests/tests_integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/test_configs.py b/tests/tests_integration/test_configs.py index 5156f84b7..f0a394d00 100644 --- a/tests/tests_integration/test_configs.py +++ b/tests/tests_integration/test_configs.py @@ -1,13 +1,14 @@ import os import pytest -import test_utils -from base import BaseTest from datashuttle import DataShuttle from datashuttle.utils import getters from datashuttle.utils.custom_exceptions import ConfigError +from .. import test_utils +from ..base import BaseTest + class TestConfigs(BaseTest): # Test Errors diff --git a/tests/tests_integration/test_create_folders.py b/tests/tests_integration/test_create_folders.py index cb1de20a3..72102d986 100644 --- a/tests/tests_integration/test_create_folders.py +++ b/tests/tests_integration/test_create_folders.py @@ -5,12 +5,13 @@ from os.path import join import pytest -import test_utils -from base import BaseTest from datashuttle.configs import canonical_configs, canonical_folders from datashuttle.configs.canonical_tags import tags +from .. import test_utils +from ..base import BaseTest + class TestCreateFolders(BaseTest): @pytest.mark.parametrize("project", ["local", "full"], indirect=True) diff --git a/tests/tests_integration/test_datatypes.py b/tests/tests_integration/test_datatypes.py index 688da3e36..8b6be04c3 100644 --- a/tests/tests_integration/test_datatypes.py +++ b/tests/tests_integration/test_datatypes.py @@ -1,11 +1,12 @@ import os import pytest -import test_utils -from base import BaseTest from datashuttle.configs import canonical_configs +from .. import test_utils +from ..base import BaseTest + class TestDatatypes(BaseTest): """Tests for creating folders and transfer (very similar to other tests) diff --git a/tests/tests_integration/test_formatting.py b/tests/tests_integration/test_formatting.py index 0b5db8338..3a8c5e207 100644 --- a/tests/tests_integration/test_formatting.py +++ b/tests/tests_integration/test_formatting.py @@ -1,9 +1,10 @@ import pytest -from base import BaseTest from datashuttle.utils import formatting from datashuttle.utils.custom_exceptions import NeuroBlueprintError +from ..base import BaseTest + class TestFormatting(BaseTest): @pytest.mark.parametrize("prefix", ["sub", "ses"]) diff --git a/tests/tests_integration/test_local_only_mode.py b/tests/tests_integration/test_local_only_mode.py index b2853cbad..c4a3dda34 100644 --- a/tests/tests_integration/test_local_only_mode.py +++ b/tests/tests_integration/test_local_only_mode.py @@ -1,13 +1,14 @@ import shutil import pytest -import test_utils -from base import BaseTest from datashuttle.utils.custom_exceptions import ( ConfigError, ) +from .. import test_utils +from ..base import BaseTest + TEST_PROJECT_NAME = "test_project" diff --git a/tests/tests_integration/test_logging.py b/tests/tests_integration/test_logging.py index 9a3ec8855..9b5bd4a0b 100644 --- a/tests/tests_integration/test_logging.py +++ b/tests/tests_integration/test_logging.py @@ -5,7 +5,6 @@ from pathlib import Path import pytest -import test_utils from datashuttle.configs import canonical_configs from datashuttle.configs.canonical_tags import tags @@ -15,6 +14,8 @@ NeuroBlueprintError, ) +from .. import test_utils + class TestLogging: @pytest.fixture(scope="function") diff --git a/tests/tests_integration/test_settings.py b/tests/tests_integration/test_settings.py index bcb9c73ec..ce3a474aa 100644 --- a/tests/tests_integration/test_settings.py +++ b/tests/tests_integration/test_settings.py @@ -2,13 +2,14 @@ import shutil import pytest -import test_utils -from base import BaseTest from datashuttle.configs import canonical_configs from datashuttle.utils import validation from datashuttle.utils.custom_exceptions import NeuroBlueprintError +from .. import test_utils +from ..base import BaseTest + class TestPersistentSettings(BaseTest): @pytest.mark.parametrize("project", ["local", "full"], indirect=True) diff --git a/tests/tests_integration/test_ssh_file_transfer.py b/tests/tests_integration/test_ssh_file_transfer.py deleted file mode 100644 index 4e98894c6..000000000 --- a/tests/tests_integration/test_ssh_file_transfer.py +++ /dev/null @@ -1,305 +0,0 @@ -import copy -import glob -import shutil -import time -from pathlib import Path - -import pandas as pd -import pytest -import ssh_test_utils -import test_utils -from file_conflicts_pathtable import get_pathtable -from pytest import ssh_config - - -class TestFileTransfer: - @pytest.fixture( - scope="class", - params=[ # Set running SSH or local filesystem (see docstring). - # False, - pytest.param( - True, - marks=pytest.mark.skipif( - ssh_config.TEST_SSH is False, - reason="TEST_SSH is set to False.", - ), - ), - ], - ) - def pathtable_and_project(self, request, tmpdir_factory): - """Create a project for SSH testing. Setup - the project as normal, and switch configs - to use SSH connection. - - Although SSH is used for transfer, for SSH tests, - checking the created filepaths is always - done through the local filesystem for speed - and convenience. As such, the drive that is - SSH to must also be mounted and the path - supplied to the location SSH'd to. - - For speed, create the project once, - and all files to transfer. Then in the - test function, the folder are transferred. - Partial cleanup is done in the test function - i.e. deleting the central_path to which the - items have been transferred. This is achieved - by using "class" scope. - - Notes - ----- - - Pytest params - The `params` key sets the - `params` attribute on the pytest `request` fixture. - This attribute is used to set the `testing_ssh` variable - to `True` or `False`. In the first run, this is set to - `False`, meaning local filesystem tests are run. In the - second run, this is set with a pytest parameter that is - `True` (i.e. SSH tests are run) but is skipped if `TEST_SSH` - in `ssh_config` (set in conftest.py` is `False`. - - - For convenience, files are transferred - with SSH and then checked through the local filesystem - mount. This is significantly easier than checking - everything through SFTP. However, on Windows the - mounted filesystem is quite slow to update, taking - a few seconds after SSH transfer. This makes the - tests run very slowly. We can get rid - of this limitation on linux. - - """ - testing_ssh = request.param - tmp_path = tmpdir_factory.mktemp("test") - - if testing_ssh: - base_path = ssh_config.FILESYSTEM_PATH - central_path = ssh_config.SERVER_PATH - else: - base_path = tmp_path / "test with space" - central_path = base_path - test_project_name = "test_file_conflicts" - - project = test_utils.setup_project_fixture( - base_path, test_project_name - ) - - if testing_ssh: - ssh_test_utils.setup_project_for_ssh( - project, - test_utils.make_test_path( - central_path, "central", test_project_name - ), - ssh_config.CENTRAL_HOST_ID, - ssh_config.USERNAME, - ) - - # Initialise the SSH connection - ssh_test_utils.setup_hostkeys(project) - shutil.copy(ssh_config.SSH_KEY_PATH, project.cfg.file_path.parent) - - pathtable = get_pathtable(project.cfg["local_path"]) - test_utils.create_all_pathtable_files(pathtable) - project.testing_ssh = testing_ssh - - yield [pathtable, project] - - test_utils.teardown_project(project) - - if testing_ssh: - for result in glob.glob(ssh_config.FILESYSTEM_PATH): - shutil.rmtree(result) - - # ------------------------------------------------------------------------- - # Utils - # ------------------------------------------------------------------------- - - def central_from_local(self, path_): - return Path(str(copy.copy(path_)).replace("local", "central")) - - # ------------------------------------------------------------------------- - # Test File Transfer - All Options - # ------------------------------------------------------------------------- - - @pytest.mark.parametrize( - "sub_names", - [ - ["all"], - ["all_sub"], - ["all_non_sub"], - ["sub-001"], - ["sub-003_date-20231901"], - ["sub-002", "all_non_sub"], - ], - ) - @pytest.mark.parametrize( - "ses_names", - [ - ["all"], - ["all_non_ses"], - ["all_ses"], - ["ses-001"], - ["ses-002_random-key"], - ["all_non_ses", "ses-001"], - ], - ) - @pytest.mark.parametrize( - "datatype", - [ - ["all"], - ["all_non_datatype"], - ["all_datatype"], - ["behav"], - ["ephys"], - ["anat"], - ["funcimg"], - ["anat", "behav", "all_non_datatype"], - ], - ) - @pytest.mark.parametrize("upload_or_download", ["upload", "download"]) - def test_all_data_transfer_options( - self, - pathtable_and_project, - sub_names, - ses_names, - datatype, - upload_or_download, - ): - """Parse the arguments to filter the pathtable, getting - the files expected to be transferred passed on the arguments - Note files in sub/ses/datatype folders must be handled - separately to those in non-sub, non-ses, non-datatype folders. - - see test_utils.swap_local_and_central_paths() for the logic - on setting up and swapping local / central paths for - upload / download tests. - """ - pathtable, project = pathtable_and_project - - transfer_function = test_utils.handle_upload_or_download( - project, - upload_or_download, - transfer_method="custom", - swap_last_folder_only=project.testing_ssh, - )[0] - - transfer_function( - "rawdata", sub_names, ses_names, datatype, init_log=False - ) - - if upload_or_download == "download": - test_utils.swap_local_and_central_paths( - project, swap_last_folder_only=project.testing_ssh - ) - - sub_names = self.parse_arguments(pathtable, sub_names, "sub") - ses_names = self.parse_arguments(pathtable, ses_names, "ses") - datatype = self.parse_arguments(pathtable, datatype, "datatype") - - # Filter pathtable to get files that were expected - # to be transferred - ( - sub_ses_dtype_arguments, - extra_arguments, - ) = self.make_pathtable_search_filter(sub_names, ses_names, datatype) - - datatype_folders = self.query_table(pathtable, sub_ses_dtype_arguments) - extra_folders = self.query_table(pathtable, extra_arguments) - - expected_paths = pd.concat([datatype_folders, extra_folders]) - expected_paths = expected_paths.drop_duplicates(subset="path") - - central_base_paths = expected_paths.base_folder.map( - lambda x: str(x).replace("local", "central") - ) - expected_transferred_paths = central_base_paths / expected_paths.path - - # When transferring with SSH, there is a delay before - # filesystem catches up - if project.testing_ssh: - time.sleep(0.5) - - # Check what paths were actually moved - # (through the local filesystem), and test - path_to_search = ( - self.central_from_local(project.cfg["local_path"]) / "rawdata" - ) - all_transferred = path_to_search.glob("**/*") - paths_to_transferred_files = list( - filter(Path.is_file, all_transferred) - ) - - assert sorted(paths_to_transferred_files) == sorted( - expected_transferred_paths - ) - - # Teardown here, because we have session scope. - try: - shutil.rmtree(self.central_from_local(project.cfg["local_path"])) - except FileNotFoundError: - pass - - # --------------------------------------------------------------------------------------------------------------- - # Utils - # --------------------------------------------------------------------------------------------------------------- - - def query_table(self, pathtable, arguments): - """Search the table for arguments, return empty - if arguments empty. - """ - if any(arguments): - folders = pathtable.query(" | ".join(arguments)) - else: - folders = pd.DataFrame() - return folders - - def parse_arguments(self, pathtable, list_of_names, field): - """Replicate datashuttle name formatting by parsing - "all" arguments and turning them into a list of all names, - (subject or session), taken from the pathtable. - """ - if list_of_names in [["all"], [f"all_{field}"]]: - entries = pathtable.query(f"parent_{field} != False")[ - f"parent_{field}" - ] - entries = list(set(entries)) - if list_of_names == ["all"]: - entries += ( - [f"all_non_{field}"] - if field != "datatype" - else ["all_non_datatype"] - ) - list_of_names = entries - return list_of_names - - def make_pathtable_search_filter(self, sub_names, ses_names, datatype): - """Create a string of arguments to pass to pd.query() that will - create the table of only transferred sub, ses and datatype. - - Two arguments must be created, one of all sub / ses / datatypes - and the other of all non sub/ non ses / non datatype - folders. These must be handled separately as they are - mutually exclusive. - """ - sub_ses_dtype_arguments = [] - extra_arguments = [] - - for sub in sub_names: - if sub == "all_non_sub": - extra_arguments += ["is_non_sub == True"] - else: - for ses in ses_names: - if ses == "all_non_ses": - extra_arguments += [ - f"(parent_sub == '{sub}' & is_non_ses == True)" - ] - else: - for dtype in datatype: - if dtype == "all_non_datatype": - extra_arguments += [ - f"(parent_sub == '{sub}' & parent_ses == '{ses}' & is_ses_level_non_datatype == True)" - ] - else: - sub_ses_dtype_arguments += [ - f"(parent_sub == '{sub}' & parent_ses == '{ses}' & (parent_datatype == '{dtype}' | parent_datatype == '{dtype}'))" - ] - - return sub_ses_dtype_arguments, extra_arguments diff --git a/tests/tests_integration/test_validation.py b/tests/tests_integration/test_validation.py index c8eaa0d3a..0c40bc2f4 100644 --- a/tests/tests_integration/test_validation.py +++ b/tests/tests_integration/test_validation.py @@ -3,12 +3,13 @@ import warnings import pytest -from base import BaseTest from datashuttle import quick_validate_project from datashuttle.utils import formatting, validation from datashuttle.utils.custom_exceptions import NeuroBlueprintError +from ..base import BaseTest + # ----------------------------------------------------------------------------- # Inconsistent sub or ses value lengths # ----------------------------------------------------------------------------- diff --git a/tests/tests_regression/__init__.py b/tests/tests_regression/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_regression/test_backwards_compatibility.py b/tests/tests_regression/test_backwards_compatibility.py index 7d73ff1d0..7fab581ee 100644 --- a/tests/tests_regression/test_backwards_compatibility.py +++ b/tests/tests_regression/test_backwards_compatibility.py @@ -3,7 +3,8 @@ from pathlib import Path import pytest -import test_utils + +from .. import test_utils TEST_PROJECT_NAME = "test_project" diff --git a/tests/tests_transfers/__init__.py b/tests/tests_transfers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_transfers/base_transfer.py b/tests/tests_transfers/base_transfer.py new file mode 100644 index 000000000..803684914 --- /dev/null +++ b/tests/tests_transfers/base_transfer.py @@ -0,0 +1,161 @@ +""" """ + +import copy +from pathlib import Path + +import pandas as pd +import pytest + +from .. import test_utils +from ..base import BaseTest +from .file_conflicts_pathtable import get_pathtable + + +class BaseTransfer(BaseTest): + """ + Class holding fixtures and methods for testing the + custom transfers with keys (e.g. all_non_sub). + """ + + @pytest.fixture( + scope="class", + ) + def pathtable_and_project(self, tmpdir_factory): + """ + Create a new test project with a test project folder + and file structure (see `get_pathtable()` for definition). + """ + tmp_path = tmpdir_factory.mktemp("test") + + base_path = tmp_path / "test with space" + test_project_name = "test_file_conflicts" + + project = test_utils.setup_project_fixture( + base_path, test_project_name + ) + + pathtable = get_pathtable(project.cfg["local_path"]) + + self.create_all_pathtable_files(pathtable) + + yield [pathtable, project] + + test_utils.teardown_project(project) + + def get_expected_transferred_paths( + self, pathtable, sub_names, ses_names, datatype + ): + """ + Process the expected files that are transferred using the logic in + `make_pathtable_search_filter()` to + """ + parsed_sub_names = self.parse_arguments(pathtable, sub_names, "sub") + parsed_ses_names = self.parse_arguments(pathtable, ses_names, "ses") + parsed_datatype = self.parse_arguments(pathtable, datatype, "datatype") + + # Filter pathtable to get files that were expected to be transferred + ( + sub_ses_dtype_arguments, + extra_arguments, + ) = self.make_pathtable_search_filter( + parsed_sub_names, parsed_ses_names, parsed_datatype + ) + + datatype_folders = self.query_table(pathtable, sub_ses_dtype_arguments) + extra_folders = self.query_table(pathtable, extra_arguments) + + expected_paths = pd.concat([datatype_folders, extra_folders]) + expected_paths = expected_paths.drop_duplicates(subset="path") + + expected_paths = self.remove_path_before_rawdata(expected_paths.path) + + return expected_paths + + def make_pathtable_search_filter(self, sub_names, ses_names, datatype): + """ + Create a string of arguments to pass to pd.query() that will + create the table of only transferred sub, ses and datatype. + + Two arguments must be created, one of all sub / ses / datatypes + and the other of all non sub/ non ses / non datatype + folders. These must be handled separately as they are + mutually exclusive. + """ + sub_ses_dtype_arguments = [] + extra_arguments = [] + + for sub in sub_names: + if sub == "all_non_sub": + extra_arguments += ["is_non_sub == True"] + else: + for ses in ses_names: + if ses == "all_non_ses": + extra_arguments += [ + f"(parent_sub == '{sub}' & is_non_ses == True)" + ] + else: + for dtype in datatype: + if dtype == "all_non_datatype": + extra_arguments += [ + f"(parent_sub == '{sub}' & parent_ses == '{ses}' " + f"& is_ses_level_non_datatype == True)" + ] + else: + sub_ses_dtype_arguments += [ + f"(parent_sub == '{sub}' & parent_ses == '{ses}' " + f"& parent_datatype == '{dtype}' )" + ] + + return sub_ses_dtype_arguments, extra_arguments + + def remove_path_before_rawdata(self, list_of_paths): + """ + Remove the path to project files before the "rawdata" so + they can be compared no matter where the project was stored + (e.g. on a central server vs. local filesystem). + """ + cut_paths = [] + for path_ in list_of_paths: + parts = Path(path_).parts + cut_paths.append(Path(*parts[parts.index("rawdata") :])) + return cut_paths + + def query_table(self, pathtable, arguments): + """ + Search the table for arguments, return empty + if arguments empty + """ + if any(arguments): + folders = pathtable.query(" | ".join(arguments)) + else: + folders = pd.DataFrame() + return folders + + def parse_arguments(self, pathtable, list_of_names, field): + """ + Replicate datashuttle name formatting by parsing + "all" arguments and turning them into a list of all names, + (subject or session), taken from the pathtable. + """ + if list_of_names in [["all"], [f"all_{field}"]]: + entries = pathtable.query(f"parent_{field} != False")[ + f"parent_{field}" + ] + entries = list(set(entries)) + if list_of_names == ["all"]: + entries += [f"all_non_{field}"] + list_of_names = entries + return list_of_names + + def create_all_pathtable_files(self, pathtable): + """ + Create the entire test project in the defined + location (usually project's `local_path`). + """ + for i in range(pathtable.shape[0]): + filepath = pathtable["base_folder"][i] / pathtable["path"][i] + filepath.parents[0].mkdir(parents=True, exist_ok=True) + test_utils.write_file(filepath, contents="test_entry") + + def central_from_local(self, path_): + return Path(str(copy.copy(path_)).replace("local", "central")) diff --git a/tests/file_conflicts_pathtable.py b/tests/tests_transfers/file_conflicts_pathtable.py similarity index 100% rename from tests/file_conflicts_pathtable.py rename to tests/tests_transfers/file_conflicts_pathtable.py diff --git a/tests/tests_transfers/local_filesystem/__init__.py b/tests/tests_transfers/local_filesystem/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/test_filesystem_transfer.py b/tests/tests_transfers/local_filesystem/test_transfer.py similarity index 99% rename from tests/tests_integration/test_filesystem_transfer.py rename to tests/tests_transfers/local_filesystem/test_transfer.py index 7dbfab4d5..25ce82363 100644 --- a/tests/tests_integration/test_filesystem_transfer.py +++ b/tests/tests_transfers/local_filesystem/test_transfer.py @@ -4,13 +4,14 @@ from pathlib import Path import pytest -import test_utils -from base import BaseTest from datashuttle.configs import canonical_folders from datashuttle.configs.canonical_configs import get_broad_datatypes from datashuttle.configs.canonical_tags import tags +from ... import test_utils +from ...base import BaseTest + class TestFileTransfer(BaseTest): @pytest.mark.parametrize( diff --git a/tests/tests_integration/test_transfer_checks.py b/tests/tests_transfers/local_filesystem/test_transfer_checks.py similarity index 98% rename from tests/tests_integration/test_transfer_checks.py rename to tests/tests_transfers/local_filesystem/test_transfer_checks.py index 7a6f631f3..5a446d53d 100644 --- a/tests/tests_integration/test_transfer_checks.py +++ b/tests/tests_transfers/local_filesystem/test_transfer_checks.py @@ -3,11 +3,12 @@ from pathlib import Path import pytest -import test_utils -from base import BaseTest from datashuttle.utils.rclone import get_local_and_central_file_differences +from ... import test_utils +from ...base import BaseTest + class TestTransferChecks(BaseTest): @pytest.mark.parametrize( diff --git a/tests/tests_transfers/local_filesystem/test_transfer_special_arguments.py b/tests/tests_transfers/local_filesystem/test_transfer_special_arguments.py new file mode 100644 index 000000000..71ca45d9e --- /dev/null +++ b/tests/tests_transfers/local_filesystem/test_transfer_special_arguments.py @@ -0,0 +1,111 @@ +""" """ + +import shutil +from pathlib import Path + +import pytest + +from ... import test_utils +from ..base_transfer import BaseTransfer + +PARAM_SUBS = [ + ["all"], + ["all_sub"], + ["all_non_sub"], + ["sub-001"], + ["sub-003_date-20231201"], + ["sub-002", "all_non_sub"], +] +PARAM_SES = [ + ["all"], + ["all_non_ses"], + ["all_ses"], + ["ses-001"], + ["ses-002_random-key"], + ["all_non_ses", "ses-001"], +] +PARAM_DATATYPE = [ + ["all"], + ["all_non_datatype"], + ["all_datatype"], + ["behav"], + ["ephys"], + ["anat"], + ["funcimg"], + ["anat", "behav", "all_non_datatype"], +] + + +class TestFileTransfer(BaseTransfer): + # ---------------------------------------------------------------------------------- + # Test File Transfer - All Options + # ---------------------------------------------------------------------------------- + + @pytest.mark.parametrize("sub_names", PARAM_SUBS) + @pytest.mark.parametrize("ses_names", PARAM_SES) + @pytest.mark.parametrize("datatype", PARAM_DATATYPE) + @pytest.mark.parametrize("upload_or_download", ["upload", "download"]) + def test_combinations_filesystem_transfer( + self, + pathtable_and_project, + sub_names, + ses_names, + datatype, + upload_or_download, + ): + """ + Test many combinations of possible file transfer commands. The + entire test project is created in the original `local_path` + and subset of it is uploaded and tested against. To test + upload vs. download, the `local_path` and `central_path` + locations are swapped. + """ + pathtable, project = pathtable_and_project + + # Transfer the data, swapping the paths to move a subset of + # files from the already set up directory to a new directory + # using upload or download. + transfer_function = test_utils.handle_upload_or_download( + project, + upload_or_download, + transfer_method="custom", + swap_last_folder_only=False, + )[0] + + transfer_function( + "rawdata", sub_names, ses_names, datatype, init_log=False + ) + + if upload_or_download == "download": + test_utils.swap_local_and_central_paths( + project, swap_last_folder_only=False + ) + + expected_transferred_paths = self.get_expected_transferred_paths( + pathtable, sub_names, ses_names, datatype + ) + + # Check what paths were actually moved + # (through the local filesystem), and test + path_to_search = ( + self.central_from_local(project.cfg["local_path"]) / "rawdata" + ) + all_transferred = path_to_search.glob("**/*") + + paths_to_transferred_files = list( + filter(Path.is_file, all_transferred) + ) + + paths_to_transferred_files = self.remove_path_before_rawdata( + paths_to_transferred_files + ) + + assert sorted(paths_to_transferred_files) == sorted( + expected_transferred_paths + ) + + # Teardown here, because we have session scope. + try: + shutil.rmtree(self.central_from_local(project.cfg["local_path"])) + except FileNotFoundError: + pass diff --git a/tests/tests_transfers/ssh/__init__.py b/tests/tests_transfers/ssh/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_transfers/ssh/base_ssh.py b/tests/tests_transfers/ssh/base_ssh.py new file mode 100644 index 000000000..0a8c52b97 --- /dev/null +++ b/tests/tests_transfers/ssh/base_ssh.py @@ -0,0 +1,76 @@ +""" """ + +import os +import platform +import subprocess +from pathlib import Path + +import pytest + +from ..base_transfer import BaseTransfer +from . import ssh_test_utils + +# Choose port 3306 for running on GH actions +# suggested in https://github.com/orgs/community/discussions/25550 +PORT = 3306 +os.environ["DS_SSH_PORT"] = str(PORT) + + +class BaseSSHTransfer(BaseTransfer): + """ + Class holding fixtures and methods for testing the + custom transfers with keys (e.g. all_non_sub). + """ + + @pytest.fixture( + scope="class", + ) + def setup_ssh_container(self): + """ + Set up the Dockerfile container for SSH tests and + delete it on teardown. + """ + container_name = "datashuttle_ssh_tests" + + assert ssh_test_utils.docker_is_running(), ( + "docker is not running, " + "this should be checked at the top of test script" + ) + + image_path = Path(__file__).parent / "ssh_test_images" + os.chdir(image_path) + + if platform.system() != "Windows": + build_command = "sudo docker build -t ssh_server ." + run_command = ( + f"sudo docker run -d -p {PORT}:22 " + f"--name {container_name} ssh_server" + ) + else: + build_command = "docker build -t ssh_server ." + run_command = f"docker run -d -p {PORT}:22 --name {container_name} ssh_server" + + build_output = subprocess.run( + build_command, + shell=True, + capture_output=True, + ) + assert build_output.returncode == 0, ( + f"docker build failed with: STDOUT-{build_output.stdout} " + f"STDERR-{build_output.stderr}" + ) + + run_output = subprocess.run( + run_command, + shell=True, + capture_output=True, + ) + + assert run_output.returncode == 0, ( + f"docker run failed with: STDOUT-{run_output.stdout} " + f"STDERR-{run_output.stderr}" + ) + + yield + + subprocess.run(f"docker rm -f {container_name}", shell=True) diff --git a/tests/tests_transfers/ssh/ssh_test_images/Dockerfile b/tests/tests_transfers/ssh/ssh_test_images/Dockerfile new file mode 100644 index 000000000..474c8ecb3 --- /dev/null +++ b/tests/tests_transfers/ssh/ssh_test_images/Dockerfile @@ -0,0 +1,25 @@ +# Use a base image with the desired OS (e.g., Ubuntu, Debian, etc.) +FROM ubuntu:latest + +# Install SSH server +RUN apt-get update && \ + apt-get upgrade -y +RUN apt-get install openssh-server -y supervisor +RUN apt-get install nano + +# Create an SSH user +RUN useradd -rm -d /home/sshuser -s /bin/bash -g root -G sudo sshuser + +# Set the SSH user's password (replace "password" with your desired password) +RUN echo "sshuser:password" | chpasswd + +# Allow SSH access +RUN mkdir /var/run/sshd + +RUN /usr/bin/ssh-keygen -A + +# Expose the SSH port +EXPOSE 22 + +# Start SSH server on container startup +CMD ["/usr/sbin/sshd", "-D"] diff --git a/tests/tests_transfers/ssh/ssh_test_utils.py b/tests/tests_transfers/ssh/ssh_test_utils.py new file mode 100644 index 000000000..5cbf141b4 --- /dev/null +++ b/tests/tests_transfers/ssh/ssh_test_utils.py @@ -0,0 +1,138 @@ +import builtins +import copy +import stat +import subprocess +import sys + +import paramiko + +from datashuttle.utils import rclone, ssh + + +def setup_project_for_ssh( + project, +): + """ + Set up the project configs to use + SSH connection to central. The settings + set up a connection to the Dockerfile image + found in /ssh_test_images. + """ + project.update_config_file( + connection_method="ssh", + central_path=f"/home/sshuser/datashuttle/{project.project_name}", + central_host_id="localhost", + central_host_username="sshuser", + ) + + +def setup_ssh_connection(project, setup_ssh_key_pair=True): + """ + Convenience function to verify the server hostkey and ssh + key pairs to the Dockerfile image for ssh tests. + + This requires monkeypatching a number of functions involved + in the SSH setup process. `input()` is patched to always + return the required hostkey confirmation "y". `getpass()` is + patched to always return the password for the container in which + SSH tests are run. `isatty()` is patched because when running this + for some reason it appears to be in a TTY - this might be a + container thing. + """ + # Monkeypatch + orig_builtin = copy.deepcopy(builtins.input) + builtins.input = lambda _: "y" # type: ignore + + orig_getpass = copy.deepcopy(ssh.getpass.getpass) + ssh.getpass.getpass = lambda _: "password" # type: ignore + + orig_isatty = copy.deepcopy(sys.stdin.isatty) + sys.stdin.isatty = lambda: True + + # Run setup + verified = ssh.verify_ssh_central_host( + project.cfg["central_host_id"], project.cfg.hostkeys_path, log=True + ) + + if setup_ssh_key_pair: + ssh.setup_ssh_key(project.cfg, log=False) + + # Restore functions + builtins.input = orig_builtin + ssh.getpass.getpass = orig_getpass + sys.stdin.isatty = orig_isatty + + rclone.setup_rclone_config_for_ssh( + project.cfg, + project.cfg.get_rclone_config_name("ssh"), + project.cfg.ssh_key_path, + ) + + return verified + + +def recursive_search_central(project): + """ + A convenience function to recursively search a + project for files through SSH, used during testing + across an SSH connection to collected names of + files that were transferred. + """ + with paramiko.SSHClient() as client: + ssh.connect_client_core(client, project.cfg) + + sftp = client.open_sftp() + + all_filenames = [] + + sftp_recursive_file_search( + sftp, + (project.cfg["central_path"] / "rawdata").as_posix(), + all_filenames, + ) + return all_filenames + + +def sftp_recursive_file_search(sftp, path_, all_filenames): + """ + Append all filenames found within a folder, + when searching over a sftp connection. + """ + try: + sftp.stat(path_) + except FileNotFoundError: + return + + for file_or_folder in sftp.listdir_attr(path_): + if stat.S_ISDIR(file_or_folder.st_mode): + sftp_recursive_file_search( + sftp, + path_ + "/" + file_or_folder.filename, + all_filenames, + ) + else: + all_filenames.append(path_ + "/" + file_or_folder.filename) + + +def docker_is_running(): + if not is_docker_installed(): + return False + + is_running = check_sys_command_returns_0("docker stats --no-stream") + return is_running + + +def is_docker_installed(): + return check_sys_command_returns_0("docker -v") + + +def check_sys_command_returns_0(command): + return ( + subprocess.run( + command, + shell=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ).returncode + == 0 + ) diff --git a/tests/tests_integration/test_ssh_setup.py b/tests/tests_transfers/ssh/test_ssh_setup.py similarity index 57% rename from tests/tests_integration/test_ssh_setup.py rename to tests/tests_transfers/ssh/test_ssh_setup.py index 31c1fa72e..69e70e527 100644 --- a/tests/tests_integration/test_ssh_setup.py +++ b/tests/tests_transfers/ssh/test_ssh_setup.py @@ -1,33 +1,40 @@ -"""SSH configs are set in conftest.py . The password -should be stored in a file called test_ssh_password.txt located -in the same folder as test_ssh.py. -""" +import builtins +import copy +import platform import pytest -import ssh_test_utils -import test_utils -from pytest import ssh_config from datashuttle.utils import ssh +from ... import test_utils +from . import ssh_test_utils +from .base_ssh import BaseSSHTransfer -@pytest.mark.skipif(ssh_config.TEST_SSH is False, reason="TEST_SSH is false") -class TestSSH: +TEST_SSH = ssh_test_utils.docker_is_running() + + +@pytest.mark.skipif( + platform.system == "Darwin", reason="Docker set up is not robust on macOS." +) +@pytest.mark.skipif( + not TEST_SSH, + reason="SSH tests are not run as docker is either not installed, " + "running or current user is not in the docker group.", +) +class TestSSH(BaseSSHTransfer): @pytest.fixture(scope="function") - def project(test, tmp_path): - """Make a project as per usual, but now add - in test ssh configurations. + def project(test, tmp_path, setup_ssh_container): + """Set up a project with configs for SSH into + the test Dockerfile image. """ tmp_path = tmp_path / "test with space" test_project_name = "test_ssh" + project = test_utils.setup_project_fixture(tmp_path, test_project_name) ssh_test_utils.setup_project_for_ssh( project, - ssh_config.FILESYSTEM_PATH, - ssh_config.CENTRAL_HOST_ID, - ssh_config.USERNAME, ) yield project @@ -41,18 +48,13 @@ def project(test, tmp_path): def test_verify_ssh_central_host_do_not_accept( self, capsys, project, input_ ): - """Use the main function to test this. Test the sub-function - when accepting, because this main function will also - call setup ssh key pairs which we don't want to do yet. - - This should only accept for "y" so try some random strings - including "n" and check they all do not make the connection. - """ - orig_builtin = ssh_test_utils.setup_mock_input(input_) + """Test that host not accepted if input is not "y".""" + orig_builtin = copy.deepcopy(builtins.input) + builtins.input = lambda _: input_ # type: ignore project.setup_ssh_connection() - ssh_test_utils.restore_mock_input(orig_builtin) + builtins.input = orig_builtin captured = capsys.readouterr() @@ -63,22 +65,22 @@ def test_verify_ssh_central_host_accept(self, capsys, project): and check hostkey is successfully accepted and written to configs. """ test_utils.clear_capsys(capsys) - orig_builtin = ssh_test_utils.setup_mock_input(input_="y") - verified = ssh.verify_ssh_central_host( - project.cfg["central_host_id"], project.cfg.hostkeys_path, log=True + verified = ssh_test_utils.setup_ssh_connection( + project, setup_ssh_key_pair=False ) - ssh_test_utils.restore_mock_input(orig_builtin) - assert verified captured = capsys.readouterr() + assert captured.out == "Host accepted.\n" with open(project.cfg.hostkeys_path) as file: hostkey = file.readlines()[0] - assert f"{project.cfg['central_host_id']} ssh-ed25519 " in hostkey + assert ( + f"[{project.cfg['central_host_id']}]:3306 ssh-ed25519 " in hostkey + ) def test_generate_and_write_ssh_key(self, project): """Check ssh key for passwordless connection is written diff --git a/tests/tests_transfers/ssh/test_ssh_transfer.py b/tests/tests_transfers/ssh/test_ssh_transfer.py new file mode 100644 index 000000000..f086462ef --- /dev/null +++ b/tests/tests_transfers/ssh/test_ssh_transfer.py @@ -0,0 +1,152 @@ +import platform +import shutil + +import paramiko +import pytest + +from datashuttle.utils import ssh + +from . import ssh_test_utils +from .base_ssh import BaseSSHTransfer + +TEST_SSH = ssh_test_utils.docker_is_running() + + +@pytest.mark.skipif( + platform.system == "Darwin", reason="Docker set up is not robust on macOS." +) +@pytest.mark.skipif( + not TEST_SSH, + reason="SSH tests are not run as docker is either not installed, " + "running or current user is not in the docker group.", +) +class TestSSHTransfer(BaseSSHTransfer): + @pytest.fixture( + scope="class", + ) + def ssh_setup(self, pathtable_and_project, setup_ssh_container): + """ + After initial project setup (in `pathtable_and_project`) + setup a container and the project's SSH connection to the container. + Then upload the test project to the `central_path`. + """ + pathtable, project = pathtable_and_project + + ssh_test_utils.setup_project_for_ssh( + project, + ) + ssh_test_utils.setup_ssh_connection(project) + + project.upload_rawdata() + + return [pathtable, project] + + # ----------------------------------------------------------------- + # Test Setup SSH Connection + # ----------------------------------------------------------------- + + @pytest.mark.parametrize( + "sub_names", [["all"], ["all_non_sub", "sub-002"]] + ) + @pytest.mark.parametrize( + "ses_names", [["all"], ["ses-002_random-key"], ["all_non_ses"]] + ) + @pytest.mark.parametrize( + "datatype", [["all"], ["anat", "all_non_datatype"]] + ) + def test_combinations_ssh_transfer( + self, + ssh_setup, + sub_names, + ses_names, + datatype, + ): + """ + Test a subset of argument combinations while testing over SSH connection + to a container. This is very slow, due to the rclone ssh transfer (which + is performed twice in this test, once for upload, once for download), around + 8 seconds per parameterization. + + In test setup, the entire project is created in the `local_path` and + is uploaded to `central_path`. So we only need to set up once per test, + upload and download is to temporary folders and these temporary folders + are cleaned at the end of each parameterization. + """ + pathtable, project = ssh_setup + + # Upload data from the setup local project to a temporary + # central directory. + true_central_path = project.cfg["central_path"] + tmp_central_path = ( + project.cfg["central_path"] / "tmp" / project.project_name + ) + self.remake_logging_path(project) + + project.update_config_file(central_path=tmp_central_path) + + project.upload_custom( + "rawdata", sub_names, ses_names, datatype, init_log=False + ) + + expected_transferred_paths = self.get_expected_transferred_paths( + pathtable, sub_names, ses_names, datatype + ) + + # Search the paths that were transferred and tidy them up, + # then check against the paths that were expected to be transferred. + transferred_files = ssh_test_utils.recursive_search_central(project) + paths_to_transferred_files = self.remove_path_before_rawdata( + transferred_files + ) + + assert sorted(paths_to_transferred_files) == sorted( + expected_transferred_paths + ) + + # Now, move data from the central path where the project is + # setup, to a temp local folder to test download. + true_local_path = project.cfg["local_path"] + tmp_local_path = ( + project.cfg["local_path"] / "tmp" / project.project_name + ) + tmp_local_path.mkdir(exist_ok=True, parents=True) + + project.update_config_file(local_path=tmp_local_path) + project.update_config_file(central_path=true_central_path) + + project.download_custom( + "rawdata", sub_names, ses_names, datatype, init_log=False + ) + + # Find the transferred paths, tidy them up + # and check expected paths were transferred. + all_transferred = list((tmp_local_path / "rawdata").glob("**/*")) + all_transferred = [ + path_ for path_ in all_transferred if path_.is_file() + ] + + paths_to_transferred_files = self.remove_path_before_rawdata( + all_transferred + ) + + assert sorted(paths_to_transferred_files) == sorted( + expected_transferred_paths + ) + + # Clean up, removing the temp directories and + # resetting the project paths. + with paramiko.SSHClient() as client: + ssh.connect_client_core(client, project.cfg) + client.exec_command(f"rm -rf {(tmp_central_path).as_posix()}") + + shutil.rmtree(tmp_local_path) + + self.remake_logging_path(project) + project.update_config_file(local_path=true_local_path) + + def remake_logging_path(self, project): + """ + Need to do this to compensate for switching + local_path location in the test environment. + """ + project.get_logging_path().mkdir(parents=True, exist_ok=True) diff --git a/tests/tests_tui/__init__.py b/tests/tests_tui/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_tui/test_local_only_project.py b/tests/tests_tui/test_local_only_project.py index 8a06ccfa7..a9d47594c 100644 --- a/tests/tests_tui/test_local_only_project.py +++ b/tests/tests_tui/test_local_only_project.py @@ -1,8 +1,9 @@ import pytest -from tui_base import TuiBase from datashuttle.tui.app import TuiApp +from .tui_base import TuiBase + class TestTuiLocalOnlyProject(TuiBase): @pytest.mark.asyncio diff --git a/tests/tests_tui/test_tui_configs.py b/tests/tests_tui/test_tui_configs.py index c2e110b65..7a367f002 100644 --- a/tests/tests_tui/test_tui_configs.py +++ b/tests/tests_tui/test_tui_configs.py @@ -3,8 +3,6 @@ from time import monotonic import pytest -import test_utils -from tui_base import TuiBase from datashuttle.configs import load_configs from datashuttle.tui.app import TuiApp @@ -13,6 +11,9 @@ ) from datashuttle.tui.screens.project_manager import ProjectManagerScreen +from .. import test_utils +from .tui_base import TuiBase + class TestTuiConfigs(TuiBase): # ------------------------------------------------------------------------- diff --git a/tests/tests_tui/test_tui_create_folders.py b/tests/tests_tui/test_tui_create_folders.py index b7d3822ee..8a837abe8 100644 --- a/tests/tests_tui/test_tui_create_folders.py +++ b/tests/tests_tui/test_tui_create_folders.py @@ -1,8 +1,6 @@ import re import pytest -import test_utils -from tui_base import TuiBase from datashuttle.configs import canonical_configs from datashuttle.tui.app import TuiApp @@ -11,6 +9,9 @@ ) from datashuttle.tui.screens.project_manager import ProjectManagerScreen +from .. import test_utils +from .tui_base import TuiBase + class TestTuiCreateFolders(TuiBase): # ------------------------------------------------------------------------- diff --git a/tests/tests_tui/test_tui_datatypes.py b/tests/tests_tui/test_tui_datatypes.py index 76a438f35..eafd33a89 100644 --- a/tests/tests_tui/test_tui_datatypes.py +++ b/tests/tests_tui/test_tui_datatypes.py @@ -1,10 +1,11 @@ import pytest -import test_utils -from tui_base import TuiBase from datashuttle.configs import canonical_configs from datashuttle.tui.app import TuiApp +from .. import test_utils +from .tui_base import TuiBase + class TestDatatypesTUI(TuiBase): """Test the datatype selection screen for the Create and Transfer tab.""" diff --git a/tests/tests_tui/test_tui_directorytree.py b/tests/tests_tui/test_tui_directorytree.py index 2bcd33ec8..a5ba24af1 100644 --- a/tests/tests_tui/test_tui_directorytree.py +++ b/tests/tests_tui/test_tui_directorytree.py @@ -2,10 +2,11 @@ import pyperclip import pytest -from tui_base import TuiBase from datashuttle.tui.app import TuiApp +from .tui_base import TuiBase + try: pyperclip.paste() HAS_GUI = True diff --git a/tests/tests_tui/test_tui_get_help.py b/tests/tests_tui/test_tui_get_help.py index aa8986b8c..c6e4a0587 100644 --- a/tests/tests_tui/test_tui_get_help.py +++ b/tests/tests_tui/test_tui_get_help.py @@ -1,8 +1,9 @@ import pytest -from tui_base import TuiBase from datashuttle.tui.app import TuiApp +from .tui_base import TuiBase + class TestTuiSettings(TuiBase): """Test that the 'Get Help' page from the main menu. diff --git a/tests/tests_tui/test_tui_logging.py b/tests/tests_tui/test_tui_logging.py index 33aa07978..350a574a3 100644 --- a/tests/tests_tui/test_tui_logging.py +++ b/tests/tests_tui/test_tui_logging.py @@ -1,10 +1,11 @@ import pytest -import test_utils -from tui_base import TuiBase from datashuttle.tui.app import TuiApp from datashuttle.tui.tabs.logging import RichLogScreen +from .. import test_utils +from .tui_base import TuiBase + class TestTuiLogging(TuiBase): @pytest.mark.asyncio diff --git a/tests/tests_tui/test_tui_selectdirectorytree.py b/tests/tests_tui/test_tui_selectdirectorytree.py index 1885ddc34..a7e488fc1 100644 --- a/tests/tests_tui/test_tui_selectdirectorytree.py +++ b/tests/tests_tui/test_tui_selectdirectorytree.py @@ -1,11 +1,12 @@ import pytest -from tui_base import TuiBase from datashuttle.tui.app import TuiApp from datashuttle.tui.screens.modal_dialogs import ( SelectDirectoryTreeScreen, ) +from .tui_base import TuiBase + class TestSelectTree(TuiBase): @pytest.mark.asyncio diff --git a/tests/tests_tui/test_tui_settings.py b/tests/tests_tui/test_tui_settings.py index 8b9bcaa4d..d460adf91 100644 --- a/tests/tests_tui/test_tui_settings.py +++ b/tests/tests_tui/test_tui_settings.py @@ -1,8 +1,9 @@ import pytest -from tui_base import TuiBase from datashuttle.tui.app import TuiApp +from .tui_base import TuiBase + class TestTuiSettings(TuiBase): """Test the 'Settings' screen accessible from the Main Menu.""" diff --git a/tests/tests_tui/test_tui_transfer.py b/tests/tests_tui/test_tui_transfer.py index af26264f6..03b036508 100644 --- a/tests/tests_tui/test_tui_transfer.py +++ b/tests/tests_tui/test_tui_transfer.py @@ -1,10 +1,11 @@ import pytest -import test_utils -from tui_base import TuiBase from datashuttle.configs import canonical_configs from datashuttle.tui.app import TuiApp +from .. import test_utils +from .tui_base import TuiBase + class TestTuiTransfer(TuiBase): """Test transferring through the TUI (entire project, top diff --git a/tests/tests_tui/test_tui_validate.py b/tests/tests_tui/test_tui_validate.py index 9bf4f7737..f9e0931c5 100644 --- a/tests/tests_tui/test_tui_validate.py +++ b/tests/tests_tui/test_tui_validate.py @@ -2,11 +2,12 @@ import pytest import textual -from tui_base import TuiBase import datashuttle from datashuttle.tui.app import TuiApp +from .tui_base import TuiBase + class TestTuiValidate(TuiBase): @pytest.mark.asyncio diff --git a/tests/tests_tui/test_tui_widgets_and_defaults.py b/tests/tests_tui/test_tui_widgets_and_defaults.py index e37117f31..13d9511a0 100644 --- a/tests/tests_tui/test_tui_widgets_and_defaults.py +++ b/tests/tests_tui/test_tui_widgets_and_defaults.py @@ -2,8 +2,6 @@ from typing import Union import pytest -import test_utils -from tui_base import TuiBase from datashuttle.configs import canonical_configs from datashuttle.tui.app import TuiApp @@ -12,6 +10,9 @@ ) from datashuttle.tui.screens.new_project import NewProjectScreen +from .. import test_utils +from .tui_base import TuiBase + class TestTuiWidgets(TuiBase): """Performs fundamental checks on the default display diff --git a/tests/tests_tui/tui_base.py b/tests/tests_tui/tui_base.py index b5b7e88d8..2701f758b 100644 --- a/tests/tests_tui/tui_base.py +++ b/tests/tests_tui/tui_base.py @@ -1,11 +1,12 @@ import pytest_asyncio -import test_utils from textual.widgets._tabbed_content import ContentTab from datashuttle.configs import canonical_configs from datashuttle.tui.screens.project_manager import ProjectManagerScreen from datashuttle.tui.screens.project_selector import ProjectSelectorScreen +from .. import test_utils + class TuiBase: """Contains fixtuers and helper functions for TUI tests.""" diff --git a/tests/tests_unit/__init__.py b/tests/tests_unit/__init__.py new file mode 100644 index 000000000..e69de29bb