Skip to content

Commit 699c22e

Browse files
committed
Tidy ups.
1 parent 7aee25e commit 699c22e

3 files changed

Lines changed: 90 additions & 42 deletions

File tree

tests/ssh_test_utils.py

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import builtins
22
import copy
3+
import platform
34
import stat
4-
5+
import subprocess
6+
import sys
7+
import warnings
8+
import os
9+
import subprocess
10+
from pathlib import Path
511
import paramiko
612

713
from datashuttle.utils import rclone, ssh
@@ -49,33 +55,54 @@ def restore_mock_input(orig_builtin):
4955
builtins.input = orig_builtin
5056

5157

52-
def setup_hostkeys(project):
58+
def setup_hostkeys(project, setup_ssh_key_pair=True): # TODO: RENAME FUNCTION
5359
"""
5460
Convenience function to verify the server hostkey.
61+
62+
This requires monkeypatching a number of functions involved
63+
in the SSH setup process. `input()` is patched to always
64+
return the required hostkey confirmation "y". `getpass()` is
65+
patched to allways return the password for the container in which
66+
SSH tests are run. `isatty()` is patched because when running this
67+
for some reason it appears to be in a TTY - this might be a
68+
container thing.
5569
"""
70+
# Monkeypatch
5671
orig_builtin = setup_mock_input(input_="y")
57-
ssh.verify_ssh_central_host(
58-
project.cfg["central_host_id"], project.cfg.hostkeys_path, log=True
59-
)
60-
restore_mock_input(orig_builtin)
6172

6273
orig_getpass = copy.deepcopy(ssh.getpass.getpass)
6374
ssh.getpass.getpass = lambda _: "password" # type: ignore
6475

65-
ssh.setup_ssh_key(project.cfg, log=False)
76+
orig_isatty = copy.deepcopy(sys.stdin.isatty)
77+
sys.stdin.isatty = lambda: True
78+
79+
# Run setup
80+
verified = ssh.verify_ssh_central_host(
81+
project.cfg["central_host_id"], project.cfg.hostkeys_path, log=True
82+
)
83+
84+
if setup_ssh_key_pair:
85+
ssh.setup_ssh_key(project.cfg, log=False)
86+
87+
# Restore functions
88+
restore_mock_input(orig_builtin)
6689
ssh.getpass.getpass = orig_getpass
90+
sys.stdin.isatty = orig_isatty
91+
92+
return verified
6793

6894

6995
def build_docker_image(project):
70-
import os
71-
import subprocess
72-
from pathlib import Path
96+
""""""
97+
container_software = is_docker_or_singularity_installed()
98+
assert container_software is not False, ("docker or singularity not installed, "
99+
"this should be checked at the top of test script")
73100

74101
image_path = Path(__file__).parent / "ssh_test_images"
75102
os.chdir(image_path)
76-
subprocess.run("docker build -t ssh_server .", shell=True)
103+
subprocess.run(f"{container_software} build -t ssh_server .", shell=True)
77104
subprocess.run(
78-
"docker run -d -p 22:22 ssh_server", shell=True
105+
f"{container_software} run -d -p 22:22 ssh_server", shell=True
79106
) # ; docker build -t ssh_server .", shell=True) # ;docker run -p 22:22 ssh_server
80107

81108
setup_project_for_ssh(
@@ -118,3 +145,32 @@ def recursive_search_central(project):
118145
all_filenames,
119146
)
120147
return all_filenames
148+
149+
150+
def get_test_ssh():
151+
""""""
152+
if is_docker_or_singularity_installed():
153+
test_ssh = True
154+
else:
155+
warnings.warn("SSH tests are not run as docker (Windows, macOS) "
156+
"or singularity (Linux) is not installed.")
157+
test_ssh = False
158+
159+
return test_ssh
160+
161+
162+
def is_docker_or_singularity_installed(): # TODO: need to test
163+
""""""
164+
check_install = lambda command: subprocess.run(
165+
command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
166+
).returncode == 0
167+
168+
installed = False
169+
if platform.system() == "Linux":
170+
if check_install("singularity version"):
171+
installed = "singularity"
172+
else:
173+
if check_install("docker -v"):
174+
installed = "docker"
175+
176+
return installed

tests/tests_integration/test_ssh_file_transfer.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111
import test_utils
1212
from file_conflicts_pathtable import get_pathtable
1313

14-
# from pytest import ssh_config
1514
from datashuttle.utils import ssh
1615

17-
TEST_SSH = True # TODO: base on whether docker / singularity is installed.
16+
TEST_SSH = ssh_test_utils.get_test_ssh()
1817

1918

