Skip to content

Commit df2eebc

Browse files
committed
lots of changes, sort out.
1 parent 5c7d754 commit df2eebc

5 files changed

Lines changed: 159 additions & 116 deletions

File tree

tests/conftest.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

tests/ssh_test_images/Dockerfile

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Use a base image with the desired OS (e.g., Ubuntu, Debian, etc.)
2+
FROM ubuntu:latest
3+
# Install SSH server
4+
RUN apt-get update && \
5+
apt-get install -y openssh-server
6+
RUN apt-get install nano
7+
# Create an SSH user
8+
RUN useradd -rm -d /home/sshuser -s /bin/bash -g root -G sudo -u 1000 sshuser
9+
# Set the SSH user's password (replace "password" with your desired password)
10+
RUN echo 'sshuser:password' | chpasswd
11+
# Allow SSH access
12+
RUN mkdir /var/run/sshd
13+
# Expose the SSH port
14+
EXPOSE 22
15+
# Start SSH server on container startup
16+
CMD ["/usr/sbin/sshd", "-D"]

tests/ssh_test_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,29 @@ def setup_hostkeys(project):
5555
project.cfg["central_host_id"], project.cfg.hostkeys_path, log=True
5656
)
5757
restore_mock_input(orig_builtin)
58+
59+
orig_getpass = copy.deepcopy(ssh.getpass.getpass)
60+
ssh.getpass.getpass = lambda _: "password"
61+
62+
ssh.setup_ssh_key(project.cfg, log=False)
63+
ssh.getpass.getpass = orig_getpass
64+
65+
66+
def build_docker_image(project):
67+
import os
68+
import subprocess
69+
from pathlib import Path
70+
71+
image_path = Path(__file__).parent / "ssh_test_images"
72+
os.chdir(image_path)
73+
subprocess.run("docker build -t ssh_server .", shell=True)
74+
subprocess.run(
75+
"docker run -d -p 22:22 ssh_server", shell=True
76+
) # ; docker build -t ssh_server .", shell=True) # ;docker run -p 22:22 ssh_server
77+
78+
setup_project_for_ssh(
79+
project,
80+
central_path=f"/home/sshuser/datashuttle/{project.project_name}",
81+
central_host_id="localhost",
82+
central_host_username="sshuser",
83+
)

tests/tests_integration/test_ssh_file_transfer.py

Lines changed: 113 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,37 @@
11
"""
22
"""
33
import copy
4-
import glob
54
import shutil
6-
import time
5+
import stat
76
from pathlib import Path
87

98
import pandas as pd
9+
import paramiko
1010
import pytest
1111
import ssh_test_utils
1212
import test_utils
13-
from pytest import ssh_config
13+
14+
# from pytest import ssh_config
1415
from test_file_conflicts_pathtable import get_pathtable
1516

17+
from datashuttle.utils import ssh
18+
19+
TEST_SSH = True # TODO: base on whether docker / singularity is installed.
1620

