Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 16 additions & 30 deletions datashuttle/datashuttle_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
check_is_not_local_project,
requires_aws_configs,
requires_ssh_configs,
with_logging,
)

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -117,6 +118,7 @@ def _set_attributes_after_config_load(self) -> None:
# -------------------------------------------------------------------------

@check_configs_set
@with_logging(conditional_param="log")
def create_folders(
self,
top_level_folder: TopLevelFolder,
Expand Down Expand Up @@ -202,9 +204,6 @@ def create_folders(
project.create_folders("rawdata", "sub-002@TO@005", ["ses-001", "ses-002"], ["ephys", "behav"])

"""
if log:
self._start_log("create-folders", local_vars=locals())

self._check_top_level_folder(top_level_folder)

if ses_names is None and datatype != "":
Expand Down Expand Up @@ -251,7 +250,6 @@ def create_folders(
f"For log of all created folders, "
f"please see {self.cfg.logging_path}"
)
ds_logger.close_log_filehandler()

return created_paths

Expand Down Expand Up @@ -306,6 +304,7 @@ def _format_and_validate_names(

@check_configs_set
@check_is_not_local_project
@with_logging(conditional_param="init_log")
def upload_custom(
self,
top_level_folder: TopLevelFolder,
Expand Down Expand Up @@ -355,9 +354,6 @@ def upload_custom(
(e.g. in a calling function).

"""
if init_log:
self._start_log("upload-custom", local_vars=locals())

self._check_top_level_folder(top_level_folder)

TransferData(
Expand All @@ -372,11 +368,9 @@ def upload_custom(
log=True,
)

if init_log:
ds_logger.close_log_filehandler()

@check_configs_set
@check_is_not_local_project
@with_logging(conditional_param="init_log")
def download_custom(
self,
top_level_folder: TopLevelFolder,
Expand Down Expand Up @@ -426,9 +420,6 @@ def download_custom(
(e.g. in a calling function).

"""
if init_log:
self._start_log("download-custom", local_vars=locals())

self._check_top_level_folder(top_level_folder)

TransferData(
Expand All @@ -443,16 +434,14 @@ def download_custom(
log=True,
)

if init_log:
ds_logger.close_log_filehandler()

# Specific top-level folder
# ----------------------------------------------------------------------------------
# A set of convenience functions are provided to abstract
# away the 'top_level_folder' concept.

@check_configs_set
@check_is_not_local_project
@with_logging()
def upload_rawdata(
self,
overwrite_existing_files: OverwriteExistingFiles = "never",
Expand All @@ -479,10 +468,12 @@ def upload_rawdata(
"rawdata",
overwrite_existing_files=overwrite_existing_files,
dry_run=dry_run,
init_log=False,
)

@check_configs_set
@check_is_not_local_project
@with_logging()
def upload_derivatives(
self,
overwrite_existing_files: OverwriteExistingFiles = "never",
Expand All @@ -509,10 +500,12 @@ def upload_derivatives(
"derivatives",
overwrite_existing_files=overwrite_existing_files,
dry_run=dry_run,
init_log=False,
)

@check_configs_set
@check_is_not_local_project
@with_logging()
def download_rawdata(
self,
overwrite_existing_files: OverwriteExistingFiles = "never",
Expand All @@ -539,10 +532,12 @@ def download_rawdata(
"rawdata",
overwrite_existing_files=overwrite_existing_files,
dry_run=dry_run,
init_log=False,
)

@check_configs_set
@check_is_not_local_project
@with_logging()
def download_derivatives(
self,
overwrite_existing_files: OverwriteExistingFiles = "never",
Expand All @@ -569,10 +564,12 @@ def download_derivatives(
"derivatives",
overwrite_existing_files=overwrite_existing_files,
dry_run=dry_run,
init_log=False,
)

@check_configs_set
@check_is_not_local_project
@with_logging()
def upload_entire_project(
self,
overwrite_existing_files: OverwriteExistingFiles = "never",
Expand All @@ -596,14 +593,13 @@ def upload_entire_project(
transfer was taking place, but no files will be moved.

"""
self._start_log("upload-entire-project", local_vars=locals())
self._transfer_entire_project(
"upload", overwrite_existing_files, dry_run
)
ds_logger.close_log_filehandler()

@check_configs_set
@check_is_not_local_project
@with_logging()
def download_entire_project(
self,
overwrite_existing_files: OverwriteExistingFiles = "never",
Expand All @@ -627,14 +623,13 @@ def download_entire_project(
transfer was taking place, but no files will be moved.

"""
self._start_log("download-entire-project", local_vars=locals())
self._transfer_entire_project(
"download", overwrite_existing_files, dry_run
)
ds_logger.close_log_filehandler()

@check_configs_set
@check_is_not_local_project
@with_logging()
def upload_specific_folder_or_file(
self,
filepath: Union[str, Path],
Expand Down Expand Up @@ -664,16 +659,13 @@ def upload_specific_folder_or_file(
transfer was taking place, but no files will be moved.

"""
self._start_log("upload-specific-folder-or-file", local_vars=locals())

self._transfer_specific_file_or_folder(
"upload", filepath, overwrite_existing_files, dry_run
)

ds_logger.close_log_filehandler()

@check_configs_set
@check_is_not_local_project
@with_logging()
def download_specific_folder_or_file(
self,
filepath: Union[str, Path],
Expand Down Expand Up @@ -704,16 +696,10 @@ def download_specific_folder_or_file(
transfer was taking place, but no files will be moved.

"""
self._start_log(
"download-specific-folder-or-file", local_vars=locals()
)

self._transfer_specific_file_or_folder(
"download", filepath, overwrite_existing_files, dry_run
)

ds_logger.close_log_filehandler()

def _transfer_top_level_folder(
self,
upload_or_download: Literal["upload", "download"],
Expand Down
89 changes: 89 additions & 0 deletions datashuttle/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import wraps
from typing import Optional

from datashuttle.utils.custom_exceptions import ConfigError
from datashuttle.utils.utils import log_and_raise_error
Expand Down Expand Up @@ -88,3 +89,91 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper


def with_logging(
command_name: Optional[str] = None,
store_in_temp_folder: bool = False,
conditional_param: Optional[str] = None,
):
"""Automatically handle logging for DataShuttle methods.

This decorator:
1. Starts logging at the beginning of the function
2. Captures local variables for logging
3. Ensures logging is closed even if an exception occurs

Parameters
----------
command_name
Name of the command for logging. If None, uses the function name
with underscores replaced by hyphens.
store_in_temp_folder
If True, store logs in temp folder instead of project logging path.
conditional_param
Name of parameter that controls whether logging occurs (e.g., "log").
If specified and that parameter is False, logging is skipped.

Examples
--------
@check_configs_set
@with_logging()
def upload_rawdata(self, ...):
...

@with_logging(conditional_param="log")
def create_folders(self, ..., log: bool = True):
...

"""

def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
import inspect

from datashuttle.utils import ds_logger

# Get the DataShuttle instance (first argument)
self = args[0]

# Check if logging should be skipped based on conditional parameter
if conditional_param:
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
if not bound_args.arguments.get(conditional_param, True):
# Skip logging - just run the function
return func(*args, **kwargs)

# Determine command name
log_command_name = (
command_name
if command_name
else func.__name__.replace("_", "-")
)

# Capture local variables for logging
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
local_vars = dict(bound_args.arguments)

# Start logging
self._start_log(
log_command_name,
local_vars=local_vars,
store_in_temp_folder=store_in_temp_folder,
)

try:
# Execute the function
result = func(*args, **kwargs)
return result
finally:
# Always close logging, even if exception occurs
ds_logger.close_log_filehandler()

return wrapper

return decorator