diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index b65ad6c66..617cd3e77 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -39,9 +39,14 @@ def get_canonical_configs() -> dict: canonical_configs = { "local_path": Union[str, Path], "central_path": Optional[Union[str, Path]], - "connection_method": Optional[Literal["ssh", "local_filesystem"]], + "connection_method": Optional[ + Literal["ssh", "local_filesystem", "aws", "gdrive"] + ], "central_host_id": Optional[str], "central_host_username": Optional[str], + "aws_bucket_name": Optional[str], + "aws_region": Optional[str], + "gdrive_folder_id": Optional[str], } return canonical_configs @@ -128,6 +133,25 @@ def check_dict_values_raise_on_fail(config_dict: Configs) -> None: ConfigError, ) + # Check AWS S3 settings + if config_dict["connection_method"] == "aws" and ( + not config_dict["aws_bucket_name"] or not config_dict["aws_region"] + ): + utils.log_and_raise_error( + "'aws_bucket_name' and 'aws_region' are required if 'connection_method' is 'aws'.", + ConfigError, + ) + + # Check Google Drive settings + if ( + config_dict["connection_method"] == "gdrive" + and not config_dict["gdrive_folder_id"] + ): + utils.log_and_raise_error( + "'gdrive_folder_id' is required if 'connection_method' is 'gdrive'.", + ConfigError, + ) + # Initialise the local project folder utils.print_message_to_user( f"Making project folder at: {config_dict['local_path']}" diff --git a/datashuttle/configs/config_class.py b/datashuttle/configs/config_class.py index 5fd70fcad..9c082f19e 100644 --- a/datashuttle/configs/config_class.py +++ b/datashuttle/configs/config_class.py @@ -59,6 +59,8 @@ def __init__( self.logging_path: Path self.hostkeys_path: Path self.ssh_key_path: Path + self.aws_key_path: Path + self.gdrive_fo_path: Path self.project_metadata_path: Path def setup_after_load(self) -> None: @@ -236,6 +238,12 @@ def init_paths(self) -> None: self.ssh_key_path = datashuttle_path / f"{self.project_name}_ssh_key" + self.aws_key_path = datashuttle_path / f"{self.project_name}_aws_key" + + self.gdrive_key_path = ( + datashuttle_path / f"{self.project_name}_gdrive_key" + ) + self.hostkeys_path = datashuttle_path / "hostkeys" self.logging_path = self.make_and_get_logging_path() diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index 7535e2664..6c2586ce3 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -36,9 +36,11 @@ from datashuttle.configs.config_class import Configs from datashuttle.datashuttle_functions import _format_top_level_folder from datashuttle.utils import ( + aws, ds_logger, folders, formatting, + gdrive, getters, rclone, ssh, @@ -53,6 +55,8 @@ from datashuttle.utils.decorators import ( # noqa check_configs_set, check_is_not_local_project, + requires_aws_configs, + requires_gdrive_configs, requires_ssh_configs, ) @@ -903,47 +907,44 @@ def make_config_file( connection_method: str | None = None, central_host_id: Optional[str] = None, central_host_username: Optional[str] = None, + aws_bucket_name: Optional[str] = None, + aws_region: Optional[str] = "us-east-1", + gdrive_folder_id: Optional[str] = None, ) -> None: """ Initialise the configurations for datashuttle to use on the local machine. Once initialised, these settings will be - used each time the datashuttle is opened. This method - can also be used to completely overwrite existing configs. - - These settings are stored in a config file on the - datashuttle path (not in the project folder) - on the local machine. Use get_config_path() to - get the full path to the saved config file. - - Use update_config_file() to selectively update settings. + used each time datashuttle is opened. Parameters ---------- - local_path : - path to project folder on local machine + Path to project folder on local machine. central_path : - Filepath to central project. - If this is local (i.e. connection_method = "local_filesystem"), - this is the full path on the local filesystem - Otherwise, if this is via ssh (i.e. connection method = "ssh"), - this is the path to the project folder on central machine. - This should be a full path to central folder i.e. this cannot - include ~ home folder syntax, must contain the full path - (e.g. /nfs/nhome/live/jziminski) + Filepath to central project (local or SSH). connection_method : - The method used to connect to the central project filesystem, - e.g. "local_filesystem" (e.g. mounted drive) or "ssh" + The method used to connect to the central project filesystem: + - "local_filesystem" (mounted drive) + - "ssh" (remote connection) + - "aws_s3" (Amazon S3 cloud storage) + - "google_drive" (Google Drive cloud storage) central_host_id : - server address for central host for ssh connection - e.g. "ssh.swc.ucl.ac.uk" + Server address for SSH connection. central_host_username : - username for which to log in to central host. - e.g. "jziminski" + Username for SSH login. + + aws_bucket_name : + Name of the AWS S3 bucket (required for AWS). + + aws_region : + AWS region (default: "us-east-1"). + + google_drive_folder_id : + Folder ID for Google Drive (required for Google Drive). """ self._start_log( "make-config-file", @@ -967,6 +968,9 @@ def make_config_file( "connection_method": connection_method, "central_host_id": central_host_id, "central_host_username": central_host_username, + "aws_bucket_name": aws_bucket_name, + "aws_region": aws_region, + "gdrive_folder_id": gdrive_folder_id, }, ) @@ -1464,6 +1468,22 @@ def _setup_rclone_central_local_filesystem_config(self) -> None: self.cfg.get_rclone_config_name("local_filesystem"), ) + def _setup_rclone_central_aws_config(self, log: bool) -> None: + rclone.setup_rclone_config_for_aws( + self.cfg, + self.cfg.get_rclone_config_name("aws"), + self.cfg["aws_region"], + log=log, + ) + + def _setup_rclone_central_gdrive_config(self, log: bool) -> None: + rclone.setup_rclone_config_for_gdrive( + self.cfg, + self.cfg.get_rclone_config_name("gdrive"), + self.cfg["gdrive_folder_id"], + log=log, + ) + # Persistent settings # ------------------------------------------------------------------------- @@ -1565,3 +1585,54 @@ def _check_top_level_folder(self, top_level_folder): f"{canonical_top_level_folders}", ValueError, ) + + # ------------------------------------------------------------------------- + # AWS S3 and Google Drive + # ------------------------------------------------------------------------- + + @requires_aws_configs + @check_is_not_local_project + def setup_aws_connection(self) -> None: + """ + Setup a connection to AWS S3 using RClone. + + Assumes the aws_bucket_name and aws_region are set in configs. + First, prompt the user to verify the AWS bucket as trusted, + then create the RClone config for AWS. + """ + self._start_log("setup-aws-connection", local_vars=locals()) + + verified = aws.verify_aws_remote( + self.cfg["aws_bucket_name"], + self.cfg["aws_region"], + self.cfg.aws_key_path, + log=True, + ) + + if verified: + self._setup_rclone_central_aws_config(log=True) + + ds_logger.close_log_filehandler() + + @requires_gdrive_configs + @check_is_not_local_project + def setup_gdrive_connection(self) -> None: + """ + Setup a connection to Google Drive using RClone. + + Assumes the gdrive_folder_id is set in configs. + First, prompt the user to verify and trust the folder ID, + then create the RClone config for Google Drive. + """ + self._start_log("setup-gdrive-connection", local_vars=locals()) + + verified = gdrive.verify_gdrive_remote( + self.cfg["gdrive_folder_id"], + self.cfg.gdrive_key_path, + log=True, + ) + + if verified: + self._setup_rclone_central_gdrive_config(log=True) + + ds_logger.close_log_filehandler() diff --git a/datashuttle/tui/configs.py b/datashuttle/tui/configs.py index 974ee08af..ce164450e 100644 --- a/datashuttle/tui/configs.py +++ b/datashuttle/tui/configs.py @@ -25,7 +25,12 @@ from datashuttle.tui.custom_widgets import ClickableInput from datashuttle.tui.interface import Interface -from datashuttle.tui.screens import modal_dialogs, setup_ssh +from datashuttle.tui.screens import ( + modal_dialogs, + setup_aws, + setup_gdrive, + setup_ssh, +) from datashuttle.tui.tooltips import get_tooltip @@ -58,13 +63,19 @@ def __init__( self.parent_class = parent_class self.interface = interface self.config_ssh_widgets: List[Any] = [] + self.config_aws_widgets: List[Any] = [] + self.config_gdrive_widgets: List[Any] = [] def compose(self) -> ComposeResult: """ `self.config_ssh_widgets` are SSH-setup related widgets that are only required when the user selects the SSH connection method. These are displayed / hidden based on the - `connection_method` + `connection_method`. + + `self.config_aws_widgets` are AWS-related widgets. + + `self.config_gdrive_widgets` are Google Drive-related widgets. `config_screen_widgets` are core config-related widgets that are always displayed. @@ -90,6 +101,32 @@ def compose(self) -> ComposeResult: ), ] + self.config_aws_widgets = [ + Label("AWS Bucket Name", id="configs_aws_bucket_name_label"), + ClickableInput( + self.parent_class.mainwindow, + placeholder="e.g. my-bucket-name", + id="configs_aws_bucket_name_input", + ), + Label("AWS Region", id="configs_aws_region_label"), + ClickableInput( + self.parent_class.mainwindow, + placeholder="e.g. us-east-1", + id="configs_aws_region_input", + ), + ] + + self.config_gdrive_widgets = [ + Label( + "Google Drive Folder ID", id="configs_gdrive_folder_id_label" + ), + ClickableInput( + self.parent_class.mainwindow, + placeholder="e.g. 1A2B3C4D5E6F7G8H", + id="configs_gdrive_folder_id_input", + ), + ] + config_screen_widgets = [ Label("Local Path", id="configs_local_path_label"), Horizontal( @@ -108,6 +145,8 @@ def compose(self) -> ComposeResult: id="configs_local_filesystem_radiobutton", ), RadioButton("SSH", id="configs_ssh_radiobutton"), + RadioButton("AWS S3", id="configs_aws_radiobutton"), + RadioButton("Google Drive", id="configs_gdrive_radiobutton"), RadioButton( "No connection (local only)", id="configs_local_only_radiobutton", @@ -115,6 +154,8 @@ def compose(self) -> ComposeResult: id="configs_connect_method_radioset", ), *self.config_ssh_widgets, + *self.config_aws_widgets, + *self.config_gdrive_widgets, Label("Central Path", id="configs_central_path_label"), Horizontal( ClickableInput( @@ -131,8 +172,14 @@ def compose(self) -> ComposeResult: "Setup SSH Connection", id="configs_setup_ssh_connection_button", ), - # Below button is always hidden when accessing - # configs from project manager screen + Button( + "Setup AWS Connection", + id="configs_setup_aws_connection_button", + ), + Button( + "Setup Google Drive Connection", + id="configs_setup_gdrive_connection_button", + ), Button( "Go to Project Screen", id="configs_go_to_project_screen_button", @@ -170,18 +217,11 @@ def compose(self) -> ComposeResult: def on_mount(self) -> None: """ - When we have mounted the widgets, the following logic depends on whether - we are setting up a new project (`self.project is `None`) or have - an instantiated project. - - If we have a project, then we want to fill the widgets with the existing - configs. Otherwise, we set to some reasonable defaults, required to - determine the display of SSH widgets. "overwrite_files_checkbox" - should be off by default anyway if `value` is not set, but we set here - anyway as it is critical this is not on by default. + When widgets are mounted, initialize based on whether this is a new + project or editing an existing one. """ - # Setup display widget defaults self.query_one("#configs_go_to_project_screen_button").visible = False + if self.interface: self.fill_widgets_with_project_configs() else: @@ -189,87 +229,124 @@ def on_mount(self) -> None: True ) self.switch_ssh_widgets_display(display_ssh=False) + self.switch_aws_widgets_display(display_aws=False) + self.switch_gdrive_widgets_display(display_gdrive=False) self.query_one("#configs_setup_ssh_connection_button").visible = ( False ) + self.query_one("#configs_setup_aws_connection_button").visible = ( + False + ) + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = False # Setup tooltips if not self.interface: id = "#configs_name_input" self.query_one(id).tooltip = get_tooltip(id) - - # Assumes 'local_filesystem' is default if no project set. assert ( self.query_one("#configs_local_filesystem_radiobutton").value is True ) - self.set_central_path_input_tooltip(display_ssh=False) + self.set_central_path_input_tooltip("local_filesystem") else: - display_ssh = ( - self.interface.project.cfg["connection_method"] == "ssh" - ) - self.set_central_path_input_tooltip(display_ssh) + method = self.interface.project.cfg["connection_method"] + self.set_central_path_input_tooltip(method) for id in [ "#configs_local_path_input", "#configs_connect_method_label", "#configs_local_filesystem_radiobutton", "#configs_ssh_radiobutton", + "#configs_aws_radiobutton", + "#configs_gdrive_radiobutton", "#configs_local_only_radiobutton", "#configs_central_host_username_input", "#configs_central_host_id_input", + "#configs_aws_bucket_name_input", + "#configs_aws_region_input", + "#configs_gdrive_folder_id_input", ]: self.query_one(id).tooltip = get_tooltip(id) def on_radio_set_changed(self, event: RadioSet.Changed) -> None: """ - Update the displayed SSH widgets when the `connection_method` - radiobuttons are changed. - - When SSH is set, ssh config-setters are shown. Otherwise, these - are hidden. + Update the displayed widgets and config state when the + `connection_method` radiobuttons are changed. - When mode is `No connection`, the `central_path` is cleared and - disabled. + Supports SSH, AWS S3, Google Drive, Local Filesystem, and + No Connection modes. """ label = str(event.pressed.label) assert label in [ "SSH", + "AWS S3", + "Google Drive", "Local Filesystem", "No connection (local only)", ], "Unexpected label." - if label == "No connection (local only)": - self.query_one("#configs_central_path_input").value = "" - self.query_one("#configs_central_path_input").disabled = True - self.query_one("#configs_central_path_select_button").disabled = ( - True - ) - display_ssh = False - else: - self.query_one("#configs_central_path_input").disabled = False - self.query_one("#configs_central_path_select_button").disabled = ( - False - ) - display_ssh = True if label == "SSH" else False + is_ssh = label == "SSH" + is_aws = label == "AWS S3" + is_gdrive = label == "Google Drive" + is_local = label == "Local Filesystem" + is_none = label == "No connection (local only)" + + central_input = self.query_one("#configs_central_path_input") + select_button = self.query_one("#configs_central_path_select_button") - self.switch_ssh_widgets_display(display_ssh) - self.set_central_path_input_tooltip(display_ssh) + # Disable fields if no connection + if is_none: + central_input.value = "" + central_input.disabled = True + select_button.disabled = True + else: + central_input.disabled = False + select_button.disabled = False + + # Toggle widget groups + if is_ssh: + self.switch_ssh_widgets_display(True) + self.switch_aws_widgets_display(False) + self.switch_gdrive_widgets_display(False) + elif is_aws: + self.switch_ssh_widgets_display(False) + self.switch_aws_widgets_display(True) + self.switch_gdrive_widgets_display(False) + elif is_gdrive: + self.switch_ssh_widgets_display(False) + self.switch_aws_widgets_display(False) + self.switch_gdrive_widgets_display(True) + else: # Local Filesystem or No Connection + self.switch_ssh_widgets_display(False) + self.switch_aws_widgets_display(False) + self.switch_gdrive_widgets_display(False) + + # Tooltip update + if is_ssh: + self.set_central_path_input_tooltip("ssh") + elif is_aws: + self.set_central_path_input_tooltip("aws") + elif is_gdrive: + self.set_central_path_input_tooltip("gdrive") + else: + self.set_central_path_input_tooltip("local_filesystem") - def set_central_path_input_tooltip(self, display_ssh: bool) -> None: + def set_central_path_input_tooltip(self, mode: str) -> None: """ - Use a different tooltip depending on whether connection method - is ssh or local filesystem. + Use a different tooltip depending on the selected connection mode. + `mode` must be one of: 'ssh', 'aws', 'gdrive', 'local_filesystem' """ id = "#configs_central_path_input" - if display_ssh: - self.query_one(id).tooltip = get_tooltip( - "config_central_path_input_mode-ssh" - ) - else: - self.query_one(id).tooltip = get_tooltip( - "config_central_path_input_mode-local_filesystem" - ) + tooltip_id = { + "ssh": "config_central_path_input_mode-ssh", + "aws": "config_central_path_input_mode-aws", + "gdrive": "config_central_path_input_mode-gdrive", + "local_filesystem": "config_central_path_input_mode-local_filesystem", + }.get(mode, "config_central_path_input_mode-local_filesystem") + + self.query_one(id).tooltip = get_tooltip(tooltip_id) def get_platform_dependent_example_paths( self, local_or_central: Literal["local", "central"], ssh: bool = False @@ -327,11 +404,87 @@ def switch_ssh_widgets_display(self, display_ssh: bool) -> None: placeholder ) + def switch_gdrive_widgets_display(self, display_gdrive: bool) -> None: + """ + Show or hide Google Drive-related configs based on whether the current + `connection_method` widget is "Google Drive". + + Parameters + ---------- + display_gdrive : bool + If `True`, display the Google Drive-related widgets. + """ + for widget in self.config_gdrive_widgets: + widget.display = display_gdrive + + # Hide local filesystem selector button when GDrive is selected + self.query_one("#configs_central_path_select_button").display = ( + not display_gdrive + ) + + # Show or hide GDrive setup button based on interface and mode + if self.interface is None: + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = False + else: + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = display_gdrive + + # Set placeholder if empty + if not self.query_one("#configs_central_path_input").value: + if display_gdrive: + placeholder = "e.g. gdrive://project-folder-id" + else: + placeholder = f"e.g. {self.get_platform_dependent_example_paths('central', ssh=False)}" + self.query_one("#configs_central_path_input").placeholder = ( + placeholder + ) + + def switch_aws_widgets_display(self, display_aws: bool) -> None: + """ + Show or hide AWS S3-related configs based on whether the current + `connection_method` widget is "AWS S3". + + Parameters + ---------- + display_aws : bool + If `True`, display the AWS-related widgets. + """ + for widget in self.config_aws_widgets: + widget.display = display_aws + + # Hide local filesystem selector button when AWS is selected + self.query_one("#configs_central_path_select_button").display = ( + not display_aws + ) + + # Show or hide AWS setup button based on interface and mode + if self.interface is None: + self.query_one("#configs_setup_aws_connection_button").visible = ( + False + ) + else: + self.query_one("#configs_setup_aws_connection_button").visible = ( + display_aws + ) + + # Set placeholder if empty + if not self.query_one("#configs_central_path_input").value: + if display_aws: + placeholder = "e.g. s3://bucket-name/project-path" + else: + placeholder = f"e.g. {self.get_platform_dependent_example_paths('central', ssh=False)}" + self.query_one("#configs_central_path_input").placeholder = ( + placeholder + ) + def on_button_pressed(self, event: Button.Pressed) -> None: """ - Enables the Create Folders button to read out current input values - and use these to call project.create_folders(). + Handle button presses in the configuration screen. """ + if event.button.id == "configs_save_configs_button": if not self.interface: self.setup_configs_for_a_new_project() @@ -341,6 +494,12 @@ def on_button_pressed(self, event: Button.Pressed) -> None: elif event.button.id == "configs_setup_ssh_connection_button": self.setup_ssh_connection() + elif event.button.id == "configs_setup_aws_connection_button": + self.setup_aws_connection() + + elif event.button.id == "configs_setup_gdrive_connection_button": + self.setup_gdrive_connection() + elif event.button.id == "configs_go_to_project_screen_button": self.parent_class.dismiss(self.interface) @@ -409,6 +568,40 @@ def setup_ssh_connection(self) -> None: setup_ssh.SetupSshScreen(self.interface) ) + def setup_gdrive_connection(self) -> None: + """ + Set up the `SetupGdriveScreen` screen. + """ + assert self.interface is not None, "type narrow flexible `interface`" + + if not self.widget_configs_match_saved_configs(): + self.parent_class.mainwindow.show_modal_error_dialog( + "The values set above must equal the datashuttle settings. " + "Either press 'Save' or reload this page." + ) + return + + self.parent_class.mainwindow.push_screen( + setup_gdrive.SetupGdriveScreen(self.interface) + ) + + def setup_aws_connection(self) -> None: + """ + Set up the `SetupAwsScreen` screen. + """ + assert self.interface is not None, "type narrow flexible `interface`" + + if not self.widget_configs_match_saved_configs(): + self.parent_class.mainwindow.show_modal_error_dialog( + "The values set above must equal the datashuttle settings. " + "Either press 'Save' or reload this page." + ) + return + + self.parent_class.mainwindow.push_screen( + setup_aws.SetupAwsScreen(self.interface) + ) + def widget_configs_match_saved_configs(self): """ Check that the configs currently stored in the widgets @@ -460,8 +653,6 @@ def setup_configs_for_a_new_project(self) -> None: True ) - # Could not find a neater way to combine the push screen - # while initiating the callback in one case but not the other. if cfg_kwargs["connection_method"] == "ssh": self.query_one( @@ -478,6 +669,38 @@ def setup_configs_for_a_new_project(self) -> None: "able to create and transfer project folders." ) + elif cfg_kwargs["connection_method"] == "aws": + + self.query_one( + "#configs_setup_aws_connection_button" + ).visible = True + self.query_one( + "#configs_setup_aws_connection_button" + ).disabled = False + + message = ( + "A datashuttle project has now been created.\n\n " + "Next, setup the AWS connection. Once complete, navigate to the " + "'Main Menu' and proceed to the project page, where you will be " + "able to create and transfer project folders." + ) + + elif cfg_kwargs["connection_method"] == "gdrive": + + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = True + self.query_one( + "#configs_setup_gdrive_connection_button" + ).disabled = False + + message = ( + "A datashuttle project has now been created.\n\n " + "Next, setup the Google Drive connection. Once complete, navigate to the " + "'Main Menu' and proceed to the project page, where you will be " + "able to create and transfer project folders." + ) + else: message = ( "A datashuttle project has now been created.\n\n " @@ -503,12 +726,51 @@ def setup_configs_for_an_existing_project(self) -> None: """ assert self.interface is not None, "type narrow flexible `interface`" - # Handle the edge case where connection method is changed after - # saving on the 'Make New Project' screen. - self.query_one("#configs_setup_ssh_connection_button").visible = True - cfg_kwargs = self.get_datashuttle_inputs_from_widgets() + # Show relevant setup button depending on selected method + connection_method = cfg_kwargs.get("connection_method", "") + if connection_method == "ssh": + self.query_one("#configs_setup_ssh_connection_button").visible = ( + True + ) + self.query_one("#configs_setup_aws_connection_button").visible = ( + False + ) + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = False + elif connection_method == "aws": + self.query_one("#configs_setup_aws_connection_button").visible = ( + True + ) + self.query_one("#configs_setup_ssh_connection_button").visible = ( + False + ) + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = False + elif connection_method == "gdrive": + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = True + self.query_one("#configs_setup_ssh_connection_button").visible = ( + False + ) + self.query_one("#configs_setup_aws_connection_button").visible = ( + False + ) + else: + self.query_one("#configs_setup_ssh_connection_button").visible = ( + False + ) + self.query_one("#configs_setup_aws_connection_button").visible = ( + False + ) + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = False + success, output = self.interface.set_configs_on_existing_project( cfg_kwargs ) @@ -558,6 +820,10 @@ def fill_widgets_with_project_configs(self) -> None: cfg_to_load["connection_method"] == "local_filesystem", "configs_local_only_radiobutton": cfg_to_load["connection_method"] is None, + "configs_aws_radiobutton": + cfg_to_load["connection_method"] == "aws", + "configs_gdrive_radiobutton": + cfg_to_load["connection_method"] == "gdrive", } # fmt: on @@ -568,6 +834,14 @@ def fill_widgets_with_project_configs(self) -> None: display_ssh=what_radiobuton_is_on["configs_ssh_radiobutton"] ) + self.switch_aws_widgets_display( + display_aws=what_radiobuton_is_on["configs_aws_radiobutton"] + ) + + self.switch_gdrive_widgets_display( + display_gdrive=what_radiobuton_is_on["configs_gdrive_radiobutton"] + ) + # Central Host ID input = self.query_one("#configs_central_host_id_input") value = ( @@ -586,6 +860,33 @@ def fill_widgets_with_project_configs(self) -> None: ) input.value = value + # AWS Bucket Name + input = self.query_one("#configs_aws_bucket_name_input") + value = ( + "" + if cfg_to_load.get("aws_bucket_name") is None + else cfg_to_load["aws_bucket_name"] + ) + input.value = value + + # AWS Region + input = self.query_one("#configs_aws_region_input") + value = ( + "" + if cfg_to_load.get("aws_region") is None + else cfg_to_load["aws_region"] + ) + input.value = value + + # GDrive Folder ID + input = self.query_one("#configs_gdrive_folder_id_input") + value = ( + "" + if cfg_to_load.get("gdrive_folder_id") is None + else cfg_to_load["gdrive_folder_id"] + ) + input.value = value + def get_datashuttle_inputs_from_widgets(self) -> Dict: """ Get the configs to pass to `make_config_file()` from @@ -611,6 +912,12 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: elif self.query_one("#configs_local_filesystem_radiobutton").value: connection_method = "local_filesystem" + elif self.query_one("#configs_aws_radiobutton").value: + connection_method = "aws" + + elif self.query_one("#configs_gdrive_radiobutton").value: + connection_method = "gdrive" + elif self.query_one("#configs_local_only_radiobutton").value: connection_method = None @@ -626,9 +933,25 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: central_host_username = self.query_one( "#configs_central_host_username_input" ).value - cfg_kwargs["central_host_username"] = ( None if central_host_username == "" else central_host_username ) + aws_bucket_name = self.query_one( + "#configs_aws_bucket_name_input" + ).value + cfg_kwargs["aws_bucket_name"] = ( + None if aws_bucket_name == "" else aws_bucket_name + ) + + aws_region = self.query_one("#configs_aws_region_input").value + cfg_kwargs["aws_region"] = None if aws_region == "" else aws_region + + gdrive_folder_id = self.query_one( + "#configs_gdrive_folder_id_input" + ).value + cfg_kwargs["gdrive_folder_id"] = ( + None if gdrive_folder_id == "" else gdrive_folder_id + ) + return cfg_kwargs diff --git a/datashuttle/tui/css/tui_menu.tcss b/datashuttle/tui/css/tui_menu.tcss index c9f35d60c..2f0dc4a6c 100644 --- a/datashuttle/tui/css/tui_menu.tcss +++ b/datashuttle/tui/css/tui_menu.tcss @@ -58,6 +58,12 @@ SetupSshScreen { align: center middle; } +SetupAwsScreen { + align: center middle; +} +SetupGdriveScreen { + align: center middle; +} SettingsScreen { align: center middle; } @@ -77,6 +83,24 @@ GetHelpScreen { #setup_ssh_screen_container { border: thick $panel-darken-3; } +#setup_aws_screen_container { + height: 75%; + width: 80%; + align: center middle; + border: thick $panel-lighten-1; +} +#setup_aws_screen_container { + border: thick $panel-darken-3; +} +#setup_gdrive_screen_container { + height: 75%; + width: 80%; + align: center middle; + border: thick $panel-lighten-1; +} +#setup_gdrive_screen_container { + border: thick $panel-darken-3; +} #messagebox_buttons_horizontal { align: center middle; } @@ -163,9 +187,15 @@ MessageBox:light > #messagebox_top_container { color: $success; /* unsure about this */ } +#configs_setup_aws_connection_button { + margin: 2 1 0 0; +} #configs_setup_ssh_connection_button { margin: 2 1 0 0; } +#configs_setup_gdrive_connection_button { + margin: 2 1 0 0; +} #configs_go_to_project_screen_button { margin: 2 1 0 0; } @@ -194,6 +224,15 @@ MessageBox:light > #messagebox_top_container { #configs_central_path_input { width: 70%; } +#configs_aws_region_input { + width: 70%; +} +#configs_aws_bucket_name_input { + width: 70%; +} +#configs_google_drive_folder_id_input { + width: 70%; +} #configs_local_path_select_button{ width: auto; } diff --git a/datashuttle/tui/interface.py b/datashuttle/tui/interface.py index 1cf397c59..8b3f64856 100644 --- a/datashuttle/tui/interface.py +++ b/datashuttle/tui/interface.py @@ -11,7 +11,7 @@ from datashuttle import DataShuttle from datashuttle.configs import load_configs -from datashuttle.utils import ssh +from datashuttle.utils import aws, gdrive, ssh class Interface: @@ -183,30 +183,51 @@ def validate_names( # Transfer # ---------------------------------------------------------------------------------- - def transfer_entire_project(self, upload: bool) -> InterfaceOutput: """ Transfer the entire project (all canonical top-level folders). Parameters ---------- - upload : bool Upload from local to central if `True`, otherwise download from central to remote. """ try: - if upload: - transfer_func = self.project.upload_entire_project - else: - transfer_func = self.project.download_entire_project + connection_method = self.project.cfg["connection_method"] - transfer_func( - overwrite_existing_files=self.tui_settings[ - "overwrite_existing_files" - ], - dry_run=self.tui_settings["dry_run"], - ) + if connection_method in ["ssh", "local_filesystem"]: + transfer_func = ( + self.project.upload_entire_project + if upload + else self.project.download_entire_project + ) + transfer_func( + overwrite_existing_files=self.tui_settings[ + "overwrite_existing_files" + ], + dry_run=self.tui_settings["dry_run"], + ) + + elif connection_method == "aws": + remote_path = self.project.cfg["aws_bucket_name"] + success, message = self.project.transfer_aws_files( + upload_or_download="upload" if upload else "download", + local_path=self.project.cfg["local_path"], + remote_path=remote_path, + ) + if not success: + return False, message + + elif connection_method == "gdrive": + remote_path = self.project.cfg["google_drive_folder_id"] + success, message = self.project.transfer_gdrive_files( + upload_or_download="upload" if upload else "download", + local_path=self.project.cfg["local_path"], + remote_path=remote_path, + ) + if not success: + return False, message return True, None @@ -271,7 +292,6 @@ def transfer_custom_selection( Parameters ---------- - selected_top_level_folder : str The top level folder selected in the TUI for this transfer window. @@ -289,21 +309,44 @@ def transfer_custom_selection( from central to remote. """ try: - if upload: - transfer_func = self.project.upload_custom - else: - transfer_func = self.project.download_custom + connection_method = self.project.cfg["connection_method"] - transfer_func( - selected_top_level_folder, - sub_names=sub_names, - ses_names=ses_names, - datatype=datatype, - overwrite_existing_files=self.tui_settings[ - "overwrite_existing_files" - ], - dry_run=self.tui_settings["dry_run"], - ) + if connection_method in ["ssh", "local_filesystem"]: + transfer_func = ( + self.project.upload_custom + if upload + else self.project.download_custom + ) + transfer_func( + selected_top_level_folder, + sub_names=sub_names, + ses_names=ses_names, + datatype=datatype, + overwrite_existing_files=self.tui_settings[ + "overwrite_existing_files" + ], + dry_run=self.tui_settings["dry_run"], + ) + + elif connection_method == "aws": + remote_path = f"{self.project.cfg['aws_bucket_name']}/{selected_top_level_folder}" + success, message = self.project.transfer_aws_files( + upload_or_download="upload" if upload else "download", + local_path=self.project.cfg["local_path"], + remote_path=remote_path, + ) + if not success: + return False, message + + elif connection_method == "grive": + remote_path = f"{self.project.cfg['google_drive_folder_id']}/{selected_top_level_folder}" + success, message = self.project.transfer_gdrive_files( + upload_or_download="upload" if upload else "download", + local_path=self.project.cfg["local_path"], + remote_path=remote_path, + ) + if not success: + return False, message return True, None @@ -455,3 +498,100 @@ def setup_key_pair_and_rclone_config( except BaseException as e: return False, str(e) + + # Setup AWS S3 + # ---------------------------------------------------------------------------------- + + def get_aws_bucket_name(self) -> str: + """ + Retrieve the AWS bucket name from project configs. + """ + return self.project.cfg["aws_bucket_name"] + + def get_aws_hostkey(self) -> InterfaceOutput: + """ + Retrieve AWS host credentials for verification. + """ + try: + key = aws.get_remote_aws_key(self.project.cfg["aws_bucket_name"]) + return True, key + except BaseException as e: + return False, str(e) + + def save_aws_key_locally(self, key: str) -> InterfaceOutput: + """ + Save the AWS access key locally for secure storage. + """ + try: + aws.save_aws_key_locally( + key, + self.project.cfg["aws_bucket_name"], + self.project.cfg.aws_key_path, + ) + return True, None + + except BaseException as e: + return False, str(e) + + def setup_aws_bucket_and_rclone_config( + self, bucket_name: str, region: str + ) -> InterfaceOutput: + try: + self.project.cfg["aws_bucket_name"] = bucket_name + self.project.cfg["aws_region"] = region + + self.project._setup_rclone_central_aws_config(log=False) + + return True, None + + except BaseException as e: + return False, str(e) + + # Setup Google Drive + # ---------------------------------------------------------------------------------- + + def get_gdrive_folder_id(self) -> str: + """ + Retrieve the Google Drive folder ID from project configs. + """ + return self.project.cfg["gdrive_folder_id"] + + def get_gdrive_hostkey(self) -> InterfaceOutput: + """ + Retrieve Google Drive credentials for verification. + """ + try: + key = gdrive.get_remote_gdrive_key( + self.project.cfg["gdrive_folder_id"] + ) + return True, key + except BaseException as e: + return False, str(e) + + def save_gdrive_key_locally(self, key: str) -> InterfaceOutput: + """ + Save the Google Drive credentials locally for secure storage. + """ + try: + gdrive.save_gdrive_key_locally( + key, + self.project.cfg["gdrive_folder_id"], + self.project.cfg.gdrive_key_path, + ) + return True, None + + except BaseException as e: + return False, str(e) + + def setup_gdrive_folder_and_rclone_config( + self, folder_id: str + ) -> InterfaceOutput: + try: + self.project.cfg["gdrive_folder_id"] = folder_id + + self.project._setup_rclone_central_gdrive_config(log=False) + + return True, None + + except BaseException as e: + return False, str(e) diff --git a/datashuttle/tui/screens/setup_aws.py b/datashuttle/tui/screens/setup_aws.py new file mode 100644 index 000000000..3b65568d7 --- /dev/null +++ b/datashuttle/tui/screens/setup_aws.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from textual.app import ComposeResult + + from datashuttle.tui.interface import Interface + +from textual.containers import Container, Horizontal +from textual.screen import ModalScreen +from textual.widgets import ( + Button, + Input, + Static, +) + + +class SetupAwsScreen(ModalScreen): + """ + This dialog window handles the TUI equivalent of API's + setup_aws_connection(). This asks to confirm AWS credentials + and takes an access key for setup. + + the TUI cannot simply wrap the API because + the logic flow requires user input (AWS credentials and key). + """ + + def __init__(self, interface: Interface) -> None: + super(SetupAwsScreen, self).__init__() + + self.interface = interface + self.stage = 0 + self.bucket_name: str = "" + self.aws_region: str = "" + self.failed_attempts = 1 + + def compose(self) -> ComposeResult: + yield Container( + Horizontal( + Static( + "Ready to setup AWS S3. Enter bucket name and region, then press OK.", + id="messagebox_message_label", + ), + id="messagebox_message_container", + ), + Input(placeholder="AWS Bucket Name", id="setup_aws_bucket_input"), + Input(placeholder="AWS Region", id="setup_aws_region_input"), + Horizontal( + Button("OK", id="setup_aws_ok_button"), + Button("Cancel", id="setup_aws_cancel_button"), + id="messagebox_buttons_horizontal", + ), + id="setup_aws_screen_container", + ) + + def on_mount(self) -> None: + """Hide region input until the bucket name is verified.""" + self.query_one("#setup_aws_bucket_input").visible = False + self.query_one("#setup_aws_region_input").visible = False + + def on_button_pressed(self, event: Button.pressed) -> None: + """ + Handle button presses for each stage: + 1. Confirm AWS credentials. + 2. Save credentials and prompt for key. + 3. Use the key to finalize AWS setup. + """ + if event.button.id == "setup_aws_cancel_button": + self.dismiss() + + if event.button.id == "setup_aws_ok_button": + if self.stage == 0: + self.ask_user_to_accept_aws_bucket() + + elif self.stage == 1: + self.save_aws_bucket_and_prompt_region_input() + + elif self.stage == 2: + self.use_aws_bucket_and_region_to_setup_aws_connection() + + elif self.stage == 3: + self.dismiss() + + def ask_user_to_accept_aws_bucket(self) -> None: + """ + Verify that the AWS S3 bucket is accessible. + Ask the user to confirm trusting this bucket for future use. + """ + success, output = self.interface.get_aws_hostkey() + + if success: + self.bucket_name = output + + message = ( + f"The AWS bucket '{self.bucket_name}' is accessible.\n\n" + "If you trust this bucket and want to use it for transfers, press OK." + ) + else: + message = ( + "Could not verify the AWS bucket.\nCheck the connection and the bucket name.\n\n" + f"Traceback: {output}" + ) + self.query_one("#setup_aws_ok_button").disabled = True + + self.query_one("#messagebox_message_label").update(message) + self.stage += 1 + + def save_aws_bucket_and_prompt_region_input(self) -> None: + """ + Once the AWS bucket is accepted, prompt the user for the region. + When 'OK' is pressed, we go straight to 'use_region_to_setup_aws_connection'. + """ + success, output = self.interface.save_aws_key_locally(self.bucket_name) + + if success: + message = ( + "AWS bucket verified.\n\nNext, enter your AWS region below to complete setup. " + "Press OK to confirm." + ) + self.query_one("#setup_aws_region_input").visible = True + else: + message = ( + f"Could not store AWS bucket name. Check permissions " + f"for: \n\n {self.interface.get_configs().aws_key_path}.\n\n Traceback: {output}" + ) + self.query_one("#setup_aws_ok_button").disabled = True + + self.query_one("#messagebox_message_label").update(message) + self.stage += 1 + + def use_aws_bucket_and_region_to_setup_aws_connection(self) -> None: + """ + Use the AWS bucket name and region to complete the setup. + If successful, the OK button changes to 'Finish'. + Otherwise, prompt for another attempt. + """ + bucket_name = self.query_one("#setup_aws_bucket_name_input").value + region = self.query_one("#setup_aws_region_input").value + + success, output = self.interface.setup_aws_bucket_and_rclone_config( + bucket_name, region + ) + + if success: + message = ( + f"AWS setup successful! Config saved to " + f"{self.interface.get_configs().aws_credentials_path}" + ) + self.query_one("#setup_aws_ok_button").label = "Finish" + self.query_one("#setup_aws_cancel_button").disabled = True + self.stage += 1 + + else: + message = ( + "AWS setup failed. Check that your bucket name and region are correct and try again." + f"\n\n{self.failed_attempts} failed attempts." + f"\n\nTraceback: {output}" + ) + self.failed_attempts += 1 + + self.query_one("#messagebox_message_label").update(message) diff --git a/datashuttle/tui/screens/setup_gdrive.py b/datashuttle/tui/screens/setup_gdrive.py new file mode 100644 index 000000000..32c5c9692 --- /dev/null +++ b/datashuttle/tui/screens/setup_gdrive.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from textual.app import ComposeResult + + from datashuttle.tui.interface import Interface + +from textual.containers import Container, Horizontal +from textual.screen import ModalScreen +from textual.widgets import ( + Button, + Input, + Static, +) + + +class SetupGdriveScreen(ModalScreen): + """ + This dialog window handles the TUI equivalent of API's + setup_gdrive_connection(). This asks to confirm Google Drive + credentials and takes an OAuth folder_id for setup. + + the TUI cannot simply wrap the API because + the logic flow requires user input (OAuth folder_id and confirmation). + """ + + def __init__(self, interface: Interface) -> None: + super(SetupGdriveScreen, self).__init__() + + self.interface = interface + self.stage = 0 + self.folder_id: str = "" + + def compose(self) -> ComposeResult: + yield Container( + Horizontal( + Static( + "Ready to setup Google Drive. Enter folder ID, then press OK.", + id="messagebox_message_label", + ), + id="messagebox_message_container", + ), + Input( + placeholder="Google Drive Folder ID", + id="setup_gdrive_folder_input", + ), + Horizontal( + Button("OK", id="setup_gdrive_ok_button"), + Button("Cancel", id="setup_gdrive_cancel_button"), + id="messagebox_buttons_horizontal", + ), + id="setup_gdrive_screen_container", + ) + + def on_mount(self) -> None: + """Ensure UI is clean on start.""" + self.query_one("#setup_gdrive_folder_input").visible = False + + def on_button_pressed(self, event: Button.pressed) -> None: + """ + Handle button presses for each stage: + 1. Confirm Google Drive credentials. + 2. Save credentials and prompt for OAuth folder_id. + 3. Use the folder_id to finalize Google Drive setup. + """ + if event.button.id == "setup_gdrive_cancel_button": + self.dismiss() + + if event.button.id == "setup_gdrive_ok_button": + if self.stage == 0: + self.ask_user_to_accept_gdrive_folder() + + elif self.stage == 1: + self.save_gdrive_folder_and_prompt_setup() + + elif self.stage == 2: + self.use_folder_id_to_setup_gdrive_connection() + + elif self.stage == 3: + self.dismiss() + + def ask_user_to_accept_gdrive_folder(self) -> None: + """ + Verify that the Google Drive folder ID is accessible. + Ask the user to confirm trusting this folder for future use. + """ + success, output = self.interface.get_gdrive_hostkey() + + if success: + self.folder_id = output + + message = ( + f"The Google Drive folder ID '{self.folder_id}' is accessible.\n\n" + "If you trust this folder and want to use it for transfers, press OK." + ) + else: + message = ( + "Could not verify the Google Drive folder.\nCheck the connection and the folder ID.\n\n" + f"Traceback: {output}" + ) + self.query_one("#setup_gdrive_ok_button").disabled = True + + self.query_one("#messagebox_message_label").update(message) + self.stage += 1 + + def save_gdrive_folder_and_prompt_setup(self) -> None: + """ + Once the Google Drive folder ID is accepted, confirm setup. + No additional credentials are needed after this step. + """ + success, output = self.interface.save_gdrive_key_locally( + self.folder_id + ) + + if success: + message = ( + "Google Drive folder ID verified.\n\nSetup is now complete. " + "You can proceed to transfer files." + ) + else: + message = ( + f"Could not store Google Drive folder ID. Check permissions " + f"for: \n\n {self.interface.get_configs().gdrive_key_path}.\n\n Traceback: {output}" + ) + self.query_one("#setup_gdrive_ok_button").disabled = True + + self.query_one("#messagebox_message_label").update(message) + self.stage += 1 + + def use_folder_id_to_setup_gdrive_connection(self) -> None: + """ + Use the OAuth folder_id to complete the Google Drive setup. + If successful, the OK button changes to 'Finish'. + Otherwise, prompt for another attempt. + """ + folder_id = self.query_one("#setup_gdrive_folder_id_input").value + + success, output = self.interface.setup_gdrive_folder_and_rclone_config( + folder_id + ) + + if success: + message = ( + f"Google Drive setup successful! Credentials saved to " + f"{self.interface.get_configs().gdrive_credentials_path}" + ) + self.query_one("#setup_gdrive_ok_button").label = "Finish" + self.query_one("#setup_gdrive_cancel_button").disabled = True + self.stage += 1 + + else: + message = ( + "Google Drive setup failed. Check that your OAuth folder_id is correct and try again." + f"\n\n{self.failed_attempts} failed attempts." + f"\n\n Traceback: {output}" + ) + self.failed_attempts += 1 + + self.query_one("#messagebox_message_label").update(message) diff --git a/datashuttle/tui/tooltips.py b/datashuttle/tui/tooltips.py index 61cb9f70d..9583eb89e 100644 --- a/datashuttle/tui/tooltips.py +++ b/datashuttle/tui/tooltips.py @@ -33,10 +33,40 @@ def get_tooltip(id: str) -> str: elif id == "#configs_ssh_radiobutton": tooltip = "Use SSH when planning to connect with the central data storage via SSH protocol." + # AWS S3 radiobutton + elif id == "#configs_aws_radiobutton": + tooltip = ( + "Use AWS S3 when planning to connect with your AWS storage bucket." + ) + + # Google Drive radiobutton + elif id == "#configs_gdrive_radiobutton": + tooltip = "Use Google Drive when planning to connect with your Google Drive folder." + # No connection (local only) radiobutton elif id == "#configs_local_only_radiobutton": tooltip = "No connection to a central project is made.\nTransfer functionality will not be available." + # AWS S3 Inputs + elif id == "#configs_aws_bucket_name_input": + tooltip = ( + "Name of your AWS S3 bucket where project data will be stored.\n\n" + "Ensure the bucket already exists before attempting to connect." + ) + + elif id == "#configs_aws_region_input": + tooltip = ( + "AWS region where your S3 bucket is located.\n\n" + "Example regions: us-east-1, eu-west-1, etc." + ) + + # Google Drive Inputs + elif id == "#configs_gdrive_folder_id_input": + tooltip = ( + "The folder ID of your Google Drive folder for project data storage.\n\n" + "Find the folder ID in your Google Drive folder URL." + ) + # central host input elif id == "#configs_central_host_id_input": tooltip = "The hostname or IP address of the server." @@ -51,6 +81,20 @@ def get_tooltip(id: str) -> str: "The path to the project folder on the central machine (or it's parent folder).\n\n" "With 'SSH', this path is relative to the server e.g. /nhome/users/myusername" ) + # AWS S3 Tooltip + elif id == "config_central_path_input_mode-aws": + tooltip = ( + "The path to the project folder on AWS S3.\n\n" + "Provide the full bucket path. For example:\n" + "s3://my-bucket-name/project_folder" + ) + + # Google Drive Tooltip + elif id == "config_central_path_input_mode-gdrive": + tooltip = ( + "The path to the project folder on Google Drive.\n\n" + "Provide the folder ID associated with your Google Drive project folder." + ) elif id == "config_central_path_input_mode-local_filesystem": tooltip = ( diff --git a/datashuttle/utils/aws.py b/datashuttle/utils/aws.py new file mode 100644 index 000000000..477212f9d --- /dev/null +++ b/datashuttle/utils/aws.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +import subprocess +from pathlib import Path +from typing import TYPE_CHECKING, Any, List, Tuple + +if TYPE_CHECKING: + from datashuttle.configs.config_class import Configs + + +import fnmatch + +from datashuttle.utils import utils + + +def get_remote_aws_key(bucket_name: str) -> Tuple[bool, str]: + """ + Attempt to list contents of the AWS S3 bucket to check access. + """ + remote_path = f"aws_remote:{bucket_name}" + try: + subprocess.run( + ["rclone", "lsf", remote_path], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + return True, "" + except subprocess.CalledProcessError as e: + return False, e.stderr.decode() + + +def save_aws_key_locally( + bucket_name: str, aws_region: str, central_path: Path +) -> None: + """ + Save the AWS bucket name and region in the central path, following SSH-style storage. + """ + central_path.parent.mkdir(parents=True, exist_ok=True) + + with open(central_path, "w") as file: + file.write(f"Bucket: {bucket_name}\nRegion: {aws_region}") + + +def connect_aws_with_logging( + cfg: Configs, + message_on_sucessful_connection: bool = True, +) -> None: + """ + Connect to AWS S3 using rclone by testing access to the remote. + This assumes rclone has already been configured properly. + """ + remote = cfg.get_rclone_config_name("AWS S3") + + try: + subprocess.run( + ["rclone", "lsf", f"{remote}:"], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + if message_on_sucessful_connection: + utils.print_message_to_user( + f"Connection to AWS S3 remote '{remote}' made successfully." + ) + + except Exception as e: + utils.log_and_raise_error( + f"Could not connect to AWS S3. Ensure that:\n" + f"1) You have run setup_aws_connection()\n" + f"2) Your rclone remote '{remote}' is correctly configured\n" + f"3) The bucket exists and credentials are valid.\n\n" + f"Error:\n{e}", + ConnectionError, + ) + + +def search_aws_remote_for_folders( + search_path: Path, + search_prefix: str, + cfg: Configs, + verbose: bool = True, + return_full_path: bool = False, +) -> Tuple[List[Any], List[Any]]: + """ + Search for the search prefix in the search path over AWS S3. + Returns the list of matching folders, files are filtered out. + + Parameters + ----------- + + search_path : path to search for folders in + + search_prefix : search prefix for folder names e.g. "sub-*" + + cfg : project config object (provides bucket and credentials) + + verbose : If `True`, if a search folder cannot be found, a message + will be printed with the un-found path. + """ + all_folder_names, all_filenames = get_list_of_folder_names_over_aws( + cfg, + search_path, + search_prefix, + verbose, + return_full_path, + ) + + return all_folder_names, all_filenames + + +def get_list_of_folder_names_over_aws( + cfg: Configs, + search_path: Path, + search_prefix: str, + verbose: bool = True, + return_full_path: bool = False, +) -> Tuple[List[Any], List[Any]]: + """ + Use rclone to search a path over AWS S3 for folders. + Return the folder names. + + Parameters + ---------- + + cfg : datashuttle project config object + + search_path : path to search for folders in (inside the bucket) + + search_prefix : prefix (can include wildcards) to search folder names + + verbose : If `True`, if a search folder cannot be found, a message + will be printed with the un-found path. + + return_full_path : If `True`, return full rclone remote path, + else return only folder name + """ + remote_path = ( + f"{cfg.get_rclone_config_name('AWS S3')}:{search_path.as_posix()}" + ) + + all_folder_names = [] + all_filenames = [] + + try: + result = subprocess.run( + ["rclone", "lsjson", remote_path], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + entries = json.loads(result.stdout) + + for entry in entries: + name = entry["Name"] + is_dir = entry.get("IsDir", False) + + if fnmatch.fnmatch(name, search_prefix): + to_append = ( + f"{remote_path}/{name}" if return_full_path else name + ) + if is_dir: + all_folder_names.append(to_append) + else: + all_filenames.append(to_append) + + except subprocess.CalledProcessError as e: + if verbose: + utils.log_and_message( + f"No file found at {remote_path}\n{e.stderr}" + ) + + return all_folder_names, all_filenames + + +def verify_aws_remote( + bucket_name: str, aws_key_path: Path, log: bool = True +) -> bool: + """ + Prompt user to trust and save an AWS S3 bucket for future use. + """ + success, _ = get_remote_aws_key(bucket_name) + if not success: + utils.print_message_to_user( + "Unable to access the AWS S3 bucket. Make sure it exists and is accessible." + ) + return False + + message = ( + f"You're about to trust this AWS S3 bucket: {bucket_name}\n" + "If you trust it, type 'y' to save and proceed: " + ) + input_ = utils.get_user_input(message) + + if input_ == "y": + save_aws_key_locally(bucket_name, aws_key_path) + if log: + utils.log( + f"AWS S3 bucket {bucket_name} trusted and saved at {aws_key_path}" + ) + utils.print_message_to_user("AWS bucket accepted.") + return True + else: + utils.print_message_to_user("Bucket not accepted. No connection made.") + return False diff --git a/datashuttle/utils/decorators.py b/datashuttle/utils/decorators.py index cacf54914..b257e319d 100644 --- a/datashuttle/utils/decorators.py +++ b/datashuttle/utils/decorators.py @@ -28,6 +28,46 @@ def wrapper(*args, **kwargs): return wrapper +def requires_aws_configs(func): + """ + Decorator to check AWS configs are loaded before running the function. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + if not args[0].cfg["aws_bucket_name"] or not args[0].cfg["aws_region"]: + log_and_raise_error( + "Cannot setup AWS connection, 'aws_bucket_name', " + "'aws_access_key', or 'aws_secret_key' is not set in " + "the configuration file.", + ConfigError, + ) + else: + return func(*args, **kwargs) + + return wrapper + + +def requires_gdrive_configs(func): + """ + Decorator to check Google Drive configs are loaded before running the function. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + if not args[0].cfg["gdrive_folder_id"]: + log_and_raise_error( + "Cannot setup Google Drive connection, 'gdrive_folder_id', " + "'gdrive_client_id', or 'gdrive_client_secret' is not set in " + "the configuration file.", + ConfigError, + ) + else: + return func(*args, **kwargs) + + return wrapper + + def check_configs_set(func): """ Check that configs have been loaded (i.e. diff --git a/datashuttle/utils/folders.py b/datashuttle/utils/folders.py index df6cdb061..ddc1d2649 100644 --- a/datashuttle/utils/folders.py +++ b/datashuttle/utils/folders.py @@ -20,7 +20,7 @@ from pathlib import Path from datashuttle.configs import canonical_folders, canonical_tags -from datashuttle.utils import ssh, utils, validation +from datashuttle.utils import aws, gdrive, ssh, utils, validation from datashuttle.utils.custom_exceptions import NeuroBlueprintError # ----------------------------------------------------------------------------- @@ -514,14 +514,52 @@ def search_for_folders( verbose : If `True`, when a search folder cannot be found, a message will be printed with the missing path. """ - if local_or_central == "central" and cfg["connection_method"] == "ssh": - all_folder_names, all_filenames = ssh.search_ssh_central_for_folders( - search_path, - search_prefix, - cfg, - verbose, - return_full_path, - ) + if local_or_central == "central": + if cfg["connection_method"] == "ssh": + all_folder_names, all_filenames = ( + ssh.search_ssh_central_for_folders( + search_path, + search_prefix, + cfg, + verbose, + return_full_path, + ) + ) + elif cfg["connection_method"] == "AWS S3": + all_folder_names, all_filenames = ( + aws.search_aws_remote_for_folders( + search_path, + search_prefix, + cfg, + verbose, + return_full_path, + ) + ) + elif cfg["connection_method"] == "Google Drive": + all_folder_names, all_filenames = ( + gdrive.search_gdrive_remote_for_folders( + search_path, + search_prefix, + cfg, + verbose, + return_full_path, + ) + ) + else: + # Default to filesystem search if no valid method found + if not search_path.exists(): + if verbose: + utils.log_and_message( + f"No file found at {search_path.as_posix()}" + ) + return [], [] + + all_folder_names, all_filenames = ( + search_filesystem_path_for_folders( + search_path / search_prefix, return_full_path + ) + ) + else: if not search_path.exists(): if verbose: @@ -533,6 +571,7 @@ def search_for_folders( all_folder_names, all_filenames = search_filesystem_path_for_folders( search_path / search_prefix, return_full_path ) + return all_folder_names, all_filenames diff --git a/datashuttle/utils/gdrive.py b/datashuttle/utils/gdrive.py new file mode 100644 index 000000000..ca3beb3c8 --- /dev/null +++ b/datashuttle/utils/gdrive.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +import subprocess +from pathlib import Path +from typing import TYPE_CHECKING, Any, List, Tuple + +if TYPE_CHECKING: + from datashuttle.configs.config_class import Configs +import fnmatch + +from datashuttle.utils import utils + + +def get_remote_gdrive_key(folder_id: str) -> Tuple[bool, str]: + """ + Attempt to list contents of the Google Drive folder to check access. + """ + remote_path = f"gdrive_remote:{folder_id}" + try: + subprocess.run( + ["rclone", "lsf", remote_path], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + return True, "" + except subprocess.CalledProcessError as e: + return False, e.stderr.decode() + + +def save_gdrive_key_locally( + folder_id: str, remote_name: str, central_path: Path +) -> None: + """ + Save the trusted Google Drive folder ID and remote name in the central path. + """ + central_path.parent.mkdir(parents=True, exist_ok=True) + + with open(central_path, "w") as file: + file.write(f"Folder ID: {folder_id}\nRemote: {remote_name}") + + +def connect_gdrive_with_logging( + cfg: Configs, + message_on_sucessful_connection: bool = True, +) -> None: + """ + Connect to Google Drive using rclone by testing access to the remote. + This assumes rclone has already been configured properly. + """ + remote = cfg.get_rclone_config_name("Google Drive") + + try: + subprocess.run( + ["rclone", "lsf", f"{remote}:"], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + if message_on_sucessful_connection: + utils.print_message_to_user( + f"Connection to Google Drive remote '{remote}' made successfully." + ) + + except Exception as e: + utils.log_and_raise_error( + f"Could not connect to Google Drive. Ensure that:\n" + f"1) You have run setup_gdrive_connection()\n" + f"2) Your rclone remote '{remote}' is correctly configured\n" + f"3) The folder ID exists and access is authorized.\n\n" + f"Error:\n{e}", + ConnectionError, + ) + + +def search_gdrive_remote_for_folders( + search_path: Path, + search_prefix: str, + cfg: Configs, + verbose: bool = True, + return_full_path: bool = False, +) -> Tuple[List[Any], List[Any]]: + """ + Search for the search prefix in the search path over Google Drive. + Returns the list of matching folders, files are filtered out. + + Parameters + ----------- + + search_path : path to search for folders in + + search_prefix : search prefix for folder names e.g. "sub-*" + + cfg : project config object (provides remote and credentials) + + verbose : If `True`, if a search folder cannot be found, a message + will be printed with the un-found path. + """ + all_folder_names, all_filenames = get_list_of_folder_names_over_gdrive( + cfg, + search_path, + search_prefix, + verbose, + return_full_path, + ) + + return all_folder_names, all_filenames + + +def get_list_of_folder_names_over_gdrive( + cfg: Configs, + search_path: Path, + search_prefix: str, + verbose: bool = True, + return_full_path: bool = False, +) -> Tuple[List[Any], List[Any]]: + """ + Use rclone to search a path over Google Drive for folders. + Return the folder names. + + Parameters + ---------- + + cfg : datashuttle project config object + + search_path : path to search for folders in (inside the GDrive folder) + + search_prefix : prefix (can include wildcards) to search folder names + + verbose : If `True`, if a search folder cannot be found, a message + will be printed with the un-found path. + + return_full_path : If `True`, return full rclone remote path, + else return only folder name + """ + remote_path = f"{cfg.get_rclone_config_name('Google Drive')}:{search_path.as_posix()}" + + all_folder_names = [] + all_filenames = [] + + try: + result = subprocess.run( + ["rclone", "lsjson", remote_path], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + entries = json.loads(result.stdout) + + for entry in entries: + name = entry["Name"] + is_dir = entry.get("IsDir", False) + + if fnmatch.fnmatch(name, search_prefix): + to_append = ( + f"{remote_path}/{name}" if return_full_path else name + ) + if is_dir: + all_folder_names.append(to_append) + else: + all_filenames.append(to_append) + + except subprocess.CalledProcessError as e: + if verbose: + utils.log_and_message( + f"No file found at {remote_path}\n{e.stderr}" + ) + + return all_folder_names, all_filenames + + +def verify_gdrive_remote( + folder_id: str, gdrive_key_path: Path, log: bool = True +) -> bool: + """ + Prompt user to trust and save a GDrive folder ID for future use. + """ + success, _ = get_remote_gdrive_key(folder_id) + if not success: + utils.print_message_to_user( + "Unable to access the Google Drive folder. Make sure it's shared and reachable." + ) + return False + + message = ( + f"You're about to trust this Google Drive folder ID: {folder_id}\n" + "If you trust it, type 'y' to save and proceed: " + ) + input_ = utils.get_user_input(message) + + if input_ == "y": + save_gdrive_key_locally(folder_id, gdrive_key_path) + if log: + utils.log( + f"Google Drive folder ID {folder_id} trusted and saved at {gdrive_key_path}" + ) + utils.print_message_to_user("Google Drive folder accepted.") + return True + else: + utils.print_message_to_user("Folder not accepted. No connection made.") + return False diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index 9644c103c..5e400093a 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -108,6 +108,83 @@ def setup_rclone_config_for_ssh( log_rclone_config_output() +def setup_rclone_config_for_gdrive( + cfg: Configs, + rclone_config_name: str, + gdrive_folder_id: str, + log: bool = True, +): + """ + RClone sets remote targets in a config file that are + used at transfer. For Google Drive, this sets the root folder. + The relative path is supplied later during transfer. + + Parameters + ---------- + cfg : Configs + datashuttle configs UserDict. + + rclone_config_name : str + Canonical config name, generated by + datashuttle.cfg.get_rclone_config_name() + + gdrive_folder_id : str + The Google Drive Folder ID to use as the root directory. + + log : bool + Whether to log, if True logger must already be initialised. + """ + call_rclone( + f"config create " + f"{rclone_config_name} " + f"drive " + f"scope drive " + f"root_folder_id {gdrive_folder_id}", + pipe_std=True, + ) + + if log: + log_rclone_config_output() + + +def setup_rclone_config_for_gdrive( + cfg: Configs, + rclone_config_name: str, + gdrive_folder_id: str, + log: bool = True, +): + """ + RClone sets remote targets in a config file that are + used for Google Drive. The relative path is supplied at transfer time. + + Parameters + ---------- + cfg : Configs + datashuttle configs UserDict. + + rclone_config_name : str + Canonical config name, generated by + datashuttle.cfg.get_rclone_config_name(). + + gdrive_folder_id : str + The Google Drive Folder ID where files will be stored. + + log : bool + Whether to log, if True logger must already be initialized. + """ + call_rclone( + f"config create " + f"{rclone_config_name} " + f"drive " + f"scope drive " + f"root_folder_id {gdrive_folder_id}", + pipe_std=True, + ) + + if log: + log_rclone_config_output() + + def log_rclone_config_output(): output = call_rclone("config file", pipe_std=True) utils.log( @@ -180,27 +257,44 @@ def transfer_data( "download", ], "must be 'upload' or 'download'" - local_filepath = cfg.get_base_folder("local", top_level_folder).as_posix() + connection_method = cfg["connection_method"] + local_filepath = cfg.get_base_folder("local", top_level_folder).as_posix() central_filepath = cfg.get_base_folder( "central", top_level_folder ).as_posix() - extra_arguments = handle_rclone_arguments(rclone_options, include_list) + # AWS S3 Path Formatting + if connection_method == "AWS S3": + central_filepath = ( + f"{cfg.get_rclone_config_name('AWS S3')}:{central_filepath}" + ) + + # Google Drive Path Formatting + elif connection_method == "Google Drive": + central_filepath = ( + f"{cfg.get_rclone_config_name('Google Drive')}:{central_filepath}" + ) + + # Default (SSH or Local Filesystem) + else: + central_filepath = f"{cfg.get_rclone_config_name()}:{central_filepath}" + + extra_arguments = handle_rclone_arguments( + rclone_options, include_list, connection_method + ) if upload_or_download == "upload": output = call_rclone( f"{rclone_args('copy')} " - f'"{local_filepath}" "{cfg.get_rclone_config_name()}:' - f'{central_filepath}" {extra_arguments}', + f'"{local_filepath}" "{central_filepath}" {extra_arguments}', pipe_std=True, ) elif upload_or_download == "download": output = call_rclone( f"{rclone_args('copy')} " - f'"{cfg.get_rclone_config_name()}:' - f'{central_filepath}" "{local_filepath}" {extra_arguments}', + f'"{central_filepath}" "{local_filepath}" {extra_arguments}', pipe_std=True, ) @@ -288,10 +382,10 @@ def perform_rclone_check( cfg: Configs, top_level_folder: TopLevelFolder ) -> str: """ - Use Rclone's `check` command to build a list of files that - are the same ("="), different ("*"), found in local only ("+") - or central only ("-"). The output is formatted as " \n". + Perform RClone check with AWS S3 and Google Drive support. """ + connection_method = cfg["connection_method"] + local_filepath = cfg.get_base_folder( "local", top_level_folder ).parent.as_posix() @@ -299,10 +393,21 @@ def perform_rclone_check( "central", top_level_folder ).parent.as_posix() + if connection_method == "AWS S3": + central_filepath = ( + f"{cfg.get_rclone_config_name('AWS S3')}:{central_filepath}" + ) + elif connection_method == "Google Drive": + central_filepath = ( + f"{cfg.get_rclone_config_name('Google Drive')}:{central_filepath}" + ) + else: + central_filepath = f"{cfg.get_rclone_config_name()}:{central_filepath}" + output = call_rclone( f'{rclone_args("check")} ' f'"{local_filepath}" ' - f'"{cfg.get_rclone_config_name()}:{central_filepath}"' + f'"{central_filepath}"' f" --combined -", pipe_std=True, )