Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion datashuttle/configs/canonical_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,14 @@ def get_canonical_configs() -> dict:
canonical_configs = {
"local_path": Union[str, Path],
"central_path": Optional[Union[str, Path]],
"connection_method": Optional[Literal["ssh", "local_filesystem"]],
"connection_method": Optional[
Literal["ssh", "local_filesystem", "aws", "gdrive"]
],
"central_host_id": Optional[str],
"central_host_username": Optional[str],
"aws_bucket_name": Optional[str],
"aws_region": Optional[str],
"gdrive_folder_id": Optional[str],
}

return canonical_configs
Expand Down Expand Up @@ -128,6 +133,25 @@ def check_dict_values_raise_on_fail(config_dict: Configs) -> None:
ConfigError,
)

# Check AWS S3 settings
if config_dict["connection_method"] == "aws" and (
not config_dict["aws_bucket_name"] or not config_dict["aws_region"]
):
utils.log_and_raise_error(
"'aws_bucket_name' and 'aws_region' are required if 'connection_method' is 'aws'.",
ConfigError,
)

# Check Google Drive settings
if (
config_dict["connection_method"] == "gdrive"
and not config_dict["gdrive_folder_id"]
):
utils.log_and_raise_error(
"'gdrive_folder_id' is required if 'connection_method' is 'gdrive'.",
ConfigError,
)

# Initialise the local project folder
utils.print_message_to_user(
f"Making project folder at: {config_dict['local_path']}"
Expand Down
8 changes: 8 additions & 0 deletions datashuttle/configs/config_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def __init__(
self.logging_path: Path
self.hostkeys_path: Path
self.ssh_key_path: Path
self.aws_key_path: Path
self.gdrive_fo_path: Path
self.project_metadata_path: Path

def setup_after_load(self) -> None:
Expand Down Expand Up @@ -236,6 +238,12 @@ def init_paths(self) -> None:

self.ssh_key_path = datashuttle_path / f"{self.project_name}_ssh_key"

self.aws_key_path = datashuttle_path / f"{self.project_name}_aws_key"

self.gdrive_key_path = (
datashuttle_path / f"{self.project_name}_gdrive_key"
)

self.hostkeys_path = datashuttle_path / "hostkeys"

self.logging_path = self.make_and_get_logging_path()
Expand Down
121 changes: 96 additions & 25 deletions datashuttle/datashuttle_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@
from datashuttle.configs.config_class import Configs
from datashuttle.datashuttle_functions import _format_top_level_folder
from datashuttle.utils import (
aws,
ds_logger,
folders,
formatting,
gdrive,
getters,
rclone,
ssh,
Expand All @@ -53,6 +55,8 @@
from datashuttle.utils.decorators import ( # noqa
check_configs_set,
check_is_not_local_project,
requires_aws_configs,
requires_gdrive_configs,
requires_ssh_configs,
)

Expand Down Expand Up @@ -903,47 +907,44 @@ def make_config_file(
connection_method: str | None = None,
central_host_id: Optional[str] = None,
central_host_username: Optional[str] = None,
aws_bucket_name: Optional[str] = None,
aws_region: Optional[str] = "us-east-1",
gdrive_folder_id: Optional[str] = None,
) -> None:
"""
Initialise the configurations for datashuttle to use on the
local machine. Once initialised, these settings will be
used each time the datashuttle is opened. This method
can also be used to completely overwrite existing configs.

These settings are stored in a config file on the
datashuttle path (not in the project folder)
on the local machine. Use get_config_path() to
get the full path to the saved config file.

Use update_config_file() to selectively update settings.
used each time datashuttle is opened.

Parameters
----------

local_path :
path to project folder on local machine
Path to project folder on local machine.

central_path :
Filepath to central project.
If this is local (i.e. connection_method = "local_filesystem"),
this is the full path on the local filesystem
Otherwise, if this is via ssh (i.e. connection method = "ssh"),
this is the path to the project folder on central machine.
This should be a full path to central folder i.e. this cannot
include ~ home folder syntax, must contain the full path
(e.g. /nfs/nhome/live/jziminski)
Filepath to central project (local or SSH).

connection_method :
The method used to connect to the central project filesystem,
e.g. "local_filesystem" (e.g. mounted drive) or "ssh"
The method used to connect to the central project filesystem:
- "local_filesystem" (mounted drive)
- "ssh" (remote connection)
- "aws_s3" (Amazon S3 cloud storage)
- "google_drive" (Google Drive cloud storage)

central_host_id :
server address for central host for ssh connection
e.g. "ssh.swc.ucl.ac.uk"
Server address for SSH connection.

central_host_username :
username for which to log in to central host.
e.g. "jziminski"
Username for SSH login.

aws_bucket_name :
Name of the AWS S3 bucket (required for AWS).

aws_region :
AWS region (default: "us-east-1").

google_drive_folder_id :
Folder ID for Google Drive (required for Google Drive).
"""
self._start_log(
"make-config-file",
Expand All @@ -967,6 +968,9 @@ def make_config_file(
"connection_method": connection_method,
"central_host_id": central_host_id,
"central_host_username": central_host_username,
"aws_bucket_name": aws_bucket_name,
"aws_region": aws_region,
"gdrive_folder_id": gdrive_folder_id,
},
)

Expand Down Expand Up @@ -1464,6 +1468,22 @@ def _setup_rclone_central_local_filesystem_config(self) -> None:
self.cfg.get_rclone_config_name("local_filesystem"),
)

def _setup_rclone_central_aws_config(self, log: bool) -> None:
rclone.setup_rclone_config_for_aws(
self.cfg,
self.cfg.get_rclone_config_name("aws"),
self.cfg["aws_region"],
log=log,
)

def _setup_rclone_central_gdrive_config(self, log: bool) -> None:
rclone.setup_rclone_config_for_gdrive(
self.cfg,
self.cfg.get_rclone_config_name("gdrive"),
self.cfg["gdrive_folder_id"],
log=log,
)

# Persistent settings
# -------------------------------------------------------------------------

Expand Down Expand Up @@ -1565,3 +1585,54 @@ def _check_top_level_folder(self, top_level_folder):
f"{canonical_top_level_folders}",
ValueError,
)

# -------------------------------------------------------------------------
# AWS S3 and Google Drive
# -------------------------------------------------------------------------

@requires_aws_configs
@check_is_not_local_project
def setup_aws_connection(self) -> None:
"""
Setup a connection to AWS S3 using RClone.

Assumes the aws_bucket_name and aws_region are set in configs.
First, prompt the user to verify the AWS bucket as trusted,
then create the RClone config for AWS.
"""
self._start_log("setup-aws-connection", local_vars=locals())

verified = aws.verify_aws_remote(
self.cfg["aws_bucket_name"],
self.cfg["aws_region"],
self.cfg.aws_key_path,
log=True,
)

if verified:
self._setup_rclone_central_aws_config(log=True)

ds_logger.close_log_filehandler()

@requires_gdrive_configs
@check_is_not_local_project
def setup_gdrive_connection(self) -> None:
"""
Setup a connection to Google Drive using RClone.

Assumes the gdrive_folder_id is set in configs.
First, prompt the user to verify and trust the folder ID,
then create the RClone config for Google Drive.
"""
self._start_log("setup-gdrive-connection", local_vars=locals())

verified = gdrive.verify_gdrive_remote(
self.cfg["gdrive_folder_id"],
self.cfg.gdrive_key_path,
log=True,
)

if verified:
self._setup_rclone_central_gdrive_config(log=True)

ds_logger.close_log_filehandler()
Loading