2019
PARAM_SUBS = [
@@ -142,24 +141,28 @@ def test_combinations_filesystem_transfer(
142141
except FileNotFoundError:
143142
pass
144143

145-
@pytest.mark.parametrize("sub_names", PARAM_SUBS)
146-
@pytest.mark.parametrize("ses_names", PARAM_SES)
147-
@pytest.mark.parametrize("datatype", PARAM_DATATYPE)
144+
@pytest.mark.skipif("not TEST_SSH", reason="TEST_SSH is false")
145+
@pytest.mark.parametrize("sub_names", [["all"], ["all_non_sub", "sub-002"]])
146+
@pytest.mark.parametrize("ses_names", [["all"], ["ses-002_random-key"], ["all_non_ses"]])
147+
@pytest.mark.parametrize("datatype", [["all"], ["anat", "all_ses_level_non_datatype"]])
148148
def test_combinations_ssh_transfer(
149149
self,
150150
ssh_setup,
151151
sub_names,
152152
ses_names,
153153
datatype,
154154
):
155-
""" """
155+
"""
156+
This is very slow. maybe 8 s per test. Nearly all in the
157+
upload() and download() part so unavoidable. Most code is shared,
158+
this should be okay though my paranoid
159+
"""
156160
pathtable, project = ssh_setup
157161

158162
true_central_path = project.cfg["central_path"]
159-
tmp_central_path = project.cfg["central_path"] / "tmp"
163+
tmp_central_path = project.cfg["central_path"] / "tmp" / project.project_name
160164
project.update_config("central_path", tmp_central_path)
161165

162-
breakpoint()
163166
project.upload(sub_names, ses_names, datatype, init_log=False)
164167

165168
expected_transferred_paths = self.get_expected_transferred_paths(
@@ -184,17 +187,18 @@ def test_combinations_ssh_transfer(
184187
ssh.connect_client(client, project.cfg)
185188
client.exec_command(
186189
f"rm -rf {(tmp_central_path).as_posix()}"
187-
) # TODO: own function as need to do on teardown)
190+
)
188191

189192
true_local_path = project.cfg["local_path"]
190-
tmp_local_path = project.cfg["local_path"] / "tmp"
191-
tmp_local_path.mkdir()
193+
tmp_local_path = project.cfg["local_path"] / "tmp" / project.project_name
194+
tmp_local_path.mkdir(exist_ok=True, parents=True)
195+
192196
project.update_config("local_path", tmp_local_path)
193197
project.update_config("central_path", true_central_path)
194198

195199
project.download(
196200
sub_names, ses_names, datatype, init_log=False
197-
) # TODO: why is this connecting so many times? [during search - make issue]
201+
)
198202

199203
all_transferred = list((tmp_local_path / "rawdata").glob("**/*"))
200204
all_transferred = [

tests/tests_integration/test_ssh_setup.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
1-
"""
2-
SSH configs are set in conftest.py . The password
3-
should be stored in a file called test_ssh_password.txt located
4-
in the same folder as test_ssh.py
5-
"""
61
import pytest
72
import ssh_test_utils
83
import test_utils
94

105
# from pytest import ssh_config
116
from datashuttle.utils import ssh
127

13-
TEST_SSH = True
8+
TEST_SSH = ssh_test_utils.get_test_ssh()
149

1510

16-
@pytest.mark.skipif(TEST_SSH is False, reason="TEST_SSH is false")
11+
@pytest.mark.skipif("not TEST_SSH", reason="TEST_SSH is false")
1712
class TestSSH:
1813
@pytest.fixture(scope="function")
1914
def project(test, tmp_path):
@@ -28,12 +23,7 @@ def project(test, tmp_path):
2823
tmp_path, test_project_name
2924
)
3025

31-
ssh_test_utils.setup_project_for_ssh(
32-
project,
33-
central_path=f"/home/sshuser/datashuttle/{project.project_name}", # TODO: centralise these
34-
central_host_id="localhost",
35-
central_host_username="sshuser",
36-
)
26+
ssh_test_utils.build_docker_image(project) # TODO: rename function
3727

3828
yield project
3929
test_utils.teardown_project(cwd, project)
@@ -70,16 +60,14 @@ def test_verify_ssh_central_host_accept(self, capsys, project):
7060
and check hostkey is successfully accepted and written to configs.
7161
"""
7262
test_utils.clear_capsys(capsys)
73-
orig_builtin = ssh_test_utils.setup_mock_input(input_="y")
7463

75-
verified = ssh.verify_ssh_central_host(
76-
project.cfg["central_host_id"], project.cfg.hostkeys_path, log=True
64+
verified = ssh_test_utils.setup_hostkeys(
65+
project, setup_ssh_key_pair=False
7766
)
7867

79-
ssh_test_utils.restore_mock_input(orig_builtin)
80-
8168
assert verified
8269
captured = capsys.readouterr()
70+
8371
assert captured.out == "Host accepted.\n"
8472

8573
with open(project.cfg.hostkeys_path, "r") as file:

0 commit comments

Comments
 (0)