From 085397d12e24d5af3b25bd62541e7fbc4bfa1c82 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 22 Apr 2024 14:46:05 +0100 Subject: [PATCH] Teardown image at end of ssh tests, factor out ssh tests. --- tests/ssh_test_utils.py | 45 ++++ tests/tests_integration/base.py | 107 +------- tests/tests_integration/base_transfer.py | 166 ++++++++++++ tests/tests_integration/test_ssh.py | 246 ++++++++++++++++++ .../test_ssh_file_transfer.py | 243 +---------------- tests/tests_integration/test_ssh_setup.py | 97 ------- 6 files changed, 460 insertions(+), 444 deletions(-) create mode 100644 tests/tests_integration/base_transfer.py create mode 100644 tests/tests_integration/test_ssh.py delete mode 100644 tests/tests_integration/test_ssh_setup.py diff --git a/tests/ssh_test_utils.py b/tests/ssh_test_utils.py index 7391d853..e5555559 100644 --- a/tests/ssh_test_utils.py +++ b/tests/ssh_test_utils.py @@ -4,15 +4,21 @@ import builtins import copy +import os +import platform import stat import subprocess import sys import warnings +from pathlib import Path import paramiko from datashuttle.utils import rclone, ssh +PORT = 3306 # https://github.com/orgs/community/discussions/25550 +os.environ["DS_SSH_PORT"] = str(PORT) + def setup_project_for_ssh( project, central_path, central_host_id, central_host_username @@ -89,6 +95,45 @@ def setup_ssh_connection(project, setup_ssh_key_pair=True): return verified +def setup_ssh_container(container_name): + """""" + assert docker_is_running(), ( + "docker is not running, " + "this should be checked at the top of test script" + ) + + image_path = Path(__file__).parent / "ssh_test_images" + os.chdir(image_path) + + if platform.system() == "Linux": + build_command = "sudo docker build -t ssh_server ." + run_command = f"sudo docker run -d -p {PORT}:22 --name {container_name} ssh_server" + else: + build_command = "docker build ." + run_command = ( + f"docker run -d -p {PORT}:22 --name {container_name} ssh_server" + ) + + build_output = subprocess.run( + build_command, + shell=True, + capture_output=True, + ) + assert ( + build_output.returncode == 0 + ), f"docker build failed with: STDOUT-{build_output.stdout} STDERR-{build_output.stderr}" + + run_output = subprocess.run( + run_command, + shell=True, + capture_output=True, + ) + + assert ( + run_output.returncode == 0 + ), f"docker run failed with: STDOUT-{run_output.stdout} STDERR-{run_output.stderr}" + + def sftp_recursive_file_search(sftp, path_, all_filenames): try: sftp.stat(path_) diff --git a/tests/tests_integration/base.py b/tests/tests_integration/base.py index 479a1e5d..6364c8e2 100644 --- a/tests/tests_integration/base.py +++ b/tests/tests_integration/base.py @@ -1,13 +1,8 @@ import os -import platform -import subprocess import warnings -from pathlib import Path import pytest -import ssh_test_utils import test_utils -from file_conflicts_pathtable import get_pathtable from datashuttle.datashuttle import DataShuttle @@ -15,6 +10,7 @@ class BaseTest: + @pytest.fixture(scope="function") def no_cfg_project(test): """ @@ -61,104 +57,3 @@ def clean_project_name(self): test_utils.delete_project_if_it_exists(project_name) yield project_name test_utils.delete_project_if_it_exists(project_name) - - @pytest.fixture( - scope="class", - ) - def pathtable_and_project(self, tmpdir_factory): - """ - Create a new test project with a test project folder - and file structure (see `get_pathtable()` for definition). - """ - tmp_path = tmpdir_factory.mktemp("test") - - base_path = tmp_path / "test with space" - test_project_name = "test_file_conflicts" - - project, cwd = test_utils.setup_project_fixture( - base_path, test_project_name - ) - - pathtable = get_pathtable(project.cfg["local_path"]) - - self.create_all_pathtable_files(pathtable) - - yield [pathtable, project] - - test_utils.teardown_project(cwd, project) - - @pytest.fixture( - scope="session", - ) - def setup_ssh_container(self): - """""" - PORT = 3306 # https://github.com/orgs/community/discussions/25550 - os.environ["DS_SSH_PORT"] = str(PORT) - - assert ssh_test_utils.docker_is_running(), ( - "docker is not running, " - "this should be checked at the top of test script" - ) - - image_path = Path(__file__).parent.parent / "ssh_test_images" - os.chdir(image_path) - - if platform.system() == "Linux": - build_command = "sudo docker build -t ssh_server ." - run_command = f"sudo docker run -d -p {PORT}:22 ssh_server" - else: - build_command = "docker build ." - run_command = f"docker run -d -p {PORT}:22 ssh_server" - - build_output = subprocess.run( - build_command, - shell=True, - capture_output=True, - ) - assert build_output.returncode == 0, ( - f"docker build failed with: STDOUT-{build_output.stdout} STDERR-" - f"{build_output.stderr}" - ) - - run_output = subprocess.run( - run_command, - shell=True, - capture_output=True, - ) - - assert run_output.returncode == 0, ( - f"docker run failed with: STDOUT-{run_output.stdout} STDE" - f"RR-{run_output.stderr}" - ) - - # setup_project_for_ssh( - # project, - # central_path=f"/home/sshuser/datashuttle/{project.project_name}", - # central_host_id="localhost", - # central_host_username="sshuser", - # ) - - @pytest.fixture( - scope="class", - ) - def ssh_setup(self, pathtable_and_project, setup_ssh_container): - """ - After initial project setup (in `pathtable_and_project`) - setup a container and the project's SSH connection to the container. - Then upload the test project to the `central_path`. - """ - pathtable, project = pathtable_and_project - - ssh_test_utils.setup_project_for_ssh( - project, - central_path=f"/home/sshuser/datashuttle/{project.project_name}", - central_host_id="localhost", - central_host_username="sshuser", - ) - - # ssh_test_utils.setup_project_and_container_for_ssh(project) - ssh_test_utils.setup_ssh_connection(project) - - project.upload_rawdata() - - return [pathtable, project] diff --git a/tests/tests_integration/base_transfer.py b/tests/tests_integration/base_transfer.py new file mode 100644 index 00000000..3af49def --- /dev/null +++ b/tests/tests_integration/base_transfer.py @@ -0,0 +1,166 @@ +""" +""" + +import copy +from pathlib import Path + +import pandas as pd +import pytest +import test_utils +from base import BaseTest +from file_conflicts_pathtable import get_pathtable + + +class BaseTransfer(BaseTest): + + # ---------------------------------------------------------------------------------- + # Test File Transfer - All Options + # ---------------------------------------------------------------------------------- + + @pytest.fixture( + scope="class", + ) + def pathtable_and_project(self, tmpdir_factory): + """ + Create a new test project with a test project folder + and file structure (see `get_pathtable()` for definition). + """ + tmp_path = tmpdir_factory.mktemp("test") + + base_path = tmp_path / "test with space" + test_project_name = "test_file_conflicts" + + project, cwd = test_utils.setup_project_fixture( + base_path, test_project_name + ) + + pathtable = get_pathtable(project.cfg["local_path"]) + + self.create_all_pathtable_files(pathtable) + + yield [pathtable, project] + + test_utils.teardown_project(cwd, project) + + def get_expected_transferred_paths( + self, pathtable, sub_names, ses_names, datatype + ): + """ + Process the expected files that are transferred using the logic in + `make_pathtable_search_filter()` to + """ + parsed_sub_names = self.parse_arguments(pathtable, sub_names, "sub") + parsed_ses_names = self.parse_arguments(pathtable, ses_names, "ses") + parsed_datatype = self.parse_arguments(pathtable, datatype, "datatype") + + # Filter pathtable to get files that were expected to be transferred + ( + sub_ses_dtype_arguments, + extra_arguments, + ) = self.make_pathtable_search_filter( + parsed_sub_names, parsed_ses_names, parsed_datatype + ) + + datatype_folders = self.query_table(pathtable, sub_ses_dtype_arguments) + extra_folders = self.query_table(pathtable, extra_arguments) + + expected_paths = pd.concat([datatype_folders, extra_folders]) + expected_paths = expected_paths.drop_duplicates(subset="path") + + expected_paths = self.remove_path_before_rawdata(expected_paths.path) + + return expected_paths + + def make_pathtable_search_filter(self, sub_names, ses_names, datatype): + """ + Create a string of arguments to pass to pd.query() that will + create the table of only transferred sub, ses and datatype. + + Two arguments must be created, one of all sub / ses / datatypes + and the other of all non sub/ non ses / non datatype + folders. These must be handled separately as they are + mutually exclusive. + """ + sub_ses_dtype_arguments = [] + extra_arguments = [] + + for sub in sub_names: + if sub == "all_non_sub": + extra_arguments += ["is_non_sub == True"] + else: + for ses in ses_names: + if ses == "all_non_ses": + extra_arguments += [ + f"(parent_sub == '{sub}' & is_non_ses == True)" + ] + else: + for dtype in datatype: + if dtype == "all_non_datatype": + extra_arguments += [ + f"(parent_sub == '{sub}' & parent_ses == '{ses}' " + f"& is_ses_level_non_datatype == True)" + ] + else: + sub_ses_dtype_arguments += [ + f"(parent_sub == '{sub}' & parent_ses == '{ses}' " + f"& (parent_datatype == '{dtype}' " + f"| parent_datatype == '{dtype}'))" + ] + + return sub_ses_dtype_arguments, extra_arguments + + def remove_path_before_rawdata(self, list_of_paths): + """ + Remove the path to project files before the "rawdata" so + they can be compared no matter where the project was stored + (e.g. on a central server vs. local filesystem). + """ + cut_paths = [] + for path_ in list_of_paths: + parts = Path(path_).parts + cut_paths.append(Path(*parts[parts.index("rawdata") :])) + return cut_paths + + def query_table(self, pathtable, arguments): + """ + Search the table for arguments, return empty + if arguments empty + """ + if any(arguments): + folders = pathtable.query(" | ".join(arguments)) + else: + folders = pd.DataFrame() + return folders + + def parse_arguments(self, pathtable, list_of_names, field): + """ + Replicate datashuttle name formatting by parsing + "all" arguments and turning them into a list of all names, + (subject or session), taken from the pathtable. + """ + if list_of_names in [["all"], [f"all_{field}"]]: + entries = pathtable.query(f"parent_{field} != False")[ + f"parent_{field}" + ] + entries = list(set(entries)) + if list_of_names == ["all"]: + entries += ( + [f"all_non_{field}"] + if field != "datatype" + else ["all_non_datatype"] + ) + list_of_names = entries + return list_of_names + + def create_all_pathtable_files(self, pathtable): + """ + Create the entire test project in the defined + location (usually project's `local_path`). + """ + for i in range(pathtable.shape[0]): + filepath = pathtable["base_folder"][i] / pathtable["path"][i] + filepath.parents[0].mkdir(parents=True, exist_ok=True) + test_utils.write_file(filepath, contents="test_entry") + + def central_from_local(self, path_): + return Path(str(copy.copy(path_)).replace("local", "central")) diff --git a/tests/tests_integration/test_ssh.py b/tests/tests_integration/test_ssh.py new file mode 100644 index 00000000..d79faf72 --- /dev/null +++ b/tests/tests_integration/test_ssh.py @@ -0,0 +1,246 @@ +import shutil +import subprocess + +import paramiko +import pytest +import ssh_test_utils +import test_utils +from base_transfer import BaseTransfer + +# from pytest import ssh_config +from datashuttle.utils import ssh + +TEST_SSH = ssh_test_utils.get_test_ssh() + + +@pytest.mark.skipif("not TEST_SSH", reason="TEST_SSH is false") +class TestSSH(BaseTransfer): + + @pytest.fixture( + scope="session", + ) + def setup_ssh_container(self): + # Annoying session scope does not seem to actually work + container_name = "running_ssh_tests" + ssh_test_utils.setup_ssh_container(container_name) + yield + subprocess.run(f"docker stop {container_name}") + subprocess.run(f"docker rm {container_name}") + + @pytest.fixture( + scope="class", + ) + def ssh_setup(self, pathtable_and_project, setup_ssh_container): + """ + After initial project setup (in `pathtable_and_project`) + setup a container and the project's SSH connection to the container. + Then upload the test project to the `central_path`. + """ + pathtable, project = pathtable_and_project + + ssh_test_utils.setup_project_for_ssh( + project, + central_path=f"/home/sshuser/datashuttle/{project.project_name}", + central_host_id="localhost", + central_host_username="sshuser", + ) + + ssh_test_utils.setup_ssh_connection(project) + + project.upload_rawdata() + + return [pathtable, project] + + @pytest.fixture(scope="function") + def project(test, tmp_path, setup_ssh_container): + """ + Make a project as per usual, but now add + in test ssh configurations + """ + tmp_path = tmp_path / "test with space" + + test_project_name = "test_ssh" + project, cwd = test_utils.setup_project_fixture( + tmp_path, test_project_name + ) + ssh_test_utils.setup_project_for_ssh( + project, + central_path=f"/home/sshuser/datashuttle/{project.project_name}", + central_host_id="localhost", + central_host_username="sshuser", + ) + yield project + test_utils.teardown_project(cwd, project) + + # ----------------------------------------------------------------- + # Test Setup SSH Connection + # ----------------------------------------------------------------- + + @pytest.mark.parametrize("input_", ["n", "o", "@"]) + def test_verify_ssh_central_host_do_not_accept( + self, capsys, project, input_ + ): + """ + Use the main function to test this. Test the sub-function + when accepting, because this main function will also + call setup ssh key pairs which we don't want to do yet + + This should only accept for "y" so try some random strings + including "n" and check they all do not make the connection. + """ + orig_builtin = ssh_test_utils.setup_mock_input(input_) + + project.setup_ssh_connection() + + ssh_test_utils.restore_mock_input(orig_builtin) + + captured = capsys.readouterr() + + assert "Host not accepted. No connection made.\n" in captured.out + + def test_verify_ssh_central_host_accept(self, capsys, project): + """ + User is asked to accept the server hostkey. Mock this here + and check hostkey is successfully accepted and written to configs. + """ + test_utils.clear_capsys(capsys) + + verified = ssh_test_utils.setup_ssh_connection( + project, setup_ssh_key_pair=False + ) + + assert verified + captured = capsys.readouterr() + + assert captured.out == "Host accepted.\n" + + with open(project.cfg.hostkeys_path, "r") as file: + hostkey = file.readlines()[0] + + assert ( + f"[{project.cfg['central_host_id']}]:3306 ssh-ed25519 " in hostkey + ) + + def test_generate_and_write_ssh_key(self, project): + """ + Check ssh key for passwordless connection is written + to file + """ + path_to_save = project.cfg["local_path"] / "test" + ssh.generate_and_write_ssh_key(path_to_save) + + with open(path_to_save, "r") as file: + first_line = file.readlines()[0] + + assert first_line == "-----BEGIN RSA PRIVATE KEY-----\n" + + # ----------------------------------------------------------------- + # Test Setup SSH Connection + # ----------------------------------------------------------------- + + @pytest.mark.skipif("not TEST_SSH", reason="TEST_SSH is false") + @pytest.mark.parametrize( + "sub_names", [["all"], ["all_non_sub", "sub-002"]] + ) + @pytest.mark.parametrize( + "ses_names", [["all"], ["ses-002_random-key"], ["all_non_ses"]] + ) + @pytest.mark.parametrize( + "datatype", [["all"], ["anat", "all_non_datatype"]] + ) + def test_combinations_ssh_transfer( + self, + ssh_setup, + sub_names, + ses_names, + datatype, + ): + """ + Test a subset of argument combinations while testing over SSH connection + to a container. This is very slow, due to the rclone ssh transfer (which + is performed twice in this test, once for upload, once for download), around + 8 seconds per parameterization. + + In test setup, the entire project is created in the `local_path` and + is uploaded to `central_path`. So we only need to set up once per test, + upload and download is to temporary folders and these temporary folders + are cleaned at the end of each parameterization. + """ + pathtable, project = ssh_setup + + # Upload data from the setup local project to a temporary + # central directory. + true_central_path = project.cfg["central_path"] + tmp_central_path = ( + project.cfg["central_path"] / "tmp" / project.project_name + ) + project.get_logging_path().mkdir( + parents=True, exist_ok=True + ) # TODO: why is this necessary + + project.update_config_file(central_path=tmp_central_path) + + project.upload_custom( + "rawdata", sub_names, ses_names, datatype, init_log=False + ) + + expected_transferred_paths = self.get_expected_transferred_paths( + pathtable, sub_names, ses_names, datatype + ) + + # Search the paths that were transferred and tidy them up, + # then check against the paths that were expected to be transferred. + transferred_files = ssh_test_utils.recursive_search_central(project) + paths_to_transferred_files = self.remove_path_before_rawdata( + transferred_files + ) + + assert sorted(paths_to_transferred_files) == sorted( + expected_transferred_paths + ) + + # Now, move data from the central path where the project is + # setup, to a temp local folder to test download. + true_local_path = project.cfg["local_path"] + tmp_local_path = ( + project.cfg["local_path"] / "tmp" / project.project_name + ) + tmp_local_path.mkdir(exist_ok=True, parents=True) + + project.update_config_file(local_path=tmp_local_path) + project.update_config_file(central_path=true_central_path) + + project.download_custom( + "rawdata", sub_names, ses_names, datatype, init_log=False + ) + + # Find the transferred paths, tidy them up + # and check expected paths were transferred. + all_transferred = list((tmp_local_path / "rawdata").glob("**/*")) + all_transferred = [ + path_ for path_ in all_transferred if path_.is_file() + ] + + paths_to_transferred_files = self.remove_path_before_rawdata( + all_transferred + ) + + assert sorted(paths_to_transferred_files) == sorted( + expected_transferred_paths + ) + + # Clean up, removing the temp directories and + # resetting the project paths. + with paramiko.SSHClient() as client: + ssh.connect_client_core(client, project.cfg) + client.exec_command(f"rm -rf {(tmp_central_path).as_posix()}") + + shutil.rmtree(tmp_local_path) + + project.get_logging_path().mkdir( + parents=True, exist_ok=True + ) # TODO: why is this necessary + project.update_config_file(local_path=true_local_path) + project.get_logging_path().mkdir( + parents=True, exist_ok=True + ) # TODO: why is this necessary diff --git a/tests/tests_integration/test_ssh_file_transfer.py b/tests/tests_integration/test_ssh_file_transfer.py index 4601cfd3..9aab1d39 100644 --- a/tests/tests_integration/test_ssh_file_transfer.py +++ b/tests/tests_integration/test_ssh_file_transfer.py @@ -1,18 +1,13 @@ """ """ -import copy import shutil from pathlib import Path -import pandas as pd -import paramiko import pytest import ssh_test_utils import test_utils -from base import BaseTest - -from datashuttle.utils import ssh +from base_transfer import BaseTransfer TEST_SSH = ssh_test_utils.get_test_ssh() @@ -44,7 +39,7 @@ ] -class TestFileTransfer(BaseTest): +class TestFileTransfer(BaseTransfer): # ---------------------------------------------------------------------------------- # Test File Transfer - All Options @@ -118,237 +113,3 @@ def test_combinations_filesystem_transfer( shutil.rmtree(self.central_from_local(project.cfg["local_path"])) except FileNotFoundError: pass - - @pytest.mark.skipif("not TEST_SSH", reason="TEST_SSH is false") - @pytest.mark.parametrize( - "sub_names", [["all"], ["all_non_sub", "sub-002"]] - ) - @pytest.mark.parametrize( - "ses_names", [["all"], ["ses-002_random-key"], ["all_non_ses"]] - ) - @pytest.mark.parametrize( - "datatype", [["all"], ["anat", "all_non_datatype"]] - ) - def test_combinations_ssh_transfer( - self, - ssh_setup, - sub_names, - ses_names, - datatype, - ): - """ - Test a subset of argument combinations while testing over SSH connection - to a container. This is very slow, due to the rclone ssh transfer (which - is performed twice in this test, once for upload, once for download), around - 8 seconds per parameterization. - - In test setup, the entire project is created in the `local_path` and - is uploaded to `central_path`. So we only need to set up once per test, - upload and download is to temporary folders and these temporary folders - are cleaned at the end of each parameterization. - """ - pathtable, project = ssh_setup - - # Upload data from the setup local project to a temporary - # central directory. - true_central_path = project.cfg["central_path"] - tmp_central_path = ( - project.cfg["central_path"] / "tmp" / project.project_name - ) - project.get_logging_path().mkdir( - parents=True, exist_ok=True - ) # TODO: why is this necessary - - project.update_config_file(central_path=tmp_central_path) - - project.upload_custom( - "rawdata", sub_names, ses_names, datatype, init_log=False - ) - - expected_transferred_paths = self.get_expected_transferred_paths( - pathtable, sub_names, ses_names, datatype - ) - - # Search the paths that were transferred and tidy them up, - # then check against the paths that were expected to be transferred. - transferred_files = ssh_test_utils.recursive_search_central(project) - paths_to_transferred_files = self.remove_path_before_rawdata( - transferred_files - ) - - assert sorted(paths_to_transferred_files) == sorted( - expected_transferred_paths - ) - - # Now, move data from the central path where the project is - # setup, to a temp local folder to test download. - true_local_path = project.cfg["local_path"] - tmp_local_path = ( - project.cfg["local_path"] / "tmp" / project.project_name - ) - tmp_local_path.mkdir(exist_ok=True, parents=True) - - project.update_config_file(local_path=tmp_local_path) - project.update_config_file(central_path=true_central_path) - - project.download_custom( - "rawdata", sub_names, ses_names, datatype, init_log=False - ) - - # Find the transferred paths, tidy them up - # and check expected paths were transferred. - all_transferred = list((tmp_local_path / "rawdata").glob("**/*")) - all_transferred = [ - path_ for path_ in all_transferred if path_.is_file() - ] - - paths_to_transferred_files = self.remove_path_before_rawdata( - all_transferred - ) - - assert sorted(paths_to_transferred_files) == sorted( - expected_transferred_paths - ) - - # Clean up, removing the temp directories and - # resetting the project paths. - with paramiko.SSHClient() as client: - ssh.connect_client_core(client, project.cfg) - client.exec_command(f"rm -rf {(tmp_central_path).as_posix()}") - - shutil.rmtree(tmp_local_path) - - project.get_logging_path().mkdir( - parents=True, exist_ok=True - ) # TODO: why is this necessary - project.update_config_file(local_path=true_local_path) - project.get_logging_path().mkdir( - parents=True, exist_ok=True - ) # TODO: why is this necessary - - # ---------------------------------------------------------------------------------- - # Utils - # ---------------------------------------------------------------------------------- - - def get_expected_transferred_paths( - self, pathtable, sub_names, ses_names, datatype - ): - """ - Process the expected files that are transferred using the logic in - `make_pathtable_search_filter()` to - """ - parsed_sub_names = self.parse_arguments(pathtable, sub_names, "sub") - parsed_ses_names = self.parse_arguments(pathtable, ses_names, "ses") - parsed_datatype = self.parse_arguments(pathtable, datatype, "datatype") - - # Filter pathtable to get files that were expected to be transferred - ( - sub_ses_dtype_arguments, - extra_arguments, - ) = self.make_pathtable_search_filter( - parsed_sub_names, parsed_ses_names, parsed_datatype - ) - - datatype_folders = self.query_table(pathtable, sub_ses_dtype_arguments) - extra_folders = self.query_table(pathtable, extra_arguments) - - expected_paths = pd.concat([datatype_folders, extra_folders]) - expected_paths = expected_paths.drop_duplicates(subset="path") - - expected_paths = self.remove_path_before_rawdata(expected_paths.path) - - return expected_paths - - def make_pathtable_search_filter(self, sub_names, ses_names, datatype): - """ - Create a string of arguments to pass to pd.query() that will - create the table of only transferred sub, ses and datatype. - - Two arguments must be created, one of all sub / ses / datatypes - and the other of all non sub/ non ses / non datatype - folders. These must be handled separately as they are - mutually exclusive. - """ - sub_ses_dtype_arguments = [] - extra_arguments = [] - - for sub in sub_names: - if sub == "all_non_sub": - extra_arguments += ["is_non_sub == True"] - else: - for ses in ses_names: - if ses == "all_non_ses": - extra_arguments += [ - f"(parent_sub == '{sub}' & is_non_ses == True)" - ] - else: - for dtype in datatype: - if dtype == "all_non_datatype": - extra_arguments += [ - f"(parent_sub == '{sub}' & parent_ses == '{ses}' " - f"& is_ses_level_non_datatype == True)" - ] - else: - sub_ses_dtype_arguments += [ - f"(parent_sub == '{sub}' & parent_ses == '{ses}' " - f"& (parent_datatype == '{dtype}' " - f"| parent_datatype == '{dtype}'))" - ] - - return sub_ses_dtype_arguments, extra_arguments - - def remove_path_before_rawdata(self, list_of_paths): - """ - Remove the path to project files before the "rawdata" so - they can be compared no matter where the project was stored - (e.g. on a central server vs. local filesystem). - """ - cut_paths = [] - for path_ in list_of_paths: - parts = Path(path_).parts - cut_paths.append(Path(*parts[parts.index("rawdata") :])) - return cut_paths - - def query_table(self, pathtable, arguments): - """ - Search the table for arguments, return empty - if arguments empty - """ - if any(arguments): - folders = pathtable.query(" | ".join(arguments)) - else: - folders = pd.DataFrame() - return folders - - def parse_arguments(self, pathtable, list_of_names, field): - """ - Replicate datashuttle name formatting by parsing - "all" arguments and turning them into a list of all names, - (subject or session), taken from the pathtable. - """ - if list_of_names in [["all"], [f"all_{field}"]]: - entries = pathtable.query(f"parent_{field} != False")[ - f"parent_{field}" - ] - entries = list(set(entries)) - if list_of_names == ["all"]: - entries += ( - [f"all_non_{field}"] - if field != "datatype" - else ["all_non_datatype"] - ) - list_of_names = entries - return list_of_names - - def create_all_pathtable_files(self, pathtable): - """ - Create the entire test project in the defined - location (usually project's `local_path`). - """ - for i in range(pathtable.shape[0]): - filepath = pathtable["base_folder"][i] / pathtable["path"][i] - filepath.parents[0].mkdir(parents=True, exist_ok=True) - test_utils.write_file(filepath, contents="test_entry") - - def central_from_local(self, path_): - return Path(str(copy.copy(path_)).replace("local", "central")) diff --git a/tests/tests_integration/test_ssh_setup.py b/tests/tests_integration/test_ssh_setup.py deleted file mode 100644 index e2f032bd..00000000 --- a/tests/tests_integration/test_ssh_setup.py +++ /dev/null @@ -1,97 +0,0 @@ -import pytest -import ssh_test_utils -import test_utils -from base import BaseTest - -# from pytest import ssh_config -from datashuttle.utils import ssh - -TEST_SSH = ssh_test_utils.get_test_ssh() - - -@pytest.mark.skipif("not TEST_SSH", reason="TEST_SSH is false") -class TestSSH(BaseTest): - @pytest.fixture(scope="function") - def project(test, tmp_path, setup_ssh_container): - """ - Make a project as per usual, but now add - in test ssh configurations - """ - tmp_path = tmp_path / "test with space" - - test_project_name = "test_ssh" - project, cwd = test_utils.setup_project_fixture( - tmp_path, test_project_name - ) - - # ssh_test_utils.setup_project_and_container_for_ssh(project) - - ssh_test_utils.setup_project_for_ssh( - project, - central_path=f"/home/sshuser/datashuttle/{project.project_name}", - central_host_id="localhost", - central_host_username="sshuser", - ) - - yield project - test_utils.teardown_project(cwd, project) - - # ----------------------------------------------------------------- - # Test Setup SSH Connection - # ----------------------------------------------------------------- - - @pytest.mark.parametrize("input_", ["n", "o", "@"]) - def test_verify_ssh_central_host_do_not_accept( - self, capsys, project, input_ - ): - """ - Use the main function to test this. Test the sub-function - when accepting, because this main function will also - call setup ssh key pairs which we don't want to do yet - - This should only accept for "y" so try some random strings - including "n" and check they all do not make the connection. - """ - orig_builtin = ssh_test_utils.setup_mock_input(input_) - - project.setup_ssh_connection() - - ssh_test_utils.restore_mock_input(orig_builtin) - - captured = capsys.readouterr() - - assert "Host not accepted. No connection made.\n" in captured.out - - def test_verify_ssh_central_host_accept(self, capsys, project): - """ - User is asked to accept the server hostkey. Mock this here - and check hostkey is successfully accepted and written to configs. - """ - test_utils.clear_capsys(capsys) - - verified = ssh_test_utils.setup_ssh_connection( - project, setup_ssh_key_pair=False - ) - - assert verified - captured = capsys.readouterr() - - assert captured.out == "Host accepted.\n" - - with open(project.cfg.hostkeys_path, "r") as file: - hostkey = file.readlines()[0] - - assert f"{project.cfg['central_host_id']} ssh-ed25519 " in hostkey - - def test_generate_and_write_ssh_key(self, project): - """ - Check ssh key for passwordless connection is written - to file - """ - path_to_save = project.cfg["local_path"] / "test" - ssh.generate_and_write_ssh_key(path_to_save) - - with open(path_to_save, "r") as file: - first_line = file.readlines()[0] - - assert first_line == "-----BEGIN RSA PRIVATE KEY-----\n"