1721
class TestFileTransfer:
1822
@pytest.fixture(
1923
scope="class",
20-
params=[ # Set running SSH or local filesystem (see docstring).
21-
False,
24+
params=[
25+
# False,
2226
pytest.param(
2327
True,
2428
marks=pytest.mark.skipif(
25-
ssh_config.TEST_SSH is False,
26-
reason="TEST_SSH is set to False.",
29+
TEST_SSH is False, reason="TEST_SSH is set to False."
2730
),
2831
),
2932
],
3033
)
31-
def pathtable_and_project(self, request, tmpdir_factory):
34+
def project_and_test_information(self, request, tmpdir_factory):
3235
"""
3336
Create a project for SSH testing. Setup
3437
the project as normal, and switch configs
@@ -72,44 +75,24 @@ def pathtable_and_project(self, request, tmpdir_factory):
7275
testing_ssh = request.param
7376
tmp_path = tmpdir_factory.mktemp("test")
7477

75-
if testing_ssh:
76-
base_path = ssh_config.FILESYSTEM_PATH
77-
central_path = ssh_config.SERVER_PATH
78-
else:
79-
base_path = tmp_path / "test with space"
80-
central_path = base_path
78+
base_path = tmp_path / "test with space"
8179
test_project_name = "test_file_conflicts"
8280

8381
project, cwd = test_utils.setup_project_fixture(
8482
base_path, test_project_name
8583
)
8684

8785
if testing_ssh:
88-
ssh_test_utils.setup_project_for_ssh(
89-
project,
90-
test_utils.make_test_path(
91-
central_path, test_project_name, "central"
92-
),
93-
ssh_config.CENTRAL_HOST_ID,
94-
ssh_config.USERNAME,
95-
)
96-
97-
# Initialise the SSH connection
86+
ssh_test_utils.build_docker_image(project)
9887
ssh_test_utils.setup_hostkeys(project)
99-
shutil.copy(ssh_config.SSH_KEY_PATH, project.cfg.file_path.parent)
10088

10189
pathtable = get_pathtable(project.cfg["local_path"])
10290
self.create_all_pathtable_files(pathtable)
103-
project.testing_ssh = testing_ssh
10491

105-
yield [pathtable, project]
92+
yield [pathtable, project, testing_ssh]
10693

10794
test_utils.teardown_project(cwd, project)
10895

109-
if testing_ssh:
110-
for result in glob.glob(ssh_config.FILESYSTEM_PATH):
111-
shutil.rmtree(result)
112-
11396
# -------------------------------------------------------------------------
11497
# Utils
11598
# -------------------------------------------------------------------------
@@ -156,14 +139,14 @@ def central_from_local(self, path_):
156139
["histology", "behav", "all_ses_level_non_datatype"],
157140
],
158141
)
159-
@pytest.mark.parametrize("upload_or_download", ["upload", "download"])
142+
# @pytest.mark.parametrize("upload_or_download", ["upload", "download"])
160143
def test_all_data_transfer_options(
161144
self,
162-
pathtable_and_project,
145+
project_and_test_information,
163146
sub_names,
164147
ses_names,
165148
datatype,
166-
upload_or_download,
149+
# upload_or_download,
167150
):
168151
"""
169152
Parse the arguments to filter the pathtable, getting
@@ -175,31 +158,31 @@ def test_all_data_transfer_options(
175158
on setting up and swapping local / central paths for
176159
upload / download tests.
177160
"""
178-
pathtable, project = pathtable_and_project
161+
pathtable, project, testing_ssh = project_and_test_information
179162

180-
transfer_function = test_utils.handle_upload_or_download(
181-
project,
182-
upload_or_download,
183-
swap_last_folder_only=project.testing_ssh,
184-
)[0]
163+
# transfer_function = test_utils.handle_upload_or_download(
164+
# project,
165+
# upload_or_download,
166+
# swap_last_folder_only=testing_ssh,
167+
# )[0]
185168

186-
transfer_function(sub_names, ses_names, datatype, init_log=False)
169+
project.upload(sub_names, ses_names, datatype, init_log=False)
170+
# transfer_function(sub_names, ses_names, datatype, init_log=False)
187171

188-
if upload_or_download == "download":
189-
test_utils.swap_local_and_central_paths(
190-
project, swap_last_folder_only=project.testing_ssh
191-
)
172+
# if upload_or_download == "download":
173+
# test_utils.swap_local_and_central_paths(
174+
# project, swap_last_folder_only=testing_ssh
175+
# )
192176

193-
sub_names = self.parse_arguments(pathtable, sub_names, "sub")
194-
ses_names = self.parse_arguments(pathtable, ses_names, "ses")
195-
datatype = self.parse_arguments(pathtable, datatype, "datatype")
177+
parsed_sub_names = self.parse_arguments(pathtable, sub_names, "sub")
178+
parsed_ses_names = self.parse_arguments(pathtable, ses_names, "ses")
179+
parsed_datatype = self.parse_arguments(pathtable, datatype, "datatype")
196180

197-
# Filter pathtable to get files that were expected
198-
# to be transferred
181+
# Filter pathtable to get files that were expected to be transferred
199182
(
200183
sub_ses_dtype_arguments,
201184
extra_arguments,
202-
) = self.make_pathtable_search_filter(sub_names, ses_names, datatype)
185+
) = self.make_pathtable_search_filter(parsed_sub_names, parsed_ses_names, parsed_datatype)
203186

204187
datatype_folders = self.query_table(pathtable, sub_ses_dtype_arguments)
205188
extra_folders = self.query_table(pathtable, extra_arguments)
@@ -214,28 +197,90 @@ def test_all_data_transfer_options(
214197

215198
# When transferring with SSH, there is a delay before
216199
# filesystem catches up
217-
if project.testing_ssh:
218-
time.sleep(0.5)
200+
# if testing_ssh:
201+
# time.sleep(0.5)
219202

220203
# Check what paths were actually moved
221204
# (through the local filesystem), and test
222-
path_to_search = (
223-
self.central_from_local(project.cfg["local_path"]) / "rawdata"
224-
)
225-
all_transferred = path_to_search.glob("**/*")
226-
paths_to_transferred_files = list(
227-
filter(Path.is_file, all_transferred)
228-
)
205+
def sftp_recursive_search(sftp, path_, all_filenames):
206+
try:
207+
sftp.stat(path_)
208+
except FileNotFoundError:
209+
return
210+
211+
for file_or_folder in sftp.listdir_attr(path_):
212+
if stat.S_ISDIR(file_or_folder.st_mode):
213+
sftp_recursive_search(
214+
sftp,
215+
path_ + "/" + file_or_folder.filename,
216+
all_filenames,
217+
)
218+
else:
219+
all_filenames.append(path_ + "/" + file_or_folder.filename)
220+
221+
with paramiko.SSHClient() as client:
222+
ssh.connect_client(client, project.cfg)
223+
224+
sftp = client.open_sftp()
225+
226+
all_filenames = []
227+
228+
sftp_recursive_search(
229+
sftp,
230+
(project.cfg["central_path"] / "rawdata").as_posix(),
231+
all_filenames,
232+
)
229233

230-
assert sorted(paths_to_transferred_files) == sorted(
231-
expected_transferred_paths
232-
)
234+
paths_to_transferred_files = []
235+
for path_ in all_filenames:
236+
parts = Path(path_).parts
237+
paths_to_transferred_files.append(
238+
Path(*parts[parts.index("rawdata") :])
239+
)
240+
241+
expected_transferred_paths_ = []
242+
for path_ in expected_transferred_paths:
243+
parts = Path(path_).parts
244+
expected_transferred_paths_.append(
245+
Path(*parts[parts.index("rawdata") :])
246+
)
247+
248+
assert sorted(paths_to_transferred_files) == sorted(
249+
expected_transferred_paths_
250+
)
251+
252+
project.upload_all()
253+
shutil.rmtree(project.cfg["local_path"] / "rawdata") # TOOD: var
254+
255+
breakpoint()
256+
257+
true_local_path = project.cfg["local_path"]
258+
tmp_local_path = project.cfg["local_path"] / "tmp_local"
259+
tmp_local_path.mkdirs()
260+
project.update_config("local_path", tmp_local_path)
261+
262+
project.download(sub_names, ses_names, datatype, init_log=False) # TODO: why is this connecting so many times?
263+
264+
all_transferred = list((project.cfg["local_path"] / "rawdata").glob("**/*"))
265+
all_transferred = [path_ for path_ in all_transferred if path_.is_file()]
266+
267+
paths_to_transferred_files = []
268+
for path_ in all_transferred: # TODO: rename all filenames
269+
parts = Path(path_).parts
270+
paths_to_transferred_files.append(
271+
Path(*parts[parts.index("rawdata"):])
272+
)
273+
274+
assert sorted(paths_to_transferred_files) == sorted(expected_transferred_paths_)
275+
276+
shutil.rmtree(project.cfg["local_path"]) # TOOD: var
277+
278+
project.update_config("local_path", true_local_path)
279+
280+
with paramiko.SSHClient() as client:
281+
ssh.connect_client(client, project.cfg)
233282

234-
# Teardown here, because we have session scope.
235-
try:
236-
shutil.rmtree(self.central_from_local(project.cfg["local_path"]))
237-
except FileNotFoundError:
238-
pass
283+
client.exec_command(f"rm -rf {(project.cfg['central_path'] / 'rawdata').as_posix()}") # TODO: own function as need to do on teardown)
239284

240285
# ---------------------------------------------------------------------------------------------------------------
241286
# Utils

tests/tests_integration/test_ssh_setup.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
import pytest
77
import ssh_test_utils
88
import test_utils
9-
from pytest import ssh_config
109

10+
# from pytest import ssh_config
1111
from datashuttle.utils import ssh
1212

13+
TEST_SSH = False
1314

14-
@pytest.mark.skipif(ssh_config.TEST_SSH is False, reason="TEST_SSH is false")
15+
16+
@pytest.mark.skipif(TEST_SSH is False, reason="TEST_SSH is false")
1517
class TestSSH:
1618
@pytest.fixture(scope="function")
1719
def project(test, tmp_path):

0 commit comments

Comments
 (0)