From 37d21ffff9ebe0352d51e35b965afb94cbdae9c9 Mon Sep 17 00:00:00 2001 From: Ijv3-0 Date: Fri, 21 Mar 2025 22:38:05 +0530 Subject: [PATCH 1/2] feat: Add support for cloud --- datashuttle/configs/canonical_configs.py | 21 +- datashuttle/configs/config_class.py | 6 + datashuttle/datashuttle_class.py | 122 +++++-- datashuttle/tui/configs.py | 444 +++++++++++++++++------ datashuttle/tui/css/tui_menu.tcss | 39 ++ datashuttle/tui/interface.py | 196 ++++++++-- datashuttle/tui/screens/setup_aws.py | 161 ++++++++ datashuttle/tui/screens/setup_gdrive.py | 155 ++++++++ datashuttle/tui/tooltips.py | 42 +++ datashuttle/utils/aws.py | 192 ++++++++++ datashuttle/utils/decorators.py | 39 ++ datashuttle/utils/folders.py | 49 ++- datashuttle/utils/gdrive.py | 190 ++++++++++ datashuttle/utils/rclone.py | 129 ++++++- 14 files changed, 1601 insertions(+), 184 deletions(-) create mode 100644 datashuttle/tui/screens/setup_aws.py create mode 100644 datashuttle/tui/screens/setup_gdrive.py create mode 100644 datashuttle/utils/aws.py create mode 100644 datashuttle/utils/gdrive.py diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index b65ad6c66..28af03bb2 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -39,9 +39,12 @@ def get_canonical_configs() -> dict: canonical_configs = { "local_path": Union[str, Path], "central_path": Optional[Union[str, Path]], - "connection_method": Optional[Literal["ssh", "local_filesystem"]], + "connection_method": Optional[Literal["ssh", "local_filesystem", "aws", "gdrive"]], "central_host_id": Optional[str], "central_host_username": Optional[str], + "aws_bucket_name": Optional[str], + "aws_region": Optional[str], + "gdrive_folder_id": Optional[str], } return canonical_configs @@ -128,6 +131,22 @@ def check_dict_values_raise_on_fail(config_dict: Configs) -> None: ConfigError, ) + # Check AWS S3 settings + if config_dict["connection_method"] == "aws" and ( + not config_dict["aws_bucket_name"] or not config_dict["aws_region"] + ): + utils.log_and_raise_error( + "'aws_bucket_name' and 'aws_region' are required if 'connection_method' is 'aws'.", + ConfigError, + ) + + # Check Google Drive settings + if config_dict["connection_method"] == "gdrive" and not config_dict["gdrive_folder_id"]: + utils.log_and_raise_error( + "'gdrive_folder_id' is required if 'connection_method' is 'gdrive'.", + ConfigError, + ) + # Initialise the local project folder utils.print_message_to_user( f"Making project folder at: {config_dict['local_path']}" diff --git a/datashuttle/configs/config_class.py b/datashuttle/configs/config_class.py index 5fd70fcad..91a1d094e 100644 --- a/datashuttle/configs/config_class.py +++ b/datashuttle/configs/config_class.py @@ -59,6 +59,8 @@ def __init__( self.logging_path: Path self.hostkeys_path: Path self.ssh_key_path: Path + self.aws_key_path: Path + self.gdrive_fo_path: Path self.project_metadata_path: Path def setup_after_load(self) -> None: @@ -236,6 +238,10 @@ def init_paths(self) -> None: self.ssh_key_path = datashuttle_path / f"{self.project_name}_ssh_key" + self.aws_key_path = datashuttle_path / f"{self.project_name}_aws_key" + + self.gdrive_key_path = datashuttle_path / f"{self.project_name}_gdrive_key" + self.hostkeys_path = datashuttle_path / "hostkeys" self.logging_path = self.make_and_get_logging_path() diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index 7535e2664..35e0c853f 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -42,6 +42,8 @@ getters, rclone, ssh, + gdrive, + aws, utils, validation, ) @@ -54,6 +56,8 @@ check_configs_set, check_is_not_local_project, requires_ssh_configs, + requires_aws_configs, + requires_gdrive_configs, ) # ----------------------------------------------------------------------------- @@ -903,47 +907,44 @@ def make_config_file( connection_method: str | None = None, central_host_id: Optional[str] = None, central_host_username: Optional[str] = None, + aws_bucket_name: Optional[str] = None, + aws_region: Optional[str] = "us-east-1", + gdrive_folder_id: Optional[str] = None, ) -> None: """ Initialise the configurations for datashuttle to use on the local machine. Once initialised, these settings will be - used each time the datashuttle is opened. This method - can also be used to completely overwrite existing configs. - - These settings are stored in a config file on the - datashuttle path (not in the project folder) - on the local machine. Use get_config_path() to - get the full path to the saved config file. - - Use update_config_file() to selectively update settings. + used each time datashuttle is opened. Parameters ---------- - local_path : - path to project folder on local machine + Path to project folder on local machine. central_path : - Filepath to central project. - If this is local (i.e. connection_method = "local_filesystem"), - this is the full path on the local filesystem - Otherwise, if this is via ssh (i.e. connection method = "ssh"), - this is the path to the project folder on central machine. - This should be a full path to central folder i.e. this cannot - include ~ home folder syntax, must contain the full path - (e.g. /nfs/nhome/live/jziminski) + Filepath to central project (local or SSH). connection_method : - The method used to connect to the central project filesystem, - e.g. "local_filesystem" (e.g. mounted drive) or "ssh" + The method used to connect to the central project filesystem: + - "local_filesystem" (mounted drive) + - "ssh" (remote connection) + - "aws_s3" (Amazon S3 cloud storage) + - "google_drive" (Google Drive cloud storage) central_host_id : - server address for central host for ssh connection - e.g. "ssh.swc.ucl.ac.uk" + Server address for SSH connection. central_host_username : - username for which to log in to central host. - e.g. "jziminski" + Username for SSH login. + + aws_bucket_name : + Name of the AWS S3 bucket (required for AWS). + + aws_region : + AWS region (default: "us-east-1"). + + google_drive_folder_id : + Folder ID for Google Drive (required for Google Drive). """ self._start_log( "make-config-file", @@ -967,6 +968,9 @@ def make_config_file( "connection_method": connection_method, "central_host_id": central_host_id, "central_host_username": central_host_username, + "aws_bucket_name": aws_bucket_name, + "aws_region": aws_region, + "gdrive_folder_id": gdrive_folder_id, }, ) @@ -1463,6 +1467,22 @@ def _setup_rclone_central_local_filesystem_config(self) -> None: rclone.setup_rclone_config_for_local_filesystem( self.cfg.get_rclone_config_name("local_filesystem"), ) + + def _setup_rclone_central_aws_config(self, log: bool) -> None: + rclone.setup_rclone_config_for_aws( + self.cfg, + self.cfg.get_rclone_config_name("aws"), + self.cfg["aws_region"], + log=log, + ) + + def _setup_rclone_central_gdrive_config(self, log: bool) -> None: + rclone.setup_rclone_config_for_gdrive( + self.cfg, + self.cfg.get_rclone_config_name("gdrive"), + self.cfg["gdrive_folder_id"], + log=log, + ) # Persistent settings # ------------------------------------------------------------------------- @@ -1565,3 +1585,55 @@ def _check_top_level_folder(self, top_level_folder): f"{canonical_top_level_folders}", ValueError, ) + + # ------------------------------------------------------------------------- + # AWS S3 and Google Drive + # ------------------------------------------------------------------------- + + @requires_aws_configs + @check_is_not_local_project + def setup_aws_connection(self) -> None: + """ + Setup a connection to AWS S3 using RClone. + + Assumes the aws_bucket_name and aws_region are set in configs. + First, prompt the user to verify the AWS bucket as trusted, + then create the RClone config for AWS. + """ + self._start_log("setup-aws-connection", local_vars=locals()) + + verified = aws.verify_aws_remote( + self.cfg["aws_bucket_name"], + self.cfg["aws_region"], + self.cfg.aws_key_path, + log=True, + ) + + if verified: + self._setup_rclone_central_aws_config(log=True) + + ds_logger.close_log_filehandler() + + + @requires_gdrive_configs + @check_is_not_local_project + def setup_gdrive_connection(self) -> None: + """ + Setup a connection to Google Drive using RClone. + + Assumes the gdrive_folder_id is set in configs. + First, prompt the user to verify and trust the folder ID, + then create the RClone config for Google Drive. + """ + self._start_log("setup-gdrive-connection", local_vars=locals()) + + verified = gdrive.verify_gdrive_remote( + self.cfg["gdrive_folder_id"], + self.cfg.gdrive_key_path, + log=True, + ) + + if verified: + self._setup_rclone_central_gdrive_config(log=True) + + ds_logger.close_log_filehandler() diff --git a/datashuttle/tui/configs.py b/datashuttle/tui/configs.py index 974ee08af..4f55a5a15 100644 --- a/datashuttle/tui/configs.py +++ b/datashuttle/tui/configs.py @@ -25,7 +25,7 @@ from datashuttle.tui.custom_widgets import ClickableInput from datashuttle.tui.interface import Interface -from datashuttle.tui.screens import modal_dialogs, setup_ssh +from datashuttle.tui.screens import modal_dialogs, setup_ssh, setup_aws, setup_gdrive from datashuttle.tui.tooltips import get_tooltip @@ -58,13 +58,19 @@ def __init__( self.parent_class = parent_class self.interface = interface self.config_ssh_widgets: List[Any] = [] + self.config_aws_widgets: List[Any] = [] + self.config_gdrive_widgets: List[Any] = [] def compose(self) -> ComposeResult: """ `self.config_ssh_widgets` are SSH-setup related widgets that are only required when the user selects the SSH connection method. These are displayed / hidden based on the - `connection_method` + `connection_method`. + + `self.config_aws_widgets` are AWS-related widgets. + + `self.config_gdrive_widgets` are Google Drive-related widgets. `config_screen_widgets` are core config-related widgets that are always displayed. @@ -90,6 +96,30 @@ def compose(self) -> ComposeResult: ), ] + self.config_aws_widgets = [ + Label("AWS Bucket Name", id="configs_aws_bucket_name_label"), + ClickableInput( + self.parent_class.mainwindow, + placeholder="e.g. my-bucket-name", + id="configs_aws_bucket_name_input", + ), + Label("AWS Region", id="configs_aws_region_label"), + ClickableInput( + self.parent_class.mainwindow, + placeholder="e.g. us-east-1", + id="configs_aws_region_input", + ), + ] + + self.config_gdrive_widgets = [ + Label("Google Drive Folder ID", id="configs_gdrive_folder_id_label"), + ClickableInput( + self.parent_class.mainwindow, + placeholder="e.g. 1A2B3C4D5E6F7G8H", + id="configs_gdrive_folder_id_input", + ), + ] + config_screen_widgets = [ Label("Local Path", id="configs_local_path_label"), Horizontal( @@ -103,18 +133,16 @@ def compose(self) -> ComposeResult: ), Label("Connection Method", id="configs_connect_method_label"), RadioSet( - RadioButton( - "Local Filesystem", - id="configs_local_filesystem_radiobutton", - ), + RadioButton("Local Filesystem", id="configs_local_filesystem_radiobutton"), RadioButton("SSH", id="configs_ssh_radiobutton"), - RadioButton( - "No connection (local only)", - id="configs_local_only_radiobutton", - ), + RadioButton("AWS S3", id="configs_aws_radiobutton"), + RadioButton("Google Drive", id="configs_gdrive_radiobutton"), + RadioButton("No connection (local only)", id="configs_local_only_radiobutton"), id="configs_connect_method_radioset", ), *self.config_ssh_widgets, + *self.config_aws_widgets, + *self.config_gdrive_widgets, Label("Central Path", id="configs_central_path_label"), Horizontal( ClickableInput( @@ -127,16 +155,10 @@ def compose(self) -> ComposeResult: ), Horizontal( Button("Save", id="configs_save_configs_button"), - Button( - "Setup SSH Connection", - id="configs_setup_ssh_connection_button", - ), - # Below button is always hidden when accessing - # configs from project manager screen - Button( - "Go to Project Screen", - id="configs_go_to_project_screen_button", - ), + Button("Setup SSH Connection", id="configs_setup_ssh_connection_button"), + Button("Setup AWS Connection", id="configs_setup_aws_connection_button"), + Button("Setup Google Drive Connection", id="configs_setup_gdrive_connection_button"), + Button("Go to Project Screen", id="configs_go_to_project_screen_button"), id="configs_bottom_buttons_horizontal", ), ] @@ -162,114 +184,132 @@ def compose(self) -> ComposeResult: ] if not self.interface: - config_screen_widgets = ( - init_only_config_screen_widgets + config_screen_widgets - ) + config_screen_widgets = init_only_config_screen_widgets + config_screen_widgets yield Container(*config_screen_widgets, id="configs_container") def on_mount(self) -> None: """ - When we have mounted the widgets, the following logic depends on whether - we are setting up a new project (`self.project is `None`) or have - an instantiated project. - - If we have a project, then we want to fill the widgets with the existing - configs. Otherwise, we set to some reasonable defaults, required to - determine the display of SSH widgets. "overwrite_files_checkbox" - should be off by default anyway if `value` is not set, but we set here - anyway as it is critical this is not on by default. + When widgets are mounted, initialize based on whether this is a new + project or editing an existing one. """ - # Setup display widget defaults self.query_one("#configs_go_to_project_screen_button").visible = False + if self.interface: self.fill_widgets_with_project_configs() else: - self.query_one("#configs_local_filesystem_radiobutton").value = ( - True - ) + self.query_one("#configs_local_filesystem_radiobutton").value = True self.switch_ssh_widgets_display(display_ssh=False) - self.query_one("#configs_setup_ssh_connection_button").visible = ( - False - ) + self.switch_aws_widgets_display(display_aws=False) + self.switch_gdrive_widgets_display(display_gdrive=False) + self.query_one("#configs_setup_ssh_connection_button").visible = False + self.query_one("#configs_setup_aws_connection_button").visible = False + self.query_one("#configs_setup_gdrive_connection_button").visible = False # Setup tooltips if not self.interface: id = "#configs_name_input" self.query_one(id).tooltip = get_tooltip(id) - - # Assumes 'local_filesystem' is default if no project set. - assert ( - self.query_one("#configs_local_filesystem_radiobutton").value - is True - ) - self.set_central_path_input_tooltip(display_ssh=False) + assert self.query_one("#configs_local_filesystem_radiobutton").value is True + self.set_central_path_input_tooltip("local_filesystem") else: - display_ssh = ( - self.interface.project.cfg["connection_method"] == "ssh" - ) - self.set_central_path_input_tooltip(display_ssh) + method = self.interface.project.cfg["connection_method"] + self.set_central_path_input_tooltip(method) for id in [ "#configs_local_path_input", "#configs_connect_method_label", "#configs_local_filesystem_radiobutton", "#configs_ssh_radiobutton", + "#configs_aws_radiobutton", + "#configs_gdrive_radiobutton", "#configs_local_only_radiobutton", "#configs_central_host_username_input", "#configs_central_host_id_input", + "#configs_aws_bucket_name_input", + "#configs_aws_region_input", + "#configs_gdrive_folder_id_input", ]: self.query_one(id).tooltip = get_tooltip(id) def on_radio_set_changed(self, event: RadioSet.Changed) -> None: """ - Update the displayed SSH widgets when the `connection_method` - radiobuttons are changed. + Update the displayed widgets and config state when the + `connection_method` radiobuttons are changed. - When SSH is set, ssh config-setters are shown. Otherwise, these - are hidden. - - When mode is `No connection`, the `central_path` is cleared and - disabled. + Supports SSH, AWS S3, Google Drive, Local Filesystem, and + No Connection modes. """ label = str(event.pressed.label) assert label in [ "SSH", + "AWS S3", + "Google Drive", "Local Filesystem", "No connection (local only)", ], "Unexpected label." - if label == "No connection (local only)": - self.query_one("#configs_central_path_input").value = "" - self.query_one("#configs_central_path_input").disabled = True - self.query_one("#configs_central_path_select_button").disabled = ( - True - ) - display_ssh = False - else: - self.query_one("#configs_central_path_input").disabled = False - self.query_one("#configs_central_path_select_button").disabled = ( - False - ) - display_ssh = True if label == "SSH" else False + is_ssh = label == "SSH" + is_aws = label == "AWS S3" + is_gdrive = label == "Google Drive" + is_local = label == "Local Filesystem" + is_none = label == "No connection (local only)" - self.switch_ssh_widgets_display(display_ssh) - self.set_central_path_input_tooltip(display_ssh) + central_input = self.query_one("#configs_central_path_input") + select_button = self.query_one("#configs_central_path_select_button") + + # Disable fields if no connection + if is_none: + central_input.value = "" + central_input.disabled = True + select_button.disabled = True + else: + central_input.disabled = False + select_button.disabled = False + + # Toggle widget groups + if is_ssh: + self.switch_ssh_widgets_display(True) + self.switch_aws_widgets_display(False) + self.switch_gdrive_widgets_display(False) + elif is_aws: + self.switch_ssh_widgets_display(False) + self.switch_aws_widgets_display(True) + self.switch_gdrive_widgets_display(False) + elif is_gdrive: + self.switch_ssh_widgets_display(False) + self.switch_aws_widgets_display(False) + self.switch_gdrive_widgets_display(True) + else: # Local Filesystem or No Connection + self.switch_ssh_widgets_display(False) + self.switch_aws_widgets_display(False) + self.switch_gdrive_widgets_display(False) + + # Tooltip update + if is_ssh: + self.set_central_path_input_tooltip("ssh") + elif is_aws: + self.set_central_path_input_tooltip("aws") + elif is_gdrive: + self.set_central_path_input_tooltip("gdrive") + else: + self.set_central_path_input_tooltip("local_filesystem") - def set_central_path_input_tooltip(self, display_ssh: bool) -> None: + def set_central_path_input_tooltip(self, mode: str) -> None: """ - Use a different tooltip depending on whether connection method - is ssh or local filesystem. + Use a different tooltip depending on the selected connection mode. + `mode` must be one of: 'ssh', 'aws', 'gdrive', 'local_filesystem' """ id = "#configs_central_path_input" - if display_ssh: - self.query_one(id).tooltip = get_tooltip( - "config_central_path_input_mode-ssh" - ) - else: - self.query_one(id).tooltip = get_tooltip( - "config_central_path_input_mode-local_filesystem" - ) + tooltip_id = { + "ssh": "config_central_path_input_mode-ssh", + "aws": "config_central_path_input_mode-aws", + "gdrive": "config_central_path_input_mode-gdrive", + "local_filesystem": "config_central_path_input_mode-local_filesystem", + }.get(mode, "config_central_path_input_mode-local_filesystem") + + self.query_one(id).tooltip = get_tooltip(tooltip_id) + def get_platform_dependent_example_paths( self, local_or_central: Literal["local", "central"], ssh: bool = False @@ -327,11 +367,75 @@ def switch_ssh_widgets_display(self, display_ssh: bool) -> None: placeholder ) + def switch_gdrive_widgets_display(self, display_gdrive: bool) -> None: + """ + Show or hide Google Drive-related configs based on whether the current + `connection_method` widget is "Google Drive". + + Parameters + ---------- + display_gdrive : bool + If `True`, display the Google Drive-related widgets. + """ + for widget in self.config_gdrive_widgets: + widget.display = display_gdrive + + # Hide local filesystem selector button when GDrive is selected + self.query_one("#configs_central_path_select_button").display = ( + not display_gdrive + ) + + # Show or hide GDrive setup button based on interface and mode + if self.interface is None: + self.query_one("#configs_setup_gdrive_connection_button").visible = False + else: + self.query_one("#configs_setup_gdrive_connection_button").visible = display_gdrive + + # Set placeholder if empty + if not self.query_one("#configs_central_path_input").value: + if display_gdrive: + placeholder = "e.g. gdrive://project-folder-id" + else: + placeholder = f"e.g. {self.get_platform_dependent_example_paths('central', ssh=False)}" + self.query_one("#configs_central_path_input").placeholder = placeholder + + def switch_aws_widgets_display(self, display_aws: bool) -> None: + """ + Show or hide AWS S3-related configs based on whether the current + `connection_method` widget is "AWS S3". + + Parameters + ---------- + display_aws : bool + If `True`, display the AWS-related widgets. + """ + for widget in self.config_aws_widgets: + widget.display = display_aws + + # Hide local filesystem selector button when AWS is selected + self.query_one("#configs_central_path_select_button").display = ( + not display_aws + ) + + # Show or hide AWS setup button based on interface and mode + if self.interface is None: + self.query_one("#configs_setup_aws_connection_button").visible = False + else: + self.query_one("#configs_setup_aws_connection_button").visible = display_aws + + # Set placeholder if empty + if not self.query_one("#configs_central_path_input").value: + if display_aws: + placeholder = "e.g. s3://bucket-name/project-path" + else: + placeholder = f"e.g. {self.get_platform_dependent_example_paths('central', ssh=False)}" + self.query_one("#configs_central_path_input").placeholder = placeholder + def on_button_pressed(self, event: Button.Pressed) -> None: """ - Enables the Create Folders button to read out current input values - and use these to call project.create_folders(). + Handle button presses in the configuration screen. """ + if event.button.id == "configs_save_configs_button": if not self.interface: self.setup_configs_for_a_new_project() @@ -341,6 +445,12 @@ def on_button_pressed(self, event: Button.Pressed) -> None: elif event.button.id == "configs_setup_ssh_connection_button": self.setup_ssh_connection() + elif event.button.id == "configs_setup_aws_connection_button": + self.setup_aws_connection() + + elif event.button.id == "configs_setup_gdrive_connection_button": + self.setup_gdrive_connection() + elif event.button.id == "configs_go_to_project_screen_button": self.parent_class.dismiss(self.interface) @@ -363,6 +473,7 @@ def on_button_pressed(self, event: Button.Pressed) -> None: ), ) + def handle_input_fill_from_select_directory( self, path_: Path, local_or_central: Literal["local", "central"] ) -> None: @@ -408,6 +519,38 @@ def setup_ssh_connection(self) -> None: self.parent_class.mainwindow.push_screen( setup_ssh.SetupSshScreen(self.interface) ) + def setup_gdrive_connection(self) -> None: + """ + Set up the `SetupGdriveScreen` screen. + """ + assert self.interface is not None, "type narrow flexible `interface`" + + if not self.widget_configs_match_saved_configs(): + self.parent_class.mainwindow.show_modal_error_dialog( + "The values set above must equal the datashuttle settings. " + "Either press 'Save' or reload this page." + ) + return + + self.parent_class.mainwindow.push_screen( + setup_gdrive.SetupGdriveScreen(self.interface) + ) + def setup_aws_connection(self) -> None: + """ + Set up the `SetupAwsScreen` screen. + """ + assert self.interface is not None, "type narrow flexible `interface`" + + if not self.widget_configs_match_saved_configs(): + self.parent_class.mainwindow.show_modal_error_dialog( + "The values set above must equal the datashuttle settings. " + "Either press 'Save' or reload this page." + ) + return + + self.parent_class.mainwindow.push_screen( + setup_aws.SetupAwsScreen(self.interface) + ) def widget_configs_match_saved_configs(self): """ @@ -456,20 +599,12 @@ def setup_configs_for_a_new_project(self) -> None: self.interface = interface - self.query_one("#configs_go_to_project_screen_button").visible = ( - True - ) + self.query_one("#configs_go_to_project_screen_button").visible = True - # Could not find a neater way to combine the push screen - # while initiating the callback in one case but not the other. if cfg_kwargs["connection_method"] == "ssh": - self.query_one( - "#configs_setup_ssh_connection_button" - ).visible = True - self.query_one( - "#configs_setup_ssh_connection_button" - ).disabled = False + self.query_one("#configs_setup_ssh_connection_button").visible = True + self.query_one("#configs_setup_ssh_connection_button").disabled = False message = ( "A datashuttle project has now been created.\n\n " @@ -478,6 +613,30 @@ def setup_configs_for_a_new_project(self) -> None: "able to create and transfer project folders." ) + elif cfg_kwargs["connection_method"] == "aws": + + self.query_one("#configs_setup_aws_connection_button").visible = True + self.query_one("#configs_setup_aws_connection_button").disabled = False + + message = ( + "A datashuttle project has now been created.\n\n " + "Next, setup the AWS connection. Once complete, navigate to the " + "'Main Menu' and proceed to the project page, where you will be " + "able to create and transfer project folders." + ) + + elif cfg_kwargs["connection_method"] == "gdrive": + + self.query_one("#configs_setup_gdrive_connection_button").visible = True + self.query_one("#configs_setup_gdrive_connection_button").disabled = False + + message = ( + "A datashuttle project has now been created.\n\n " + "Next, setup the Google Drive connection. Once complete, navigate to the " + "'Main Menu' and proceed to the project page, where you will be " + "able to create and transfer project folders." + ) + else: message = ( "A datashuttle project has now been created.\n\n " @@ -494,6 +653,7 @@ def setup_configs_for_a_new_project(self) -> None: else: self.parent_class.mainwindow.show_modal_error_dialog(output) + def setup_configs_for_an_existing_project(self) -> None: """ If the project already exists, we are on the TabbedContent @@ -503,12 +663,27 @@ def setup_configs_for_an_existing_project(self) -> None: """ assert self.interface is not None, "type narrow flexible `interface`" - # Handle the edge case where connection method is changed after - # saving on the 'Make New Project' screen. - self.query_one("#configs_setup_ssh_connection_button").visible = True - cfg_kwargs = self.get_datashuttle_inputs_from_widgets() + # Show relevant setup button depending on selected method + connection_method = cfg_kwargs.get("connection_method", "") + if connection_method == "ssh": + self.query_one("#configs_setup_ssh_connection_button").visible = True + self.query_one("#configs_setup_aws_connection_button").visible = False + self.query_one("#configs_setup_gdrive_connection_button").visible = False + elif connection_method == "aws": + self.query_one("#configs_setup_aws_connection_button").visible = True + self.query_one("#configs_setup_ssh_connection_button").visible = False + self.query_one("#configs_setup_gdrive_connection_button").visible = False + elif connection_method == "gdrive": + self.query_one("#configs_setup_gdrive_connection_button").visible = True + self.query_one("#configs_setup_ssh_connection_button").visible = False + self.query_one("#configs_setup_aws_connection_button").visible = False + else: + self.query_one("#configs_setup_ssh_connection_button").visible = False + self.query_one("#configs_setup_aws_connection_button").visible = False + self.query_one("#configs_setup_gdrive_connection_button").visible = False + success, output = self.interface.set_configs_on_existing_project( cfg_kwargs ) @@ -523,6 +698,7 @@ def setup_configs_for_an_existing_project(self) -> None: else: self.parent_class.mainwindow.show_modal_error_dialog(output) + def fill_widgets_with_project_configs(self) -> None: """ If a configured project already exists, we want to fill the @@ -558,6 +734,10 @@ def fill_widgets_with_project_configs(self) -> None: cfg_to_load["connection_method"] == "local_filesystem", "configs_local_only_radiobutton": cfg_to_load["connection_method"] is None, + "configs_aws_radiobutton": + cfg_to_load["connection_method"] == "aws", + "configs_gdrive_radiobutton": + cfg_to_load["connection_method"] == "gdrive", } # fmt: on @@ -568,6 +748,14 @@ def fill_widgets_with_project_configs(self) -> None: display_ssh=what_radiobuton_is_on["configs_ssh_radiobutton"] ) + self.switch_aws_widgets_display( + display_aws=what_radiobuton_is_on["configs_aws_radiobutton"] + ) + + self.switch_gdrive_widgets_display( + display_gdrive=what_radiobuton_is_on["configs_gdrive_radiobutton"] + ) + # Central Host ID input = self.query_one("#configs_central_host_id_input") value = ( @@ -586,6 +774,34 @@ def fill_widgets_with_project_configs(self) -> None: ) input.value = value + # AWS Bucket Name + input = self.query_one("#configs_aws_bucket_name_input") + value = ( + "" + if cfg_to_load.get("aws_bucket_name") is None + else cfg_to_load["aws_bucket_name"] + ) + input.value = value + + # AWS Region + input = self.query_one("#configs_aws_region_input") + value = ( + "" + if cfg_to_load.get("aws_region") is None + else cfg_to_load["aws_region"] + ) + input.value = value + + # GDrive Folder ID + input = self.query_one("#configs_gdrive_folder_id_input") + value = ( + "" + if cfg_to_load.get("gdrive_folder_id") is None + else cfg_to_load["gdrive_folder_id"] + ) + input.value = value + + def get_datashuttle_inputs_from_widgets(self) -> Dict: """ Get the configs to pass to `make_config_file()` from @@ -611,6 +827,12 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: elif self.query_one("#configs_local_filesystem_radiobutton").value: connection_method = "local_filesystem" + elif self.query_one("#configs_aws_radiobutton").value: + connection_method = "aws" + + elif self.query_one("#configs_gdrive_radiobutton").value: + connection_method = "gdrive" + elif self.query_one("#configs_local_only_radiobutton").value: connection_method = None @@ -626,9 +848,29 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: central_host_username = self.query_one( "#configs_central_host_username_input" ).value - cfg_kwargs["central_host_username"] = ( None if central_host_username == "" else central_host_username ) + aws_bucket_name = self.query_one( + "#configs_aws_bucket_name_input" + ).value + cfg_kwargs["aws_bucket_name"] = ( + None if aws_bucket_name == "" else aws_bucket_name + ) + + aws_region = self.query_one( + "#configs_aws_region_input" + ).value + cfg_kwargs["aws_region"] = ( + None if aws_region == "" else aws_region + ) + + gdrive_folder_id = self.query_one( + "#configs_gdrive_folder_id_input" + ).value + cfg_kwargs["gdrive_folder_id"] = ( + None if gdrive_folder_id == "" else gdrive_folder_id + ) + return cfg_kwargs diff --git a/datashuttle/tui/css/tui_menu.tcss b/datashuttle/tui/css/tui_menu.tcss index c9f35d60c..2f0dc4a6c 100644 --- a/datashuttle/tui/css/tui_menu.tcss +++ b/datashuttle/tui/css/tui_menu.tcss @@ -58,6 +58,12 @@ SetupSshScreen { align: center middle; } +SetupAwsScreen { + align: center middle; +} +SetupGdriveScreen { + align: center middle; +} SettingsScreen { align: center middle; } @@ -77,6 +83,24 @@ GetHelpScreen { #setup_ssh_screen_container { border: thick $panel-darken-3; } +#setup_aws_screen_container { + height: 75%; + width: 80%; + align: center middle; + border: thick $panel-lighten-1; +} +#setup_aws_screen_container { + border: thick $panel-darken-3; +} +#setup_gdrive_screen_container { + height: 75%; + width: 80%; + align: center middle; + border: thick $panel-lighten-1; +} +#setup_gdrive_screen_container { + border: thick $panel-darken-3; +} #messagebox_buttons_horizontal { align: center middle; } @@ -163,9 +187,15 @@ MessageBox:light > #messagebox_top_container { color: $success; /* unsure about this */ } +#configs_setup_aws_connection_button { + margin: 2 1 0 0; +} #configs_setup_ssh_connection_button { margin: 2 1 0 0; } +#configs_setup_gdrive_connection_button { + margin: 2 1 0 0; +} #configs_go_to_project_screen_button { margin: 2 1 0 0; } @@ -194,6 +224,15 @@ MessageBox:light > #messagebox_top_container { #configs_central_path_input { width: 70%; } +#configs_aws_region_input { + width: 70%; +} +#configs_aws_bucket_name_input { + width: 70%; +} +#configs_google_drive_folder_id_input { + width: 70%; +} #configs_local_path_select_button{ width: auto; } diff --git a/datashuttle/tui/interface.py b/datashuttle/tui/interface.py index 1cf397c59..f3b2f3828 100644 --- a/datashuttle/tui/interface.py +++ b/datashuttle/tui/interface.py @@ -11,7 +11,7 @@ from datashuttle import DataShuttle from datashuttle.configs import load_configs -from datashuttle.utils import ssh +from datashuttle.utils import ssh, gdrive, aws class Interface: @@ -183,36 +183,56 @@ def validate_names( # Transfer # ---------------------------------------------------------------------------------- - def transfer_entire_project(self, upload: bool) -> InterfaceOutput: """ Transfer the entire project (all canonical top-level folders). Parameters ---------- - upload : bool Upload from local to central if `True`, otherwise download from central to remote. """ try: - if upload: - transfer_func = self.project.upload_entire_project - else: - transfer_func = self.project.download_entire_project + connection_method = self.project.cfg["connection_method"] - transfer_func( - overwrite_existing_files=self.tui_settings[ - "overwrite_existing_files" - ], - dry_run=self.tui_settings["dry_run"], - ) + if connection_method in ["ssh", "local_filesystem"]: + transfer_func = ( + self.project.upload_entire_project + if upload + else self.project.download_entire_project + ) + transfer_func( + overwrite_existing_files=self.tui_settings["overwrite_existing_files"], + dry_run=self.tui_settings["dry_run"], + ) + + elif connection_method == "aws": + remote_path = self.project.cfg["aws_bucket_name"] + success, message = self.project.transfer_aws_files( + upload_or_download="upload" if upload else "download", + local_path=self.project.cfg["local_path"], + remote_path=remote_path, + ) + if not success: + return False, message + + elif connection_method == "gdrive": + remote_path = self.project.cfg["google_drive_folder_id"] + success, message = self.project.transfer_gdrive_files( + upload_or_download="upload" if upload else "download", + local_path=self.project.cfg["local_path"], + remote_path=remote_path, + ) + if not success: + return False, message return True, None except BaseException as e: return False, str(e) + def transfer_top_level_only( self, selected_top_level_folder: str, upload: bool ) -> InterfaceOutput: @@ -257,6 +277,7 @@ def transfer_top_level_only( except BaseException as e: return False, str(e) + def transfer_custom_selection( self, @@ -271,7 +292,6 @@ def transfer_custom_selection( Parameters ---------- - selected_top_level_folder : str The top level folder selected in the TUI for this transfer window. @@ -289,27 +309,47 @@ def transfer_custom_selection( from central to remote. """ try: - if upload: - transfer_func = self.project.upload_custom - else: - transfer_func = self.project.download_custom + connection_method = self.project.cfg["connection_method"] - transfer_func( - selected_top_level_folder, - sub_names=sub_names, - ses_names=ses_names, - datatype=datatype, - overwrite_existing_files=self.tui_settings[ - "overwrite_existing_files" - ], - dry_run=self.tui_settings["dry_run"], - ) + if connection_method in ["ssh", "local_filesystem"]: + transfer_func = ( + self.project.upload_custom if upload else self.project.download_custom + ) + transfer_func( + selected_top_level_folder, + sub_names=sub_names, + ses_names=ses_names, + datatype=datatype, + overwrite_existing_files=self.tui_settings["overwrite_existing_files"], + dry_run=self.tui_settings["dry_run"], + ) + + elif connection_method == "aws": + remote_path = f"{self.project.cfg['aws_bucket_name']}/{selected_top_level_folder}" + success, message = self.project.transfer_aws_files( + upload_or_download="upload" if upload else "download", + local_path=self.project.cfg["local_path"], + remote_path=remote_path, + ) + if not success: + return False, message + + elif connection_method == "grive": + remote_path = f"{self.project.cfg['google_drive_folder_id']}/{selected_top_level_folder}" + success, message = self.project.transfer_gdrive_files( + upload_or_download="upload" if upload else "download", + local_path=self.project.cfg["local_path"], + remote_path=remote_path, + ) + if not success: + return False, message return True, None except BaseException as e: return False, str(e) + # Setup SSH # ---------------------------------------------------------------------------------- @@ -455,3 +495,103 @@ def setup_key_pair_and_rclone_config( except BaseException as e: return False, str(e) + + # Setup AWS S3 + # ---------------------------------------------------------------------------------- + + def get_aws_bucket_name(self) -> str: + """ + Retrieve the AWS bucket name from project configs. + """ + return self.project.cfg["aws_bucket_name"] + + def get_aws_hostkey(self) -> InterfaceOutput: + """ + Retrieve AWS host credentials for verification. + """ + try: + key = aws.get_remote_aws_key( + self.project.cfg["aws_bucket_name"] + ) + return True, key + except BaseException as e: + return False, str(e) + + def save_aws_key_locally(self, key: str) -> InterfaceOutput: + """ + Save the AWS access key locally for secure storage. + """ + try: + aws.save_aws_key_locally( + key, + self.project.cfg["aws_bucket_name"], + self.project.cfg.aws_key_path, + ) + return True, None + + except BaseException as e: + return False, str(e) + + def setup_aws_bucket_and_rclone_config( + self, bucket_name: str, region: str + ) -> InterfaceOutput: + try: + self.project.cfg["aws_bucket_name"] = bucket_name + self.project.cfg["aws_region"] = region + + self.project._setup_rclone_central_aws_config(log=False) + + return True, None + + except BaseException as e: + return False, str(e) + + # Setup Google Drive + # ---------------------------------------------------------------------------------- + + def get_gdrive_folder_id(self) -> str: + """ + Retrieve the Google Drive folder ID from project configs. + """ + return self.project.cfg["gdrive_folder_id"] + + def get_gdrive_hostkey(self) -> InterfaceOutput: + """ + Retrieve Google Drive credentials for verification. + """ + try: + key = gdrive.get_remote_gdrive_key( + self.project.cfg["gdrive_folder_id"] + ) + return True, key + except BaseException as e: + return False, str(e) + + def save_gdrive_key_locally(self, key: str) -> InterfaceOutput: + """ + Save the Google Drive credentials locally for secure storage. + """ + try: + gdrive.save_gdrive_key_locally( + key, + self.project.cfg["gdrive_folder_id"], + self.project.cfg.gdrive_key_path, + ) + return True, None + + except BaseException as e: + return False, str(e) + + def setup_gdrive_folder_and_rclone_config( + self, folder_id: str + ) -> InterfaceOutput: + try: + self.project.cfg["gdrive_folder_id"] = folder_id + + self.project._setup_rclone_central_gdrive_config(log=False) + + return True, None + + except BaseException as e: + return False, str(e) + diff --git a/datashuttle/tui/screens/setup_aws.py b/datashuttle/tui/screens/setup_aws.py new file mode 100644 index 000000000..9b5859fa9 --- /dev/null +++ b/datashuttle/tui/screens/setup_aws.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from textual.app import ComposeResult + from datashuttle.tui.interface import Interface + +from textual.containers import Container, Horizontal +from textual.screen import ModalScreen +from textual.widgets import ( + Button, + Input, + Static, +) + + +class SetupAwsScreen(ModalScreen): + """ + This dialog window handles the TUI equivalent of API's + setup_aws_connection(). This asks to confirm AWS credentials + and takes an access key for setup. + + the TUI cannot simply wrap the API because + the logic flow requires user input (AWS credentials and key). + """ + + def __init__(self, interface: Interface) -> None: + super(SetupAwsScreen, self).__init__() + + self.interface = interface + self.stage = 0 + self.bucket_name: str = "" + self.aws_region: str = "" + self.failed_attempts = 1 + + def compose(self) -> ComposeResult: + yield Container( + Horizontal( + Static( + "Ready to setup AWS S3. Enter bucket name and region, then press OK.", + id="messagebox_message_label", + ), + id="messagebox_message_container", + ), + Input(placeholder="AWS Bucket Name", id="setup_aws_bucket_input"), + Input(placeholder="AWS Region", id="setup_aws_region_input"), + Horizontal( + Button("OK", id="setup_aws_ok_button"), + Button("Cancel", id="setup_aws_cancel_button"), + id="messagebox_buttons_horizontal", + ), + id="setup_aws_screen_container", + ) + + def on_mount(self) -> None: + """Hide region input until the bucket name is verified.""" + self.query_one("#setup_aws_bucket_input").visible = False + self.query_one("#setup_aws_region_input").visible = False + + def on_button_pressed(self, event: Button.pressed) -> None: + """ + Handle button presses for each stage: + 1. Confirm AWS credentials. + 2. Save credentials and prompt for key. + 3. Use the key to finalize AWS setup. + """ + if event.button.id == "setup_aws_cancel_button": + self.dismiss() + + if event.button.id == "setup_aws_ok_button": + if self.stage == 0: + self.ask_user_to_accept_aws_bucket() + + elif self.stage == 1: + self.save_aws_bucket_and_prompt_region_input() + + elif self.stage == 2: + self.use_aws_bucket_and_region_to_setup_aws_connection() + + elif self.stage == 3: + self.dismiss() + + def ask_user_to_accept_aws_bucket(self) -> None: + """ + Verify that the AWS S3 bucket is accessible. + Ask the user to confirm trusting this bucket for future use. + """ + success, output = self.interface.get_aws_hostkey() + + if success: + self.bucket_name = output + + message = ( + f"The AWS bucket '{self.bucket_name}' is accessible.\n\n" + "If you trust this bucket and want to use it for transfers, press OK." + ) + else: + message = ( + "Could not verify the AWS bucket.\nCheck the connection and the bucket name.\n\n" + f"Traceback: {output}" + ) + self.query_one("#setup_aws_ok_button").disabled = True + + self.query_one("#messagebox_message_label").update(message) + self.stage += 1 + + + def save_aws_bucket_and_prompt_region_input(self) -> None: + """ + Once the AWS bucket is accepted, prompt the user for the region. + When 'OK' is pressed, we go straight to 'use_region_to_setup_aws_connection'. + """ + success, output = self.interface.save_aws_key_locally(self.bucket_name) + + if success: + message = ( + "AWS bucket verified.\n\nNext, enter your AWS region below to complete setup. " + "Press OK to confirm." + ) + self.query_one("#setup_aws_region_input").visible = True + else: + message = ( + f"Could not store AWS bucket name. Check permissions " + f"for: \n\n {self.interface.get_configs().aws_key_path}.\n\n Traceback: {output}" + ) + self.query_one("#setup_aws_ok_button").disabled = True + + self.query_one("#messagebox_message_label").update(message) + self.stage += 1 + + + def use_aws_bucket_and_region_to_setup_aws_connection(self) -> None: + """ + Use the AWS bucket name and region to complete the setup. + If successful, the OK button changes to 'Finish'. + Otherwise, prompt for another attempt. + """ + bucket_name = self.query_one("#setup_aws_bucket_name_input").value + region = self.query_one("#setup_aws_region_input").value + + success, output = self.interface.setup_aws_bucket_and_rclone_config(bucket_name, region) + + if success: + message = ( + f"AWS setup successful! Config saved to " + f"{self.interface.get_configs().aws_credentials_path}" + ) + self.query_one("#setup_aws_ok_button").label = "Finish" + self.query_one("#setup_aws_cancel_button").disabled = True + self.stage += 1 + + else: + message = ( + "AWS setup failed. Check that your bucket name and region are correct and try again." + f"\n\n{self.failed_attempts} failed attempts." + f"\n\nTraceback: {output}" + ) + self.failed_attempts += 1 + + self.query_one("#messagebox_message_label").update(message) diff --git a/datashuttle/tui/screens/setup_gdrive.py b/datashuttle/tui/screens/setup_gdrive.py new file mode 100644 index 000000000..de2889cab --- /dev/null +++ b/datashuttle/tui/screens/setup_gdrive.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from textual.app import ComposeResult + from datashuttle.tui.interface import Interface + +from textual.containers import Container, Horizontal +from textual.screen import ModalScreen +from textual.widgets import ( + Button, + Input, + Static, +) + + +class SetupGdriveScreen(ModalScreen): + """ + This dialog window handles the TUI equivalent of API's + setup_gdrive_connection(). This asks to confirm Google Drive + credentials and takes an OAuth folder_id for setup. + + the TUI cannot simply wrap the API because + the logic flow requires user input (OAuth folder_id and confirmation). + """ + + def __init__(self, interface: Interface) -> None: + super(SetupGdriveScreen, self).__init__() + + self.interface = interface + self.stage = 0 + self.folder_id: str = "" + + def compose(self) -> ComposeResult: + yield Container( + Horizontal( + Static( + "Ready to setup Google Drive. Enter folder ID, then press OK.", + id="messagebox_message_label", + ), + id="messagebox_message_container", + ), + Input(placeholder="Google Drive Folder ID", id="setup_gdrive_folder_input"), + Horizontal( + Button("OK", id="setup_gdrive_ok_button"), + Button("Cancel", id="setup_gdrive_cancel_button"), + id="messagebox_buttons_horizontal", + ), + id="setup_gdrive_screen_container", + ) + + def on_mount(self) -> None: + """Ensure UI is clean on start.""" + self.query_one("#setup_gdrive_folder_input").visible = False + + def on_button_pressed(self, event: Button.pressed) -> None: + """ + Handle button presses for each stage: + 1. Confirm Google Drive credentials. + 2. Save credentials and prompt for OAuth folder_id. + 3. Use the folder_id to finalize Google Drive setup. + """ + if event.button.id == "setup_gdrive_cancel_button": + self.dismiss() + + if event.button.id == "setup_gdrive_ok_button": + if self.stage == 0: + self.ask_user_to_accept_gdrive_folder() + + elif self.stage == 1: + self.save_gdrive_folder_and_prompt_setup() + + elif self.stage == 2: + self.use_folder_id_to_setup_gdrive_connection() + + elif self.stage == 3: + self.dismiss() + + def ask_user_to_accept_gdrive_folder(self) -> None: + """ + Verify that the Google Drive folder ID is accessible. + Ask the user to confirm trusting this folder for future use. + """ + success, output = self.interface.get_gdrive_hostkey() + + if success: + self.folder_id = output + + message = ( + f"The Google Drive folder ID '{self.folder_id}' is accessible.\n\n" + "If you trust this folder and want to use it for transfers, press OK." + ) + else: + message = ( + "Could not verify the Google Drive folder.\nCheck the connection and the folder ID.\n\n" + f"Traceback: {output}" + ) + self.query_one("#setup_gdrive_ok_button").disabled = True + + self.query_one("#messagebox_message_label").update(message) + self.stage += 1 + + + def save_gdrive_folder_and_prompt_setup(self) -> None: + """ + Once the Google Drive folder ID is accepted, confirm setup. + No additional credentials are needed after this step. + """ + success, output = self.interface.save_gdrive_key_locally(self.folder_id) + + if success: + message = ( + "Google Drive folder ID verified.\n\nSetup is now complete. " + "You can proceed to transfer files." + ) + else: + message = ( + f"Could not store Google Drive folder ID. Check permissions " + f"for: \n\n {self.interface.get_configs().gdrive_key_path}.\n\n Traceback: {output}" + ) + self.query_one("#setup_gdrive_ok_button").disabled = True + + self.query_one("#messagebox_message_label").update(message) + self.stage += 1 + + + def use_folder_id_to_setup_gdrive_connection(self) -> None: + """ + Use the OAuth folder_id to complete the Google Drive setup. + If successful, the OK button changes to 'Finish'. + Otherwise, prompt for another attempt. + """ + folder_id = self.query_one("#setup_gdrive_folder_id_input").value + + success, output = self.interface.setup_gdrive_folder_and_rclone_config(folder_id) + + if success: + message = ( + f"Google Drive setup successful! Credentials saved to " + f"{self.interface.get_configs().gdrive_credentials_path}" + ) + self.query_one("#setup_gdrive_ok_button").label = "Finish" + self.query_one("#setup_gdrive_cancel_button").disabled = True + self.stage += 1 + + else: + message = ( + "Google Drive setup failed. Check that your OAuth folder_id is correct and try again." + f"\n\n{self.failed_attempts} failed attempts." + f"\n\n Traceback: {output}" + ) + self.failed_attempts += 1 + + self.query_one("#messagebox_message_label").update(message) diff --git a/datashuttle/tui/tooltips.py b/datashuttle/tui/tooltips.py index 61cb9f70d..dee58d733 100644 --- a/datashuttle/tui/tooltips.py +++ b/datashuttle/tui/tooltips.py @@ -33,10 +33,38 @@ def get_tooltip(id: str) -> str: elif id == "#configs_ssh_radiobutton": tooltip = "Use SSH when planning to connect with the central data storage via SSH protocol." + # AWS S3 radiobutton + elif id == "#configs_aws_radiobutton": + tooltip = "Use AWS S3 when planning to connect with your AWS storage bucket." + + # Google Drive radiobutton + elif id == "#configs_gdrive_radiobutton": + tooltip = "Use Google Drive when planning to connect with your Google Drive folder." + # No connection (local only) radiobutton elif id == "#configs_local_only_radiobutton": tooltip = "No connection to a central project is made.\nTransfer functionality will not be available." + # AWS S3 Inputs + elif id == "#configs_aws_bucket_name_input": + tooltip = ( + "Name of your AWS S3 bucket where project data will be stored.\n\n" + "Ensure the bucket already exists before attempting to connect." + ) + + elif id == "#configs_aws_region_input": + tooltip = ( + "AWS region where your S3 bucket is located.\n\n" + "Example regions: us-east-1, eu-west-1, etc." + ) + + # Google Drive Inputs + elif id == "#configs_gdrive_folder_id_input": + tooltip = ( + "The folder ID of your Google Drive folder for project data storage.\n\n" + "Find the folder ID in your Google Drive folder URL." + ) + # central host input elif id == "#configs_central_host_id_input": tooltip = "The hostname or IP address of the server." @@ -51,6 +79,20 @@ def get_tooltip(id: str) -> str: "The path to the project folder on the central machine (or it's parent folder).\n\n" "With 'SSH', this path is relative to the server e.g. /nhome/users/myusername" ) + # AWS S3 Tooltip + elif id == "config_central_path_input_mode-aws": + tooltip = ( + "The path to the project folder on AWS S3.\n\n" + "Provide the full bucket path. For example:\n" + "s3://my-bucket-name/project_folder" + ) + + # Google Drive Tooltip + elif id == "config_central_path_input_mode-gdrive": + tooltip = ( + "The path to the project folder on Google Drive.\n\n" + "Provide the folder ID associated with your Google Drive project folder." + ) elif id == "config_central_path_input_mode-local_filesystem": tooltip = ( diff --git a/datashuttle/utils/aws.py b/datashuttle/utils/aws.py new file mode 100644 index 000000000..ee3bd2803 --- /dev/null +++ b/datashuttle/utils/aws.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +import subprocess +from pathlib import Path +from typing import Any, List, Optional, Tuple +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from datashuttle.configs.config_class import Configs + + +from datashuttle.utils import utils + +import fnmatch +import os + +def get_remote_aws_key(bucket_name: str) -> Tuple[bool, str]: + """ + Attempt to list contents of the AWS S3 bucket to check access. + """ + remote_path = f"aws_remote:{bucket_name}" + try: + subprocess.run( + ["rclone", "lsf", remote_path], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + return True, "" + except subprocess.CalledProcessError as e: + return False, e.stderr.decode() + +def save_aws_key_locally(bucket_name: str, aws_region: str, central_path: Path) -> None: + """ + Save the AWS bucket name and region in the central path, following SSH-style storage. + """ + central_path.parent.mkdir(parents=True, exist_ok=True) + + with open(central_path, "w") as file: + file.write(f"Bucket: {bucket_name}\nRegion: {aws_region}") + + +def connect_aws_with_logging( + cfg: Configs, + message_on_sucessful_connection: bool = True, +) -> None: + """ + Connect to AWS S3 using rclone by testing access to the remote. + This assumes rclone has already been configured properly. + """ + remote = cfg.get_rclone_config_name("AWS S3") + + try: + subprocess.run( + ["rclone", "lsf", f"{remote}:"], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + if message_on_sucessful_connection: + utils.print_message_to_user( + f"Connection to AWS S3 remote '{remote}' made successfully." + ) + + except Exception as e: + utils.log_and_raise_error( + f"Could not connect to AWS S3. Ensure that:\n" + f"1) You have run setup_aws_connection()\n" + f"2) Your rclone remote '{remote}' is correctly configured\n" + f"3) The bucket exists and credentials are valid.\n\n" + f"Error:\n{e}", + ConnectionError, + ) +def search_aws_remote_for_folders( + search_path: Path, + search_prefix: str, + cfg: Configs, + verbose: bool = True, + return_full_path: bool = False, +) -> Tuple[List[Any], List[Any]]: + """ + Search for the search prefix in the search path over AWS S3. + Returns the list of matching folders, files are filtered out. + + Parameters + ----------- + + search_path : path to search for folders in + + search_prefix : search prefix for folder names e.g. "sub-*" + + cfg : project config object (provides bucket and credentials) + + verbose : If `True`, if a search folder cannot be found, a message + will be printed with the un-found path. + """ + all_folder_names, all_filenames = get_list_of_folder_names_over_aws( + cfg, + search_path, + search_prefix, + verbose, + return_full_path, + ) + + return all_folder_names, all_filenames + +def get_list_of_folder_names_over_aws( + cfg: Configs, + search_path: Path, + search_prefix: str, + verbose: bool = True, + return_full_path: bool = False, +) -> Tuple[List[Any], List[Any]]: + """ + Use rclone to search a path over AWS S3 for folders. + Return the folder names. + + Parameters + ---------- + + cfg : datashuttle project config object + + search_path : path to search for folders in (inside the bucket) + + search_prefix : prefix (can include wildcards) to search folder names + + verbose : If `True`, if a search folder cannot be found, a message + will be printed with the un-found path. + + return_full_path : If `True`, return full rclone remote path, + else return only folder name + """ + remote_path = f"{cfg.get_rclone_config_name('AWS S3')}:{search_path.as_posix()}" + + all_folder_names = [] + all_filenames = [] + + try: + result = subprocess.run( + ["rclone", "lsjson", remote_path], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + entries = json.loads(result.stdout) + + for entry in entries: + name = entry["Name"] + is_dir = entry.get("IsDir", False) + + if fnmatch.fnmatch(name, search_prefix): + to_append = ( + f"{remote_path}/{name}" if return_full_path else name + ) + if is_dir: + all_folder_names.append(to_append) + else: + all_filenames.append(to_append) + + except subprocess.CalledProcessError as e: + if verbose: + utils.log_and_message(f"No file found at {remote_path}\n{e.stderr}") + + return all_folder_names, all_filenames + +def verify_aws_remote(bucket_name: str, aws_key_path: Path, log: bool = True) -> bool: + """ + Prompt user to trust and save an AWS S3 bucket for future use. + """ + success, _ = get_remote_aws_key(bucket_name) + if not success: + utils.print_message_to_user("Unable to access the AWS S3 bucket. Make sure it exists and is accessible.") + return False + + message = ( + f"You're about to trust this AWS S3 bucket: {bucket_name}\n" + "If you trust it, type 'y' to save and proceed: " + ) + input_ = utils.get_user_input(message) + + if input_ == "y": + save_aws_key_locally(bucket_name, aws_key_path) + if log: + utils.log(f"AWS S3 bucket {bucket_name} trusted and saved at {aws_key_path}") + utils.print_message_to_user("AWS bucket accepted.") + return True + else: + utils.print_message_to_user("Bucket not accepted. No connection made.") + return False diff --git a/datashuttle/utils/decorators.py b/datashuttle/utils/decorators.py index cacf54914..f0b29c2c3 100644 --- a/datashuttle/utils/decorators.py +++ b/datashuttle/utils/decorators.py @@ -27,6 +27,45 @@ def wrapper(*args, **kwargs): return wrapper +def requires_aws_configs(func): + """ + Decorator to check AWS configs are loaded before running the function. + """ + @wraps(func) + def wrapper(*args, **kwargs): + if ( + not args[0].cfg["aws_bucket_name"] + or not args[0].cfg["aws_region"] + ): + log_and_raise_error( + "Cannot setup AWS connection, 'aws_bucket_name', " + "'aws_access_key', or 'aws_secret_key' is not set in " + "the configuration file.", + ConfigError, + ) + else: + return func(*args, **kwargs) + return wrapper + + +def requires_gdrive_configs(func): + """ + Decorator to check Google Drive configs are loaded before running the function. + """ + @wraps(func) + def wrapper(*args, **kwargs): + if ( + not args[0].cfg["gdrive_folder_id"] + ): + log_and_raise_error( + "Cannot setup Google Drive connection, 'gdrive_folder_id', " + "'gdrive_client_id', or 'gdrive_client_secret' is not set in " + "the configuration file.", + ConfigError, + ) + else: + return func(*args, **kwargs) + return wrapper def check_configs_set(func): """ diff --git a/datashuttle/utils/folders.py b/datashuttle/utils/folders.py index df6cdb061..3ca3e364b 100644 --- a/datashuttle/utils/folders.py +++ b/datashuttle/utils/folders.py @@ -20,7 +20,7 @@ from pathlib import Path from datashuttle.configs import canonical_folders, canonical_tags -from datashuttle.utils import ssh, utils, validation +from datashuttle.utils import ssh, utils, validation, aws, gdrive from datashuttle.utils.custom_exceptions import NeuroBlueprintError # ----------------------------------------------------------------------------- @@ -514,14 +514,44 @@ def search_for_folders( verbose : If `True`, when a search folder cannot be found, a message will be printed with the missing path. """ - if local_or_central == "central" and cfg["connection_method"] == "ssh": - all_folder_names, all_filenames = ssh.search_ssh_central_for_folders( - search_path, - search_prefix, - cfg, - verbose, - return_full_path, - ) + if local_or_central == "central": + if cfg["connection_method"] == "ssh": + all_folder_names, all_filenames = ssh.search_ssh_central_for_folders( + search_path, + search_prefix, + cfg, + verbose, + return_full_path, + ) + elif cfg["connection_method"] == "AWS S3": + all_folder_names, all_filenames = aws.search_aws_remote_for_folders( + search_path, + search_prefix, + cfg, + verbose, + return_full_path, + ) + elif cfg["connection_method"] == "Google Drive": + all_folder_names, all_filenames = gdrive.search_gdrive_remote_for_folders( + search_path, + search_prefix, + cfg, + verbose, + return_full_path, + ) + else: + # Default to filesystem search if no valid method found + if not search_path.exists(): + if verbose: + utils.log_and_message( + f"No file found at {search_path.as_posix()}" + ) + return [], [] + + all_folder_names, all_filenames = search_filesystem_path_for_folders( + search_path / search_prefix, return_full_path + ) + else: if not search_path.exists(): if verbose: @@ -533,6 +563,7 @@ def search_for_folders( all_folder_names, all_filenames = search_filesystem_path_for_folders( search_path / search_prefix, return_full_path ) + return all_folder_names, all_filenames diff --git a/datashuttle/utils/gdrive.py b/datashuttle/utils/gdrive.py new file mode 100644 index 000000000..2c8e7dd7b --- /dev/null +++ b/datashuttle/utils/gdrive.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import subprocess +from pathlib import Path +from typing import Any, List, Optional, Tuple +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from datashuttle.configs.config_class import Configs +from datashuttle.utils import utils + +import fnmatch +import os + +def get_remote_gdrive_key(folder_id: str) -> Tuple[bool, str]: + """ + Attempt to list contents of the Google Drive folder to check access. + """ + remote_path = f"gdrive_remote:{folder_id}" + try: + subprocess.run( + ["rclone", "lsf", remote_path], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + return True, "" + except subprocess.CalledProcessError as e: + return False, e.stderr.decode() + +def save_gdrive_key_locally(folder_id: str, remote_name: str, central_path: Path) -> None: + """ + Save the trusted Google Drive folder ID and remote name in the central path. + """ + central_path.parent.mkdir(parents=True, exist_ok=True) + + with open(central_path, "w") as file: + file.write(f"Folder ID: {folder_id}\nRemote: {remote_name}") + +def connect_gdrive_with_logging( + cfg: Configs, + message_on_sucessful_connection: bool = True, +) -> None: + """ + Connect to Google Drive using rclone by testing access to the remote. + This assumes rclone has already been configured properly. + """ + remote = cfg.get_rclone_config_name("Google Drive") + + try: + subprocess.run( + ["rclone", "lsf", f"{remote}:"], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + if message_on_sucessful_connection: + utils.print_message_to_user( + f"Connection to Google Drive remote '{remote}' made successfully." + ) + + except Exception as e: + utils.log_and_raise_error( + f"Could not connect to Google Drive. Ensure that:\n" + f"1) You have run setup_gdrive_connection()\n" + f"2) Your rclone remote '{remote}' is correctly configured\n" + f"3) The folder ID exists and access is authorized.\n\n" + f"Error:\n{e}", + ConnectionError, + ) + +def search_gdrive_remote_for_folders( + search_path: Path, + search_prefix: str, + cfg: Configs, + verbose: bool = True, + return_full_path: bool = False, +) -> Tuple[List[Any], List[Any]]: + """ + Search for the search prefix in the search path over Google Drive. + Returns the list of matching folders, files are filtered out. + + Parameters + ----------- + + search_path : path to search for folders in + + search_prefix : search prefix for folder names e.g. "sub-*" + + cfg : project config object (provides remote and credentials) + + verbose : If `True`, if a search folder cannot be found, a message + will be printed with the un-found path. + """ + all_folder_names, all_filenames = get_list_of_folder_names_over_gdrive( + cfg, + search_path, + search_prefix, + verbose, + return_full_path, + ) + + return all_folder_names, all_filenames + +def get_list_of_folder_names_over_gdrive( + cfg: Configs, + search_path: Path, + search_prefix: str, + verbose: bool = True, + return_full_path: bool = False, +) -> Tuple[List[Any], List[Any]]: + """ + Use rclone to search a path over Google Drive for folders. + Return the folder names. + + Parameters + ---------- + + cfg : datashuttle project config object + + search_path : path to search for folders in (inside the GDrive folder) + + search_prefix : prefix (can include wildcards) to search folder names + + verbose : If `True`, if a search folder cannot be found, a message + will be printed with the un-found path. + + return_full_path : If `True`, return full rclone remote path, + else return only folder name + """ + remote_path = f"{cfg.get_rclone_config_name('Google Drive')}:{search_path.as_posix()}" + + all_folder_names = [] + all_filenames = [] + + try: + result = subprocess.run( + ["rclone", "lsjson", remote_path], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + entries = json.loads(result.stdout) + + for entry in entries: + name = entry["Name"] + is_dir = entry.get("IsDir", False) + + if fnmatch.fnmatch(name, search_prefix): + to_append = ( + f"{remote_path}/{name}" if return_full_path else name + ) + if is_dir: + all_folder_names.append(to_append) + else: + all_filenames.append(to_append) + + except subprocess.CalledProcessError as e: + if verbose: + utils.log_and_message(f"No file found at {remote_path}\n{e.stderr}") + + return all_folder_names, all_filenames + +def verify_gdrive_remote(folder_id: str, gdrive_key_path: Path, log: bool = True) -> bool: + """ + Prompt user to trust and save a GDrive folder ID for future use. + """ + success, _ = get_remote_gdrive_key(folder_id) + if not success: + utils.print_message_to_user("Unable to access the Google Drive folder. Make sure it's shared and reachable.") + return False + + message = ( + f"You're about to trust this Google Drive folder ID: {folder_id}\n" + "If you trust it, type 'y' to save and proceed: " + ) + input_ = utils.get_user_input(message) + + if input_ == "y": + save_gdrive_key_locally(folder_id, gdrive_key_path) + if log: + utils.log(f"Google Drive folder ID {folder_id} trusted and saved at {gdrive_key_path}") + utils.print_message_to_user("Google Drive folder accepted.") + return True + else: + utils.print_message_to_user("Folder not accepted. No connection made.") + return False diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index 9644c103c..d7b6fe937 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -107,6 +107,82 @@ def setup_rclone_config_for_ssh( if log: log_rclone_config_output() +def setup_rclone_config_for_gdrive( + cfg: Configs, + rclone_config_name: str, + gdrive_folder_id: str, + log: bool = True, +): + """ + RClone sets remote targets in a config file that are + used at transfer. For Google Drive, this sets the root folder. + The relative path is supplied later during transfer. + + Parameters + ---------- + cfg : Configs + datashuttle configs UserDict. + + rclone_config_name : str + Canonical config name, generated by + datashuttle.cfg.get_rclone_config_name() + + gdrive_folder_id : str + The Google Drive Folder ID to use as the root directory. + + log : bool + Whether to log, if True logger must already be initialised. + """ + call_rclone( + f"config create " + f"{rclone_config_name} " + f"drive " + f"scope drive " + f"root_folder_id {gdrive_folder_id}", + pipe_std=True, + ) + + if log: + log_rclone_config_output() + + +def setup_rclone_config_for_gdrive( + cfg: Configs, + rclone_config_name: str, + gdrive_folder_id: str, + log: bool = True, +): + """ + RClone sets remote targets in a config file that are + used for Google Drive. The relative path is supplied at transfer time. + + Parameters + ---------- + cfg : Configs + datashuttle configs UserDict. + + rclone_config_name : str + Canonical config name, generated by + datashuttle.cfg.get_rclone_config_name(). + + gdrive_folder_id : str + The Google Drive Folder ID where files will be stored. + + log : bool + Whether to log, if True logger must already be initialized. + """ + call_rclone( + f"config create " + f"{rclone_config_name} " + f"drive " + f"scope drive " + f"root_folder_id {gdrive_folder_id}", + pipe_std=True, + ) + + if log: + log_rclone_config_output() + def log_rclone_config_output(): output = call_rclone("config file", pipe_std=True) @@ -175,32 +251,42 @@ def transfer_data( A list of options to pass to Rclone's copy function. see `cfg.make_rclone_transfer_options()`. """ - assert upload_or_download in [ - "upload", - "download", - ], "must be 'upload' or 'download'" + assert upload_or_download in ["upload", "download"], "must be 'upload' or 'download'" - local_filepath = cfg.get_base_folder("local", top_level_folder).as_posix() + connection_method = cfg["connection_method"] + local_filepath = cfg.get_base_folder("local", top_level_folder).as_posix() central_filepath = cfg.get_base_folder( "central", top_level_folder ).as_posix() - extra_arguments = handle_rclone_arguments(rclone_options, include_list) + # AWS S3 Path Formatting + if connection_method == "AWS S3": + central_filepath = f"{cfg.get_rclone_config_name('AWS S3')}:{central_filepath}" + + # Google Drive Path Formatting + elif connection_method == "Google Drive": + central_filepath = f"{cfg.get_rclone_config_name('Google Drive')}:{central_filepath}" + + # Default (SSH or Local Filesystem) + else: + central_filepath = f"{cfg.get_rclone_config_name()}:{central_filepath}" + + extra_arguments = handle_rclone_arguments( + rclone_options, include_list, connection_method + ) if upload_or_download == "upload": output = call_rclone( f"{rclone_args('copy')} " - f'"{local_filepath}" "{cfg.get_rclone_config_name()}:' - f'{central_filepath}" {extra_arguments}', + f'"{local_filepath}" "{central_filepath}" {extra_arguments}', pipe_std=True, ) elif upload_or_download == "download": output = call_rclone( f"{rclone_args('copy')} " - f'"{cfg.get_rclone_config_name()}:' - f'{central_filepath}" "{local_filepath}" {extra_arguments}', + f'"{central_filepath}" "{local_filepath}" {extra_arguments}', pipe_std=True, ) @@ -288,21 +374,24 @@ def perform_rclone_check( cfg: Configs, top_level_folder: TopLevelFolder ) -> str: """ - Use Rclone's `check` command to build a list of files that - are the same ("="), different ("*"), found in local only ("+") - or central only ("-"). The output is formatted as " \n". + Perform RClone check with AWS S3 and Google Drive support. """ - local_filepath = cfg.get_base_folder( - "local", top_level_folder - ).parent.as_posix() - central_filepath = cfg.get_base_folder( - "central", top_level_folder - ).parent.as_posix() + connection_method = cfg["connection_method"] + + local_filepath = cfg.get_base_folder("local", top_level_folder).parent.as_posix() + central_filepath = cfg.get_base_folder("central", top_level_folder).parent.as_posix() + + if connection_method == "AWS S3": + central_filepath = f"{cfg.get_rclone_config_name('AWS S3')}:{central_filepath}" + elif connection_method == "Google Drive": + central_filepath = f"{cfg.get_rclone_config_name('Google Drive')}:{central_filepath}" + else: + central_filepath = f"{cfg.get_rclone_config_name()}:{central_filepath}" output = call_rclone( f'{rclone_args("check")} ' f'"{local_filepath}" ' - f'"{cfg.get_rclone_config_name()}:{central_filepath}"' + f'"{central_filepath}"' f" --combined -", pipe_std=True, ) From efdef7e02bf32befe0089b368710c1e5de9df145 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Mar 2025 17:29:43 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- datashuttle/configs/canonical_configs.py | 9 +- datashuttle/configs/config_class.py | 4 +- datashuttle/datashuttle_class.py | 9 +- datashuttle/tui/configs.py | 183 ++++++++++++++++------- datashuttle/tui/interface.py | 22 +-- datashuttle/tui/screens/setup_aws.py | 7 +- datashuttle/tui/screens/setup_gdrive.py | 16 +- datashuttle/tui/tooltips.py | 6 +- datashuttle/utils/aws.py | 36 +++-- datashuttle/utils/decorators.py | 15 +- datashuttle/utils/folders.py | 50 ++++--- datashuttle/utils/gdrive.py | 32 ++-- datashuttle/utils/rclone.py | 30 +++- 13 files changed, 285 insertions(+), 134 deletions(-) diff --git a/datashuttle/configs/canonical_configs.py b/datashuttle/configs/canonical_configs.py index 28af03bb2..617cd3e77 100644 --- a/datashuttle/configs/canonical_configs.py +++ b/datashuttle/configs/canonical_configs.py @@ -39,7 +39,9 @@ def get_canonical_configs() -> dict: canonical_configs = { "local_path": Union[str, Path], "central_path": Optional[Union[str, Path]], - "connection_method": Optional[Literal["ssh", "local_filesystem", "aws", "gdrive"]], + "connection_method": Optional[ + Literal["ssh", "local_filesystem", "aws", "gdrive"] + ], "central_host_id": Optional[str], "central_host_username": Optional[str], "aws_bucket_name": Optional[str], @@ -141,7 +143,10 @@ def check_dict_values_raise_on_fail(config_dict: Configs) -> None: ) # Check Google Drive settings - if config_dict["connection_method"] == "gdrive" and not config_dict["gdrive_folder_id"]: + if ( + config_dict["connection_method"] == "gdrive" + and not config_dict["gdrive_folder_id"] + ): utils.log_and_raise_error( "'gdrive_folder_id' is required if 'connection_method' is 'gdrive'.", ConfigError, diff --git a/datashuttle/configs/config_class.py b/datashuttle/configs/config_class.py index 91a1d094e..9c082f19e 100644 --- a/datashuttle/configs/config_class.py +++ b/datashuttle/configs/config_class.py @@ -240,7 +240,9 @@ def init_paths(self) -> None: self.aws_key_path = datashuttle_path / f"{self.project_name}_aws_key" - self.gdrive_key_path = datashuttle_path / f"{self.project_name}_gdrive_key" + self.gdrive_key_path = ( + datashuttle_path / f"{self.project_name}_gdrive_key" + ) self.hostkeys_path = datashuttle_path / "hostkeys" diff --git a/datashuttle/datashuttle_class.py b/datashuttle/datashuttle_class.py index 35e0c853f..6c2586ce3 100644 --- a/datashuttle/datashuttle_class.py +++ b/datashuttle/datashuttle_class.py @@ -36,14 +36,14 @@ from datashuttle.configs.config_class import Configs from datashuttle.datashuttle_functions import _format_top_level_folder from datashuttle.utils import ( + aws, ds_logger, folders, formatting, + gdrive, getters, rclone, ssh, - gdrive, - aws, utils, validation, ) @@ -55,9 +55,9 @@ from datashuttle.utils.decorators import ( # noqa check_configs_set, check_is_not_local_project, - requires_ssh_configs, requires_aws_configs, requires_gdrive_configs, + requires_ssh_configs, ) # ----------------------------------------------------------------------------- @@ -1467,7 +1467,7 @@ def _setup_rclone_central_local_filesystem_config(self) -> None: rclone.setup_rclone_config_for_local_filesystem( self.cfg.get_rclone_config_name("local_filesystem"), ) - + def _setup_rclone_central_aws_config(self, log: bool) -> None: rclone.setup_rclone_config_for_aws( self.cfg, @@ -1614,7 +1614,6 @@ def setup_aws_connection(self) -> None: ds_logger.close_log_filehandler() - @requires_gdrive_configs @check_is_not_local_project def setup_gdrive_connection(self) -> None: diff --git a/datashuttle/tui/configs.py b/datashuttle/tui/configs.py index 4f55a5a15..ce164450e 100644 --- a/datashuttle/tui/configs.py +++ b/datashuttle/tui/configs.py @@ -25,7 +25,12 @@ from datashuttle.tui.custom_widgets import ClickableInput from datashuttle.tui.interface import Interface -from datashuttle.tui.screens import modal_dialogs, setup_ssh, setup_aws, setup_gdrive +from datashuttle.tui.screens import ( + modal_dialogs, + setup_aws, + setup_gdrive, + setup_ssh, +) from datashuttle.tui.tooltips import get_tooltip @@ -69,7 +74,7 @@ def compose(self) -> ComposeResult: `connection_method`. `self.config_aws_widgets` are AWS-related widgets. - + `self.config_gdrive_widgets` are Google Drive-related widgets. `config_screen_widgets` are core config-related widgets that are @@ -112,7 +117,9 @@ def compose(self) -> ComposeResult: ] self.config_gdrive_widgets = [ - Label("Google Drive Folder ID", id="configs_gdrive_folder_id_label"), + Label( + "Google Drive Folder ID", id="configs_gdrive_folder_id_label" + ), ClickableInput( self.parent_class.mainwindow, placeholder="e.g. 1A2B3C4D5E6F7G8H", @@ -133,11 +140,17 @@ def compose(self) -> ComposeResult: ), Label("Connection Method", id="configs_connect_method_label"), RadioSet( - RadioButton("Local Filesystem", id="configs_local_filesystem_radiobutton"), + RadioButton( + "Local Filesystem", + id="configs_local_filesystem_radiobutton", + ), RadioButton("SSH", id="configs_ssh_radiobutton"), RadioButton("AWS S3", id="configs_aws_radiobutton"), RadioButton("Google Drive", id="configs_gdrive_radiobutton"), - RadioButton("No connection (local only)", id="configs_local_only_radiobutton"), + RadioButton( + "No connection (local only)", + id="configs_local_only_radiobutton", + ), id="configs_connect_method_radioset", ), *self.config_ssh_widgets, @@ -155,10 +168,22 @@ def compose(self) -> ComposeResult: ), Horizontal( Button("Save", id="configs_save_configs_button"), - Button("Setup SSH Connection", id="configs_setup_ssh_connection_button"), - Button("Setup AWS Connection", id="configs_setup_aws_connection_button"), - Button("Setup Google Drive Connection", id="configs_setup_gdrive_connection_button"), - Button("Go to Project Screen", id="configs_go_to_project_screen_button"), + Button( + "Setup SSH Connection", + id="configs_setup_ssh_connection_button", + ), + Button( + "Setup AWS Connection", + id="configs_setup_aws_connection_button", + ), + Button( + "Setup Google Drive Connection", + id="configs_setup_gdrive_connection_button", + ), + Button( + "Go to Project Screen", + id="configs_go_to_project_screen_button", + ), id="configs_bottom_buttons_horizontal", ), ] @@ -184,7 +209,9 @@ def compose(self) -> ComposeResult: ] if not self.interface: - config_screen_widgets = init_only_config_screen_widgets + config_screen_widgets + config_screen_widgets = ( + init_only_config_screen_widgets + config_screen_widgets + ) yield Container(*config_screen_widgets, id="configs_container") @@ -198,19 +225,30 @@ def on_mount(self) -> None: if self.interface: self.fill_widgets_with_project_configs() else: - self.query_one("#configs_local_filesystem_radiobutton").value = True + self.query_one("#configs_local_filesystem_radiobutton").value = ( + True + ) self.switch_ssh_widgets_display(display_ssh=False) self.switch_aws_widgets_display(display_aws=False) self.switch_gdrive_widgets_display(display_gdrive=False) - self.query_one("#configs_setup_ssh_connection_button").visible = False - self.query_one("#configs_setup_aws_connection_button").visible = False - self.query_one("#configs_setup_gdrive_connection_button").visible = False + self.query_one("#configs_setup_ssh_connection_button").visible = ( + False + ) + self.query_one("#configs_setup_aws_connection_button").visible = ( + False + ) + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = False # Setup tooltips if not self.interface: id = "#configs_name_input" self.query_one(id).tooltip = get_tooltip(id) - assert self.query_one("#configs_local_filesystem_radiobutton").value is True + assert ( + self.query_one("#configs_local_filesystem_radiobutton").value + is True + ) self.set_central_path_input_tooltip("local_filesystem") else: method = self.interface.project.cfg["connection_method"] @@ -310,7 +348,6 @@ def set_central_path_input_tooltip(self, mode: str) -> None: self.query_one(id).tooltip = get_tooltip(tooltip_id) - def get_platform_dependent_example_paths( self, local_or_central: Literal["local", "central"], ssh: bool = False ) -> str: @@ -387,9 +424,13 @@ def switch_gdrive_widgets_display(self, display_gdrive: bool) -> None: # Show or hide GDrive setup button based on interface and mode if self.interface is None: - self.query_one("#configs_setup_gdrive_connection_button").visible = False + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = False else: - self.query_one("#configs_setup_gdrive_connection_button").visible = display_gdrive + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = display_gdrive # Set placeholder if empty if not self.query_one("#configs_central_path_input").value: @@ -397,7 +438,9 @@ def switch_gdrive_widgets_display(self, display_gdrive: bool) -> None: placeholder = "e.g. gdrive://project-folder-id" else: placeholder = f"e.g. {self.get_platform_dependent_example_paths('central', ssh=False)}" - self.query_one("#configs_central_path_input").placeholder = placeholder + self.query_one("#configs_central_path_input").placeholder = ( + placeholder + ) def switch_aws_widgets_display(self, display_aws: bool) -> None: """ @@ -419,9 +462,13 @@ def switch_aws_widgets_display(self, display_aws: bool) -> None: # Show or hide AWS setup button based on interface and mode if self.interface is None: - self.query_one("#configs_setup_aws_connection_button").visible = False + self.query_one("#configs_setup_aws_connection_button").visible = ( + False + ) else: - self.query_one("#configs_setup_aws_connection_button").visible = display_aws + self.query_one("#configs_setup_aws_connection_button").visible = ( + display_aws + ) # Set placeholder if empty if not self.query_one("#configs_central_path_input").value: @@ -429,7 +476,9 @@ def switch_aws_widgets_display(self, display_aws: bool) -> None: placeholder = "e.g. s3://bucket-name/project-path" else: placeholder = f"e.g. {self.get_platform_dependent_example_paths('central', ssh=False)}" - self.query_one("#configs_central_path_input").placeholder = placeholder + self.query_one("#configs_central_path_input").placeholder = ( + placeholder + ) def on_button_pressed(self, event: Button.Pressed) -> None: """ @@ -473,7 +522,6 @@ def on_button_pressed(self, event: Button.Pressed) -> None: ), ) - def handle_input_fill_from_select_directory( self, path_: Path, local_or_central: Literal["local", "central"] ) -> None: @@ -519,6 +567,7 @@ def setup_ssh_connection(self) -> None: self.parent_class.mainwindow.push_screen( setup_ssh.SetupSshScreen(self.interface) ) + def setup_gdrive_connection(self) -> None: """ Set up the `SetupGdriveScreen` screen. @@ -535,6 +584,7 @@ def setup_gdrive_connection(self) -> None: self.parent_class.mainwindow.push_screen( setup_gdrive.SetupGdriveScreen(self.interface) ) + def setup_aws_connection(self) -> None: """ Set up the `SetupAwsScreen` screen. @@ -599,12 +649,18 @@ def setup_configs_for_a_new_project(self) -> None: self.interface = interface - self.query_one("#configs_go_to_project_screen_button").visible = True + self.query_one("#configs_go_to_project_screen_button").visible = ( + True + ) if cfg_kwargs["connection_method"] == "ssh": - self.query_one("#configs_setup_ssh_connection_button").visible = True - self.query_one("#configs_setup_ssh_connection_button").disabled = False + self.query_one( + "#configs_setup_ssh_connection_button" + ).visible = True + self.query_one( + "#configs_setup_ssh_connection_button" + ).disabled = False message = ( "A datashuttle project has now been created.\n\n " @@ -615,8 +671,12 @@ def setup_configs_for_a_new_project(self) -> None: elif cfg_kwargs["connection_method"] == "aws": - self.query_one("#configs_setup_aws_connection_button").visible = True - self.query_one("#configs_setup_aws_connection_button").disabled = False + self.query_one( + "#configs_setup_aws_connection_button" + ).visible = True + self.query_one( + "#configs_setup_aws_connection_button" + ).disabled = False message = ( "A datashuttle project has now been created.\n\n " @@ -627,8 +687,12 @@ def setup_configs_for_a_new_project(self) -> None: elif cfg_kwargs["connection_method"] == "gdrive": - self.query_one("#configs_setup_gdrive_connection_button").visible = True - self.query_one("#configs_setup_gdrive_connection_button").disabled = False + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = True + self.query_one( + "#configs_setup_gdrive_connection_button" + ).disabled = False message = ( "A datashuttle project has now been created.\n\n " @@ -653,7 +717,6 @@ def setup_configs_for_a_new_project(self) -> None: else: self.parent_class.mainwindow.show_modal_error_dialog(output) - def setup_configs_for_an_existing_project(self) -> None: """ If the project already exists, we are on the TabbedContent @@ -668,21 +731,45 @@ def setup_configs_for_an_existing_project(self) -> None: # Show relevant setup button depending on selected method connection_method = cfg_kwargs.get("connection_method", "") if connection_method == "ssh": - self.query_one("#configs_setup_ssh_connection_button").visible = True - self.query_one("#configs_setup_aws_connection_button").visible = False - self.query_one("#configs_setup_gdrive_connection_button").visible = False + self.query_one("#configs_setup_ssh_connection_button").visible = ( + True + ) + self.query_one("#configs_setup_aws_connection_button").visible = ( + False + ) + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = False elif connection_method == "aws": - self.query_one("#configs_setup_aws_connection_button").visible = True - self.query_one("#configs_setup_ssh_connection_button").visible = False - self.query_one("#configs_setup_gdrive_connection_button").visible = False + self.query_one("#configs_setup_aws_connection_button").visible = ( + True + ) + self.query_one("#configs_setup_ssh_connection_button").visible = ( + False + ) + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = False elif connection_method == "gdrive": - self.query_one("#configs_setup_gdrive_connection_button").visible = True - self.query_one("#configs_setup_ssh_connection_button").visible = False - self.query_one("#configs_setup_aws_connection_button").visible = False + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = True + self.query_one("#configs_setup_ssh_connection_button").visible = ( + False + ) + self.query_one("#configs_setup_aws_connection_button").visible = ( + False + ) else: - self.query_one("#configs_setup_ssh_connection_button").visible = False - self.query_one("#configs_setup_aws_connection_button").visible = False - self.query_one("#configs_setup_gdrive_connection_button").visible = False + self.query_one("#configs_setup_ssh_connection_button").visible = ( + False + ) + self.query_one("#configs_setup_aws_connection_button").visible = ( + False + ) + self.query_one( + "#configs_setup_gdrive_connection_button" + ).visible = False success, output = self.interface.set_configs_on_existing_project( cfg_kwargs @@ -698,7 +785,6 @@ def setup_configs_for_an_existing_project(self) -> None: else: self.parent_class.mainwindow.show_modal_error_dialog(output) - def fill_widgets_with_project_configs(self) -> None: """ If a configured project already exists, we want to fill the @@ -801,7 +887,6 @@ def fill_widgets_with_project_configs(self) -> None: ) input.value = value - def get_datashuttle_inputs_from_widgets(self) -> Dict: """ Get the configs to pass to `make_config_file()` from @@ -859,12 +944,8 @@ def get_datashuttle_inputs_from_widgets(self) -> Dict: None if aws_bucket_name == "" else aws_bucket_name ) - aws_region = self.query_one( - "#configs_aws_region_input" - ).value - cfg_kwargs["aws_region"] = ( - None if aws_region == "" else aws_region - ) + aws_region = self.query_one("#configs_aws_region_input").value + cfg_kwargs["aws_region"] = None if aws_region == "" else aws_region gdrive_folder_id = self.query_one( "#configs_gdrive_folder_id_input" diff --git a/datashuttle/tui/interface.py b/datashuttle/tui/interface.py index f3b2f3828..8b3f64856 100644 --- a/datashuttle/tui/interface.py +++ b/datashuttle/tui/interface.py @@ -11,7 +11,7 @@ from datashuttle import DataShuttle from datashuttle.configs import load_configs -from datashuttle.utils import ssh, gdrive, aws +from datashuttle.utils import aws, gdrive, ssh class Interface: @@ -203,7 +203,9 @@ def transfer_entire_project(self, upload: bool) -> InterfaceOutput: else self.project.download_entire_project ) transfer_func( - overwrite_existing_files=self.tui_settings["overwrite_existing_files"], + overwrite_existing_files=self.tui_settings[ + "overwrite_existing_files" + ], dry_run=self.tui_settings["dry_run"], ) @@ -232,7 +234,6 @@ def transfer_entire_project(self, upload: bool) -> InterfaceOutput: except BaseException as e: return False, str(e) - def transfer_top_level_only( self, selected_top_level_folder: str, upload: bool ) -> InterfaceOutput: @@ -277,7 +278,6 @@ def transfer_top_level_only( except BaseException as e: return False, str(e) - def transfer_custom_selection( self, @@ -313,14 +313,18 @@ def transfer_custom_selection( if connection_method in ["ssh", "local_filesystem"]: transfer_func = ( - self.project.upload_custom if upload else self.project.download_custom + self.project.upload_custom + if upload + else self.project.download_custom ) transfer_func( selected_top_level_folder, sub_names=sub_names, ses_names=ses_names, datatype=datatype, - overwrite_existing_files=self.tui_settings["overwrite_existing_files"], + overwrite_existing_files=self.tui_settings[ + "overwrite_existing_files" + ], dry_run=self.tui_settings["dry_run"], ) @@ -349,7 +353,6 @@ def transfer_custom_selection( except BaseException as e: return False, str(e) - # Setup SSH # ---------------------------------------------------------------------------------- @@ -510,9 +513,7 @@ def get_aws_hostkey(self) -> InterfaceOutput: Retrieve AWS host credentials for verification. """ try: - key = aws.get_remote_aws_key( - self.project.cfg["aws_bucket_name"] - ) + key = aws.get_remote_aws_key(self.project.cfg["aws_bucket_name"]) return True, key except BaseException as e: return False, str(e) @@ -594,4 +595,3 @@ def setup_gdrive_folder_and_rclone_config( except BaseException as e: return False, str(e) - diff --git a/datashuttle/tui/screens/setup_aws.py b/datashuttle/tui/screens/setup_aws.py index 9b5859fa9..3b65568d7 100644 --- a/datashuttle/tui/screens/setup_aws.py +++ b/datashuttle/tui/screens/setup_aws.py @@ -4,6 +4,7 @@ if TYPE_CHECKING: from textual.app import ComposeResult + from datashuttle.tui.interface import Interface from textual.containers import Container, Horizontal @@ -105,7 +106,6 @@ def ask_user_to_accept_aws_bucket(self) -> None: self.query_one("#messagebox_message_label").update(message) self.stage += 1 - def save_aws_bucket_and_prompt_region_input(self) -> None: """ Once the AWS bucket is accepted, prompt the user for the region. @@ -129,7 +129,6 @@ def save_aws_bucket_and_prompt_region_input(self) -> None: self.query_one("#messagebox_message_label").update(message) self.stage += 1 - def use_aws_bucket_and_region_to_setup_aws_connection(self) -> None: """ Use the AWS bucket name and region to complete the setup. @@ -139,7 +138,9 @@ def use_aws_bucket_and_region_to_setup_aws_connection(self) -> None: bucket_name = self.query_one("#setup_aws_bucket_name_input").value region = self.query_one("#setup_aws_region_input").value - success, output = self.interface.setup_aws_bucket_and_rclone_config(bucket_name, region) + success, output = self.interface.setup_aws_bucket_and_rclone_config( + bucket_name, region + ) if success: message = ( diff --git a/datashuttle/tui/screens/setup_gdrive.py b/datashuttle/tui/screens/setup_gdrive.py index de2889cab..32c5c9692 100644 --- a/datashuttle/tui/screens/setup_gdrive.py +++ b/datashuttle/tui/screens/setup_gdrive.py @@ -4,6 +4,7 @@ if TYPE_CHECKING: from textual.app import ComposeResult + from datashuttle.tui.interface import Interface from textual.containers import Container, Horizontal @@ -41,7 +42,10 @@ def compose(self) -> ComposeResult: ), id="messagebox_message_container", ), - Input(placeholder="Google Drive Folder ID", id="setup_gdrive_folder_input"), + Input( + placeholder="Google Drive Folder ID", + id="setup_gdrive_folder_input", + ), Horizontal( Button("OK", id="setup_gdrive_ok_button"), Button("Cancel", id="setup_gdrive_cancel_button"), @@ -101,13 +105,14 @@ def ask_user_to_accept_gdrive_folder(self) -> None: self.query_one("#messagebox_message_label").update(message) self.stage += 1 - def save_gdrive_folder_and_prompt_setup(self) -> None: """ Once the Google Drive folder ID is accepted, confirm setup. No additional credentials are needed after this step. """ - success, output = self.interface.save_gdrive_key_locally(self.folder_id) + success, output = self.interface.save_gdrive_key_locally( + self.folder_id + ) if success: message = ( @@ -124,7 +129,6 @@ def save_gdrive_folder_and_prompt_setup(self) -> None: self.query_one("#messagebox_message_label").update(message) self.stage += 1 - def use_folder_id_to_setup_gdrive_connection(self) -> None: """ Use the OAuth folder_id to complete the Google Drive setup. @@ -133,7 +137,9 @@ def use_folder_id_to_setup_gdrive_connection(self) -> None: """ folder_id = self.query_one("#setup_gdrive_folder_id_input").value - success, output = self.interface.setup_gdrive_folder_and_rclone_config(folder_id) + success, output = self.interface.setup_gdrive_folder_and_rclone_config( + folder_id + ) if success: message = ( diff --git a/datashuttle/tui/tooltips.py b/datashuttle/tui/tooltips.py index dee58d733..9583eb89e 100644 --- a/datashuttle/tui/tooltips.py +++ b/datashuttle/tui/tooltips.py @@ -35,7 +35,9 @@ def get_tooltip(id: str) -> str: # AWS S3 radiobutton elif id == "#configs_aws_radiobutton": - tooltip = "Use AWS S3 when planning to connect with your AWS storage bucket." + tooltip = ( + "Use AWS S3 when planning to connect with your AWS storage bucket." + ) # Google Drive radiobutton elif id == "#configs_gdrive_radiobutton": @@ -45,7 +47,7 @@ def get_tooltip(id: str) -> str: elif id == "#configs_local_only_radiobutton": tooltip = "No connection to a central project is made.\nTransfer functionality will not be available." - # AWS S3 Inputs + # AWS S3 Inputs elif id == "#configs_aws_bucket_name_input": tooltip = ( "Name of your AWS S3 bucket where project data will be stored.\n\n" diff --git a/datashuttle/utils/aws.py b/datashuttle/utils/aws.py index ee3bd2803..477212f9d 100644 --- a/datashuttle/utils/aws.py +++ b/datashuttle/utils/aws.py @@ -2,17 +2,16 @@ import subprocess from pathlib import Path -from typing import Any, List, Optional, Tuple -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, List, Tuple if TYPE_CHECKING: from datashuttle.configs.config_class import Configs +import fnmatch + from datashuttle.utils import utils -import fnmatch -import os def get_remote_aws_key(bucket_name: str) -> Tuple[bool, str]: """ @@ -30,7 +29,10 @@ def get_remote_aws_key(bucket_name: str) -> Tuple[bool, str]: except subprocess.CalledProcessError as e: return False, e.stderr.decode() -def save_aws_key_locally(bucket_name: str, aws_region: str, central_path: Path) -> None: + +def save_aws_key_locally( + bucket_name: str, aws_region: str, central_path: Path +) -> None: """ Save the AWS bucket name and region in the central path, following SSH-style storage. """ @@ -73,6 +75,8 @@ def connect_aws_with_logging( f"Error:\n{e}", ConnectionError, ) + + def search_aws_remote_for_folders( search_path: Path, search_prefix: str, @@ -106,6 +110,7 @@ def search_aws_remote_for_folders( return all_folder_names, all_filenames + def get_list_of_folder_names_over_aws( cfg: Configs, search_path: Path, @@ -132,7 +137,9 @@ def get_list_of_folder_names_over_aws( return_full_path : If `True`, return full rclone remote path, else return only folder name """ - remote_path = f"{cfg.get_rclone_config_name('AWS S3')}:{search_path.as_posix()}" + remote_path = ( + f"{cfg.get_rclone_config_name('AWS S3')}:{search_path.as_posix()}" + ) all_folder_names = [] all_filenames = [] @@ -162,17 +169,24 @@ def get_list_of_folder_names_over_aws( except subprocess.CalledProcessError as e: if verbose: - utils.log_and_message(f"No file found at {remote_path}\n{e.stderr}") + utils.log_and_message( + f"No file found at {remote_path}\n{e.stderr}" + ) return all_folder_names, all_filenames -def verify_aws_remote(bucket_name: str, aws_key_path: Path, log: bool = True) -> bool: + +def verify_aws_remote( + bucket_name: str, aws_key_path: Path, log: bool = True +) -> bool: """ Prompt user to trust and save an AWS S3 bucket for future use. """ success, _ = get_remote_aws_key(bucket_name) if not success: - utils.print_message_to_user("Unable to access the AWS S3 bucket. Make sure it exists and is accessible.") + utils.print_message_to_user( + "Unable to access the AWS S3 bucket. Make sure it exists and is accessible." + ) return False message = ( @@ -184,7 +198,9 @@ def verify_aws_remote(bucket_name: str, aws_key_path: Path, log: bool = True) -> if input_ == "y": save_aws_key_locally(bucket_name, aws_key_path) if log: - utils.log(f"AWS S3 bucket {bucket_name} trusted and saved at {aws_key_path}") + utils.log( + f"AWS S3 bucket {bucket_name} trusted and saved at {aws_key_path}" + ) utils.print_message_to_user("AWS bucket accepted.") return True else: diff --git a/datashuttle/utils/decorators.py b/datashuttle/utils/decorators.py index f0b29c2c3..b257e319d 100644 --- a/datashuttle/utils/decorators.py +++ b/datashuttle/utils/decorators.py @@ -27,16 +27,15 @@ def wrapper(*args, **kwargs): return wrapper + def requires_aws_configs(func): """ Decorator to check AWS configs are loaded before running the function. """ + @wraps(func) def wrapper(*args, **kwargs): - if ( - not args[0].cfg["aws_bucket_name"] - or not args[0].cfg["aws_region"] - ): + if not args[0].cfg["aws_bucket_name"] or not args[0].cfg["aws_region"]: log_and_raise_error( "Cannot setup AWS connection, 'aws_bucket_name', " "'aws_access_key', or 'aws_secret_key' is not set in " @@ -45,6 +44,7 @@ def wrapper(*args, **kwargs): ) else: return func(*args, **kwargs) + return wrapper @@ -52,11 +52,10 @@ def requires_gdrive_configs(func): """ Decorator to check Google Drive configs are loaded before running the function. """ + @wraps(func) def wrapper(*args, **kwargs): - if ( - not args[0].cfg["gdrive_folder_id"] - ): + if not args[0].cfg["gdrive_folder_id"]: log_and_raise_error( "Cannot setup Google Drive connection, 'gdrive_folder_id', " "'gdrive_client_id', or 'gdrive_client_secret' is not set in " @@ -65,8 +64,10 @@ def wrapper(*args, **kwargs): ) else: return func(*args, **kwargs) + return wrapper + def check_configs_set(func): """ Check that configs have been loaded (i.e. diff --git a/datashuttle/utils/folders.py b/datashuttle/utils/folders.py index 3ca3e364b..ddc1d2649 100644 --- a/datashuttle/utils/folders.py +++ b/datashuttle/utils/folders.py @@ -20,7 +20,7 @@ from pathlib import Path from datashuttle.configs import canonical_folders, canonical_tags -from datashuttle.utils import ssh, utils, validation, aws, gdrive +from datashuttle.utils import aws, gdrive, ssh, utils, validation from datashuttle.utils.custom_exceptions import NeuroBlueprintError # ----------------------------------------------------------------------------- @@ -516,28 +516,34 @@ def search_for_folders( """ if local_or_central == "central": if cfg["connection_method"] == "ssh": - all_folder_names, all_filenames = ssh.search_ssh_central_for_folders( - search_path, - search_prefix, - cfg, - verbose, - return_full_path, + all_folder_names, all_filenames = ( + ssh.search_ssh_central_for_folders( + search_path, + search_prefix, + cfg, + verbose, + return_full_path, + ) ) elif cfg["connection_method"] == "AWS S3": - all_folder_names, all_filenames = aws.search_aws_remote_for_folders( - search_path, - search_prefix, - cfg, - verbose, - return_full_path, + all_folder_names, all_filenames = ( + aws.search_aws_remote_for_folders( + search_path, + search_prefix, + cfg, + verbose, + return_full_path, + ) ) elif cfg["connection_method"] == "Google Drive": - all_folder_names, all_filenames = gdrive.search_gdrive_remote_for_folders( - search_path, - search_prefix, - cfg, - verbose, - return_full_path, + all_folder_names, all_filenames = ( + gdrive.search_gdrive_remote_for_folders( + search_path, + search_prefix, + cfg, + verbose, + return_full_path, + ) ) else: # Default to filesystem search if no valid method found @@ -548,8 +554,10 @@ def search_for_folders( ) return [], [] - all_folder_names, all_filenames = search_filesystem_path_for_folders( - search_path / search_prefix, return_full_path + all_folder_names, all_filenames = ( + search_filesystem_path_for_folders( + search_path / search_prefix, return_full_path + ) ) else: diff --git a/datashuttle/utils/gdrive.py b/datashuttle/utils/gdrive.py index 2c8e7dd7b..ca3beb3c8 100644 --- a/datashuttle/utils/gdrive.py +++ b/datashuttle/utils/gdrive.py @@ -2,15 +2,14 @@ import subprocess from pathlib import Path -from typing import Any, List, Optional, Tuple -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, List, Tuple if TYPE_CHECKING: from datashuttle.configs.config_class import Configs +import fnmatch + from datashuttle.utils import utils -import fnmatch -import os def get_remote_gdrive_key(folder_id: str) -> Tuple[bool, str]: """ @@ -28,7 +27,10 @@ def get_remote_gdrive_key(folder_id: str) -> Tuple[bool, str]: except subprocess.CalledProcessError as e: return False, e.stderr.decode() -def save_gdrive_key_locally(folder_id: str, remote_name: str, central_path: Path) -> None: + +def save_gdrive_key_locally( + folder_id: str, remote_name: str, central_path: Path +) -> None: """ Save the trusted Google Drive folder ID and remote name in the central path. """ @@ -37,6 +39,7 @@ def save_gdrive_key_locally(folder_id: str, remote_name: str, central_path: Path with open(central_path, "w") as file: file.write(f"Folder ID: {folder_id}\nRemote: {remote_name}") + def connect_gdrive_with_logging( cfg: Configs, message_on_sucessful_connection: bool = True, @@ -71,6 +74,7 @@ def connect_gdrive_with_logging( ConnectionError, ) + def search_gdrive_remote_for_folders( search_path: Path, search_prefix: str, @@ -104,6 +108,7 @@ def search_gdrive_remote_for_folders( return all_folder_names, all_filenames + def get_list_of_folder_names_over_gdrive( cfg: Configs, search_path: Path, @@ -160,17 +165,24 @@ def get_list_of_folder_names_over_gdrive( except subprocess.CalledProcessError as e: if verbose: - utils.log_and_message(f"No file found at {remote_path}\n{e.stderr}") + utils.log_and_message( + f"No file found at {remote_path}\n{e.stderr}" + ) return all_folder_names, all_filenames -def verify_gdrive_remote(folder_id: str, gdrive_key_path: Path, log: bool = True) -> bool: + +def verify_gdrive_remote( + folder_id: str, gdrive_key_path: Path, log: bool = True +) -> bool: """ Prompt user to trust and save a GDrive folder ID for future use. """ success, _ = get_remote_gdrive_key(folder_id) if not success: - utils.print_message_to_user("Unable to access the Google Drive folder. Make sure it's shared and reachable.") + utils.print_message_to_user( + "Unable to access the Google Drive folder. Make sure it's shared and reachable." + ) return False message = ( @@ -182,7 +194,9 @@ def verify_gdrive_remote(folder_id: str, gdrive_key_path: Path, log: bool = True if input_ == "y": save_gdrive_key_locally(folder_id, gdrive_key_path) if log: - utils.log(f"Google Drive folder ID {folder_id} trusted and saved at {gdrive_key_path}") + utils.log( + f"Google Drive folder ID {folder_id} trusted and saved at {gdrive_key_path}" + ) utils.print_message_to_user("Google Drive folder accepted.") return True else: diff --git a/datashuttle/utils/rclone.py b/datashuttle/utils/rclone.py index d7b6fe937..5e400093a 100644 --- a/datashuttle/utils/rclone.py +++ b/datashuttle/utils/rclone.py @@ -107,6 +107,7 @@ def setup_rclone_config_for_ssh( if log: log_rclone_config_output() + def setup_rclone_config_for_gdrive( cfg: Configs, rclone_config_name: str, @@ -251,7 +252,10 @@ def transfer_data( A list of options to pass to Rclone's copy function. see `cfg.make_rclone_transfer_options()`. """ - assert upload_or_download in ["upload", "download"], "must be 'upload' or 'download'" + assert upload_or_download in [ + "upload", + "download", + ], "must be 'upload' or 'download'" connection_method = cfg["connection_method"] @@ -262,11 +266,15 @@ def transfer_data( # AWS S3 Path Formatting if connection_method == "AWS S3": - central_filepath = f"{cfg.get_rclone_config_name('AWS S3')}:{central_filepath}" + central_filepath = ( + f"{cfg.get_rclone_config_name('AWS S3')}:{central_filepath}" + ) # Google Drive Path Formatting elif connection_method == "Google Drive": - central_filepath = f"{cfg.get_rclone_config_name('Google Drive')}:{central_filepath}" + central_filepath = ( + f"{cfg.get_rclone_config_name('Google Drive')}:{central_filepath}" + ) # Default (SSH or Local Filesystem) else: @@ -378,13 +386,21 @@ def perform_rclone_check( """ connection_method = cfg["connection_method"] - local_filepath = cfg.get_base_folder("local", top_level_folder).parent.as_posix() - central_filepath = cfg.get_base_folder("central", top_level_folder).parent.as_posix() + local_filepath = cfg.get_base_folder( + "local", top_level_folder + ).parent.as_posix() + central_filepath = cfg.get_base_folder( + "central", top_level_folder + ).parent.as_posix() if connection_method == "AWS S3": - central_filepath = f"{cfg.get_rclone_config_name('AWS S3')}:{central_filepath}" + central_filepath = ( + f"{cfg.get_rclone_config_name('AWS S3')}:{central_filepath}" + ) elif connection_method == "Google Drive": - central_filepath = f"{cfg.get_rclone_config_name('Google Drive')}:{central_filepath}" + central_filepath = ( + f"{cfg.get_rclone_config_name('Google Drive')}:{central_filepath}" + ) else: central_filepath = f"{cfg.get_rclone_config_name()}:{central_filepath}"